diff --git a/assets/quicksettings.js b/assets/quicksettings.js
deleted file mode 100644
index 43cc3381..00000000
--- a/assets/quicksettings.js
+++ /dev/null
@@ -1 +0,0 @@
-!function(){function a(a,b){var d=c("div",null,"qs_label",b);return d.innerHTML=a,d}function b(a,b,d,e){var f=c("input",b,d,e);return f.type=a,f}function c(a,b,c,d){var e=document.createElement(a);if(e)return e.id=b,c&&(e.className=c),d&&d.appendChild(e),e}function d(){return navigator.userAgent.indexOf("rv:11")!=-1||navigator.userAgent.indexOf("MSIE")!=-1}function e(){var a=navigator.userAgent.toLowerCase();return!(a.indexOf("chrome")>-1||a.indexOf("firefox")>-1||a.indexOf("epiphany")>-1)&&a.indexOf("safari/")>-1}function f(){var a=navigator.userAgent.toLowerCase();return a.indexOf("edge")>-1}function g(){var a=document.createElement("style");a.innerText=i,document.head.appendChild(a),h=!0}var h=!1,i=".qs_main{background-color:#dddddd;text-align:left;position:absolute;width:200px;font:12px sans-serif;box-shadow:5px 5px 8px rgba(0,0,0,0.35);user-select:none;-webkit-user-select:none;color:#000000;border:none}.qs_content{background-color:#cccccc;overflow-y:auto}.qs_title_bar{background-color:#eeeeee;user-select:none;-webkit-user-select:none;cursor:pointer;padding:5px;font-weight:bold;border:none;color:#000000}.qs_container{margin:5px;padding:5px;background-color:#eeeeee;border:none;position:relative}.qs_container_selected{border:none;background-color:#ffffff}.qs_range{-webkit-appearance:none;-moz-appearance:none;width:100%;height:17px;padding:0;margin:0;background-color:transparent;border:none;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_range:focus{outline:none;border:none}.qs_range::-webkit-slider-runnable-track{width:100%;height:15px;cursor:pointer;background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-webkit-slider-runnable-track{background:#cccccc}.qs_range::-webkit-slider-thumb{-webkit-appearance:none;height:15px;width:15px;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer;margin-top:0}.qs_range::-moz-range-track{width:100%;height:15px;cursor:pointer;background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range::-moz-range-thumb{height:15px;width:15px;border:none;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer}.qs_range::-ms-track{width:100%;height:15px;cursor:pointer;visibility:hidden;background:transparent}.qs_range::-ms-thumb{height:15px;width:15px;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer;border:none}.qs_range::-ms-fill-lower{background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-ms-fill-lower{background:#cccccc}.qs_range::-ms-fill-upper{background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-ms-fill-upper{background:#cccccc}.qs_button{background-color:#f6f6f6;color:#000000;height:30px;border:1px solid #aaaaaa;font:12px sans-serif}.qs_button:active{background-color:#ffffff;border:1px solid #aaaaaa}.qs_button:focus{border:1px solid #aaaaaa;outline:none}.qs_checkbox{cursor:pointer}.qs_checkbox input{position:absolute;left:-99999px}.qs_checkbox span{height:16px;width:100%;display:block;text-indent:20px;background:url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAALklEQVQ4T2OcOXPmfwYKACPIgLS0NLKMmDVrFsOoAaNhMJoOGBioFwZkZUWoJgApdFaxjUM1YwAAAABJRU5ErkJggg==') no-repeat}.qs_checkbox input:checked+span{background:url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAvElEQVQ4T63Tyw2EIBAA0OFKBxBL40wDRovAUACcKc1IB1zZDAkG18GYZTmSmafzgTnnMgwchoDWGlJKheGcP3JtnPceCqCUAmttSZznuYtgchsXQrgC+77DNE0kUpPbmBOoJaBOIVQylnqWgAAeKhDve/AN+EaklJBzhhgjWRoJVGTbNjiOowAIret6a+4jYIwpX8aDwLIs74C2D0IIYIyVP6Gm898m9kbVm85ljHUTf16k4VUefkwDrxk+zoUEwCt0GbUAAAAASUVORK5CYII=') no-repeat}.qs_checkbox_label{position:absolute;top:7px;left:30px}.qs_label{margin-bottom:3px;user-select:none;-webkit-user-select:none;cursor:default;font:12px sans-serif}.qs_text_input{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;width:100%;padding:0 0 0 5px;height:24px;border:1px inset #ffffff;background-color:#ffffff;color:#000000;font-size:12px}.qs_text_input:focus{outline:none;background:#ffffff;border:1px inset #ffffff}.qs_select{background:url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAAp0lEQVRIS+2SsQ3FIAwF7RVYhA5mgQFhFuhYhJKWL0eKxI8SGylKZ0p4+OBsHGNM+HChAiS7qkgyBKrovaLeOxhjbgtxZ+cFtgelFMg5QwgBvPd/EO5sDbKAlBLUWo/8CjmL075zDmKMj6rEKbpCqBL9aqc4ZUQAhVbInBMQUXz5Vg/WfxOktXZsWWtZLds9uIqlqaH1NFV3jdhSJA47E1CAaE8ViYp+wGiWMZ/T+cgAAAAASUVORK5CYII=') no-repeat right #f6f6f6;-webkit-appearance:none;-moz-appearance:none;appearance:none;color:#000000;width:100%;height:24px;border:1px solid #aaaaaa;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;padding:0 5px;-moz-outline:none;font-size:14px}.qs_select option{font-size:14px}.qs_select::-ms-expand{display:none}.qs_select:focus{outline:none}.qs_number{height:24px}.qs_image{width:100%}.qs_progress{width:100%;height:15px;background-color:#cccccc;border:none;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_progress_value{height:100%;background-color:#999999}.qs_textarea{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;resize:vertical;width:100%;padding:3px 5px;border:1px inset #ffffff;background-color:#ffffff;color:#000000;font-size:12px}.qs_textarea:focus{outline:none;background:#ffffff;border:1px inset #ffffff}.qs_color{position:absolute;left:-999999px}.qs_color_label{width:100%;height:20px;display:block;border:1px solid #aaaaaa;cursor:pointer;padding:0 0 0 5px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_file_chooser{position:absolute;left:-999999px}.qs_file_chooser_label{background-color:#f6f6f6;color:#000000;height:30px;border:1px solid #aaaaaa;font:12px sans-serif;width:100%;display:block;cursor:pointer;padding:7px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;white-space:nowrap;overflow:hidden;text-overflow:ellipsis}",j={_version:"2.1",_topZ:1,_panel:null,_titleBar:null,_content:null,_startX:0,_startY:0,_hidden:!1,_collapsed:!1,_controls:null,_keyCode:-1,_draggable:!0,_collapsible:!0,_globalChangeHandler:null,useExtStyleSheet:function(){h=!0},create:function(a,b,c,d){var e=Object.create(this);return e._init(a,b,c,d),e},destroy:function(){this._panel.parentElement&&this._panel.parentElement.removeChild(this._panel);for(var a in this)this[a]=null},_init:function(a,b,c,d){h||g(),this._bindHandlers(),this._createPanel(a,b,d),this._createTitleBar(c||"QuickSettings"),this._createContent()},_bindHandlers:function(){this._startDrag=this._startDrag.bind(this),this._drag=this._drag.bind(this),this._endDrag=this._endDrag.bind(this),this._doubleClickTitle=this._doubleClickTitle.bind(this),this._onKeyUp=this._onKeyUp.bind(this)},getValuesAsJSON:function(a){var b={};for(var c in this._controls)this._controls[c].getValue&&(b[c]=this._controls[c].getValue());return a&&(b=JSON.stringify(b)),b},setValuesFromJSON:function(a){"string"==typeof a&&(a=JSON.parse(a));for(var b in a)this._controls[b].setValue&&this._controls[b].setValue(a[b]);return this},saveInLocalStorage:function(a){return this._localStorageName=a,this._readFromLocalStorage(a),this},clearLocalStorage:function(a){return localStorage.removeItem(a),this},_saveInLocalStorage:function(a){localStorage.setItem(a,this.getValuesAsJSON(!0))},_readFromLocalStorage:function(a){var b=localStorage.getItem(a);b&&this.setValuesFromJSON(b)},_createPanel:function(a,b,d){this._panel=c("div",null,"qs_main",d||document.body),this._panel.style.zIndex=++j._topZ,this.setPosition(a||0,b||0),this._controls={}},_createTitleBar:function(a){this._titleBar=c("div",null,"qs_title_bar",this._panel),this._titleBar.textContent=a,this._titleBar.addEventListener("mousedown",this._startDrag),this._titleBar.addEventListener("dblclick",this._doubleClickTitle)},_createContent:function(){this._content=c("div",null,"qs_content",this._panel)},_createContainer:function(){var a=c("div",null,"qs_container");return a.addEventListener("focus",function(){this.className+=" qs_container_selected"},!0),a.addEventListener("blur",function(){var a=this.className.indexOf(" qs_container_selected");a>-1&&(this.className=this.className.substr(0,a))},!0),this._content.appendChild(a),a},setPosition:function(a,b){return this._panel.style.left=a+"px",this._panel.style.top=Math.max(b,0)+"px",this},setSize:function(a,b){return this._panel.style.width=a+"px",this._content.style.width=a+"px",this._content.style.height=b-this._titleBar.offsetHeight+"px",this},setWidth:function(a){return this._panel.style.width=a+"px",this._content.style.width=a+"px",this},setHeight:function(a){return this._content.style.height=a-this._titleBar.offsetHeight+"px",this},setDraggable:function(a){return this._draggable=a,this._draggable||this._collapsible?this._titleBar.style.cursor="pointer":this._titleBar.style.cursor="default",this},_startDrag:function(a){this._draggable&&(this._panel.style.zIndex=++j._topZ,document.addEventListener("mousemove",this._drag),document.addEventListener("mouseup",this._endDrag),this._startX=a.clientX,this._startY=a.clientY),a.preventDefault()},_drag:function(a){var b=parseInt(this._panel.style.left),c=parseInt(this._panel.style.top),d=a.clientX,e=a.clientY;this.setPosition(b+d-this._startX,c+e-this._startY),this._startX=d,this._startY=e,a.preventDefault()},_endDrag:function(a){document.removeEventListener("mousemove",this._drag),document.removeEventListener("mouseup",this._endDrag),a.preventDefault()},setGlobalChangeHandler:function(a){return this._globalChangeHandler=a,this},_callGCH:function(){this._localStorageName&&this._saveInLocalStorage(this._localStorageName),this._globalChangeHandler&&this._globalChangeHandler()},hide:function(){return this._panel.style.visibility="hidden",this._hidden=!0,this},show:function(){return this._panel.style.visibility="visible",this._panel.style.zIndex=++j._topZ,this._hidden=!1,this},toggleVisibility:function(){return this._hidden?this.show():this.hide(),this},setCollapsible:function(a){return this._collapsible=a,this._draggable||this._collapsible?this._titleBar.style.cursor="pointer":this._titleBar.style.cursor="default",this},collapse:function(){return this._panel.removeChild(this._content),this._collapsed=!0,this},expand:function(){return this._panel.appendChild(this._content),this._collapsed=!1,this},toggleCollapsed:function(){return this._collapsed?this.expand():this.collapse(),this},setKey:function(a){return this._keyCode=a.toUpperCase().charCodeAt(0),document.body.addEventListener("keyup",this.onKeyUp),this},_onKeyUp:function(a){a.keyCode===this._keyCode&&this.toggleVisibility()},_doubleClickTitle:function(){this._collapsible&&this.toggleCollapsed()},removeControl:function(a){if(this._controls[a])var b=this._controls[a].container;return b&&b.parentElement&&b.parentElement.removeChild(b),this._controls[a]=null,this},enableControl:function(a){return this._controls[a]&&(this._controls[a].control.disabled=!1),this},disableControl:function(a){return this._controls[a]&&(this._controls[a].control.disabled=!0),this},hideControl:function(a){return this._controls[a]&&(this._controls[a].container.style.display="none"),this},showControl:function(a){return this._controls[a]&&(this._controls[a].container.style.display="block"),this},overrideStyle:function(a,b,c){return this._controls[a]&&(this._controls[a].control.style[b]=c),this},hideTitle:function(a){var b=this._controls[a].label;return b&&(b.style.display="none"),this},showTitle:function(a){var b=this._controls[a].label;return b&&(b.style.display="block"),this},hideAllTitles:function(){for(var a in this._controls){var b=this._controls[a].label;b&&(b.style.display="none")}return this},showAllTitles:function(){for(var a in this._controls){var b=this._controls[a].label;b&&(b.style.display="block")}return this},getValue:function(a){return this._controls[a].getValue()},setValue:function(a,b){return this._controls[a].setValue(b),this._callGCH(),this},addBoolean:function(a,d,e){var f=this._createContainer(),g=c("label",null,"qs_checkbox_label",f);g.textContent=a,g.setAttribute("for",a);var h=c("label",null,"qs_checkbox",f);h.setAttribute("for",a);var i=b("checkbox",a,null,h);i.checked=d;c("span",null,null,h);this._controls[a]={container:f,control:i,getValue:function(){return this.control.checked},setValue:function(a){this.control.checked=a,e&&e(a)}};var j=this;return i.addEventListener("change",function(){e&&e(i.checked),j._callGCH()}),this},bindBoolean:function(a,b,c){return this.addBoolean(a,b,function(b){c[a]=b})},addButton:function(a,c){var d=this._createContainer(),e=b("button",a,"qs_button",d);e.value=a,this._controls[a]={container:d,control:e};var f=this;return e.addEventListener("click",function(){c&&c(e),f._callGCH()}),this},addColor:function(g,h,i){if(e()||f()||d())return this.addText(g,h,i);var j=this._createContainer(),k=a(""+g+": "+h,j),l=b("color",g,"qs_color",j);l.value=h||"#ff0000";var m=c("label",null,"qs_color_label",j);m.setAttribute("for",g),m.style.backgroundColor=l.value,this._controls[g]={container:j,control:l,colorLabel:m,label:k,title:g,getValue:function(){return this.control.value},setValue:function(a){this.control.value=a,this.colorLabel.style.backgroundColor=l.value,this.label.innerHTML=""+this.title+": "+this.control.value,i&&i(a)}};var n=this;return l.addEventListener("input",function(){k.innerHTML=""+g+": "+l.value,m.style.backgroundColor=l.value,i&&i(l.value),n._callGCH()}),this},bindColor:function(a,b,c){return this.addColor(a,b,function(b){c[a]=b})},addDate:function(c,e,f){var g;if(e instanceof Date){var h=e.getFullYear(),i=e.getMonth()+1;i<10&&(i="0"+i);var j=e.getDate();g=h+"-"+i+"-"+j}else g=e;if(d())return this.addText(c,g,f);var k=this._createContainer(),l=a(""+c+" ",k),m=b("date",c,"qs_text_input",k);m.value=g||"",this._controls[c]={container:k,control:m,label:l,getValue:function(){return this.control.value},setValue:function(a){var b;if(a instanceof Date){var c=a.getFullYear(),d=a.getMonth()+1;d<10&&(d="0"+d);var e=a.getDate();e<10&&(e="0"+e),b=c+"-"+d+"-"+e}else b=a;this.control.value=b||"",f&&f(b)}};var n=this;return m.addEventListener("input",function(){f&&f(m.value),n._callGCH()}),this},bindDate:function(a,b,c){return this.addDate(a,b,function(b){c[a]=b})},addDropDown:function(b,d,e){for(var f=this._createContainer(),g=a(""+b+" ",f),h=c("select",null,"qs_select",f),i=0;i"+b+"",d);return d.appendChild(c),this._controls[b]={container:d,label:e},this},addFileChooser:function(d,e,f,g){var h=this._createContainer(),i=a(""+d+" ",h),j=b("file",d,"qs_file_chooser",h);f&&(j.accept=f);var k=c("label",null,"qs_file_chooser_label",h);k.setAttribute("for",d),k.textContent=e||"Choose a file...",this._controls[d]={container:h,control:j,label:i,getValue:function(){return this.control.files[0]}};var l=this;return j.addEventListener("change",function(){j.files&&j.files.length&&(k.textContent=j.files[0].name,g&&g(j.files[0]),l._callGCH())}),this},addHTML:function(b,d){var e=this._createContainer(),f=a(""+b+": ",e),g=c("div",null,null,e);return g.innerHTML=d,this._controls[b]={label:f,control:g,getValue:function(){return this.control.innerHTML},setValue:function(a){this.control.innerHTML=a}},this},addImage:function(b,d){var e=this._createContainer(),f=a(""+b+" ",e);return img=c("img",null,"qs_image",e),img.src=d,this._controls[b]={container:e,control:img,label:f,getValue:function(){return this.control.src},setValue:function(a){this.control.src=a}},this},addRange:function(a,b,c,d,e,f){return this._addNumber("range",a,b,c,d,e,f)},addNumber:function(a,b,c,d,e,f){return this._addNumber("number",a,b,c,d,e,f)},_addNumber:function(c,e,f,g,h,i,j){var k=this._createContainer(),l=a("",k),m="range"===c?"qs_range":"qs_text_input qs_number",n=b(c,e,m,k);n.min=f||0,n.max=g||100,n.step=i||1,n.value=h||0,l.innerHTML=""+e+": "+n.value,this._controls[e]={container:k,control:n,label:l,title:e,callback:j,getValue:function(){return parseFloat(this.control.value)},setValue:function(a){this.control.value=a,this.label.innerHTML=""+this.title+": "+this.control.value,j&&j(parseFloat(a))}};var o="input";"range"===c&&d()&&(o="change");var p=this;return n.addEventListener(o,function(){l.innerHTML=""+e+": "+n.value,j&&j(parseFloat(n.value)),p._callGCH()}),this},bindRange:function(a,b,c,d,e,f){return this.addRange(a,b,c,d,e,function(b){f[a]=b})},bindNumber:function(a,b,c,d,e,f){return this.addNumber(a,b,c,d,e,function(b){f[a]=b})},setRangeParameters:function(a,b,c,d){return this.setNumberParameters(a,b,c,d)},setNumberParameters:function(a,b,c,d){var e=this._controls[a],f=e.control.value;return e.control.min=b,e.control.max=c,e.control.step=d,e.control.value!==f&&e.callback&&e.callback(e.control.value),this},addPassword:function(a,b,c){return this._addText("password",a,b,c)},bindPassword:function(a,b,c){return this.addPassword(a,b,function(b){c[a]=b})},addProgressBar:function(b,d,e,f){var g=this._createContainer(),h=a("",g),i=c("div",null,"qs_progress",g),j=c("div",null,"qs_progress_value",i);return j.style.width=e/d*100+"%","numbers"===f?h.innerHTML=""+b+": "+e+" / "+d:"percent"===f?h.innerHTML=""+b+": "+Math.round(e/d*100)+"%":h.innerHTML=""+b+" ",this._controls[b]={container:g,control:i,valueDiv:j,valueDisplay:f,label:h,value:e,max:d,title:b,getValue:function(){return this.value},setValue:function(a){this.value=Math.max(0,Math.min(a,this.max)),this.valueDiv.style.width=this.value/this.max*100+"%","numbers"===this.valueDisplay?this.label.innerHTML=""+this.title+": "+this.value+" / "+this.max:"percent"===this.valueDisplay&&(this.label.innerHTML=""+this.title+": "+Math.round(this.value/this.max*100)+"%")}},this},setProgressMax:function(a,b){var c=this._controls[a];return c.max=b,c.value=Math.min(c.value,c.max),c.valueDiv.style.width=c.value/c.max*100+"%","numbers"===c.valueDisplay?c.label.innerHTML=""+c.title+": "+c.value+" / "+c.max:"percent"===c.valueDisplay?c.label.innerHTML=""+c.title+": "+Math.round(c.value/c.max*100)+"%":c.label.innerHTML=""+c.title+" ",this},addText:function(a,b,c){return this._addText("text",a,b,c)},_addText:function(d,e,f,g){var h,i=this._createContainer(),j=a(""+e+" ",i);"textarea"===d?(h=c("textarea",e,"qs_textarea",i),h.rows=5):h=b(d,e,"qs_text_input",i),h.value=f||"",this._controls[e]={container:i,control:h,label:j,getValue:function(){return this.control.value},setValue:function(a){this.control.value=a,g&&g(a)}};var k=this;return h.addEventListener("input",function(){g&&g(h.value),k._callGCH()}),this},bindText:function(a,b,c){return this.addText(a,b,function(b){c[a]=b})},addTextArea:function(a,b,c){return this._addText("textarea",a,b,c)},setTextAreaRows:function(a,b){return this._controls[a].control.rows=b,this},bindTextArea:function(a,b,c){return this.addTextArea(a,b,function(b){c[a]=b})},addTime:function(c,e,f){var g;if(e instanceof Date){var h=e.getHours();h<10&&(h="0"+h);var i=e.getMinutes();i<10&&(i="0"+i);var j=e.getSeconds();j<10&&(j="0"+j),g=h+":"+i+":"+j}else g=e;if(d())return this.addText(c,g,f);var k=this._createContainer(),l=a(""+c+" ",k),m=b("time",c,"qs_text_input",k);m.value=g||"",this._controls[c]={container:k,control:m,label:l,getValue:function(){return this.control.value},setValue:function(a){var b;if(a instanceof Date){var c=a.getHours();c<10&&(c="0"+c);var d=a.getMinutes();d<10&&(d="0"+d);var e=a.getSeconds();e<10&&(e="0"+e),b=c+":"+d+":"+e}else b=a;this.control.value=b||"",f&&f(b)}};var n=this;return m.addEventListener("input",function(){f&&f(m.value),n._callGCH()}),this},bindTime:function(a,b,c){return this.addTime(a,b,function(b){c[a]=b})}};"object"==typeof exports&&"object"==typeof module?module.exports=j:"function"==typeof define&&define.amd?define(j):window.QuickSettings=j}();
\ No newline at end of file
diff --git a/assets/tf.esnext.js b/assets/tf.esnext.js
deleted file mode 100644
index 1226e06f..00000000
--- a/assets/tf.esnext.js
+++ /dev/null
@@ -1,82571 +0,0 @@
-/**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-(function (global, factory) {
- typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
- typeof define === 'function' && define.amd ? define(['exports'], factory) :
- (global = global || self, factory(global.tf = global.tf || {}));
-}(this, (function (exports) { 'use strict';
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const EPSILON_FLOAT32 = 1e-7;
- const EPSILON_FLOAT16 = 1e-4;
- /** Convenient class for storing tensor-related data. */
- class DataStorage {
- constructor(backend, dataMover) {
- this.backend = backend;
- this.dataMover = dataMover;
- this.data = new WeakMap();
- this.dataIdsCount = 0;
- }
- get(dataId) {
- if (!this.data.has(dataId)) {
- this.dataMover.moveData(this.backend, dataId);
- }
- return this.data.get(dataId);
- }
- set(dataId, value) {
- this.dataIdsCount++;
- this.data.set(dataId, value);
- }
- has(dataId) {
- return this.data.has(dataId);
- }
- delete(dataId) {
- this.dataIdsCount--;
- return this.data.delete(dataId);
- }
- numDataIds() {
- return this.dataIdsCount;
- }
- }
- /**
- * The interface that defines the kernels that should be implemented when
- * adding a new backend. New backends don't need to implement every one of the
- * methods, this can be done gradually (throw an error for unimplemented
- * methods).
- */
- class KernelBackend {
- time(f) {
- return notYetImplemented('time');
- }
- read(dataId) {
- return notYetImplemented('read');
- }
- readSync(dataId) {
- return notYetImplemented('readSync');
- }
- numDataIds() {
- return notYetImplemented('numDataIds');
- }
- disposeData(dataId) {
- return notYetImplemented('disposeData');
- }
- write(values, shape, dtype) {
- return notYetImplemented('write');
- }
- move(dataId, values, shape, dtype) {
- return notYetImplemented('move');
- }
- memory() {
- return notYetImplemented('memory');
- }
- /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
- floatPrecision() {
- return notYetImplemented('floatPrecision');
- }
- /** Returns the smallest representable number. */
- epsilon() {
- return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
- }
- batchMatMul(a, b, transposeA, transposeB) {
- return notYetImplemented('batchMatMul');
- }
- fusedBatchMatMul({ a, b, transposeA, transposeB, bias, activation, preluActivationWeights }) {
- return notYetImplemented('fusedBatchMatMul');
- }
- slice(x, begin, size) {
- return notYetImplemented('slice');
- }
- stridedSlice(x, begin, end, strides) {
- return notYetImplemented('stridedSlice');
- }
- unstack(x, axis) {
- return notYetImplemented('unstack');
- }
- reverse(a, axis) {
- return notYetImplemented('reverse');
- }
- concat(tensors, axis) {
- return notYetImplemented('concat');
- }
- neg(a) {
- return notYetImplemented('neg');
- }
- add(a, b) {
- return notYetImplemented('add');
- }
- addN(tensors) {
- return notYetImplemented('addN');
- }
- subtract(a, b) {
- return notYetImplemented('subtract');
- }
- multiply(a, b) {
- return notYetImplemented('multiply');
- }
- realDivide(a, b) {
- return notYetImplemented('realDivide');
- }
- floorDiv(a, b) {
- return notYetImplemented('floorDiv');
- }
- sum(x, axes) {
- return notYetImplemented('sum');
- }
- prod(x, axes) {
- return notYetImplemented('prod');
- }
- unsortedSegmentSum(x, segmentIds, numSegments) {
- return notYetImplemented('unsortedSegmentSum');
- }
- argMin(x, axis) {
- return notYetImplemented('argMin');
- }
- argMax(x, axis) {
- return notYetImplemented('argMax');
- }
- equal(a, b) {
- return notYetImplemented('equal');
- }
- notEqual(a, b) {
- return notYetImplemented('notEqual');
- }
- less(a, b) {
- return notYetImplemented('less');
- }
- lessEqual(a, b) {
- return notYetImplemented('lessEqual');
- }
- greater(a, b) {
- return notYetImplemented('greater');
- }
- greaterEqual(a, b) {
- return notYetImplemented('greaterEqual');
- }
- logicalNot(a) {
- return notYetImplemented('logicalNot');
- }
- logicalAnd(a, b) {
- return notYetImplemented('logicalAnd');
- }
- logicalOr(a, b) {
- return notYetImplemented('logicalOr');
- }
- where(condition) {
- return notYetImplemented('where');
- }
- select(condition, a, b) {
- return notYetImplemented('select');
- }
- topk(x, k, sorted) {
- return notYetImplemented('topk');
- }
- min(x, axes) {
- return notYetImplemented('min');
- }
- minimum(a, b) {
- return notYetImplemented('minimum');
- }
- mod(a, b) {
- return notYetImplemented('mod');
- }
- max(x, axes) {
- return notYetImplemented('max');
- }
- maximum(a, b) {
- return notYetImplemented('maximum');
- }
- all(x, axes) {
- return notYetImplemented('all');
- }
- any(x, axes) {
- return notYetImplemented('any');
- }
- squaredDifference(a, b) {
- return notYetImplemented('squaredDifference');
- }
- ceil(x) {
- return notYetImplemented('ceil');
- }
- floor(x) {
- return notYetImplemented('floor');
- }
- round(x) {
- return notYetImplemented('round');
- }
- sign(x) {
- return notYetImplemented('sign');
- }
- isNaN(x) {
- return notYetImplemented('isNaN');
- }
- isInf(x) {
- return notYetImplemented('isInf');
- }
- isFinite(x) {
- return notYetImplemented('isFinite');
- }
- pow(a, b) {
- return notYetImplemented('pow');
- }
- exp(x) {
- return notYetImplemented('exp');
- }
- expm1(x) {
- return notYetImplemented('expm1');
- }
- softmax(x, dim) {
- return notYetImplemented('softmax');
- }
- log(x) {
- return notYetImplemented('log');
- }
- log1p(x) {
- return notYetImplemented('log1p');
- }
- sqrt(x) {
- return notYetImplemented('sqrt');
- }
- rsqrt(x) {
- return notYetImplemented('rsqrt');
- }
- square(x) {
- return notYetImplemented('square');
- }
- reciprocal(x) {
- return notYetImplemented('reciprocal');
- }
- relu(x) {
- return notYetImplemented('relu');
- }
- relu6(x) {
- return notYetImplemented('relu6');
- }
- prelu(x, a) {
- return notYetImplemented('prelu');
- }
- elu(x) {
- return notYetImplemented('elu');
- }
- eluDer(dy, y) {
- return notYetImplemented('eluDer');
- }
- selu(x) {
- return notYetImplemented('selu');
- }
- int(x) {
- return notYetImplemented('int');
- }
- clip(x, min, max) {
- return notYetImplemented('clip');
- }
- abs(x) {
- return notYetImplemented('abs');
- }
- complexAbs(x) {
- return notYetImplemented('complexAbs');
- }
- sigmoid(x) {
- return notYetImplemented('sigmoid');
- }
- softplus(x) {
- return notYetImplemented('softplus');
- }
- sin(x) {
- return notYetImplemented('sin');
- }
- cos(x) {
- return notYetImplemented('cos');
- }
- tan(x) {
- return notYetImplemented('tan');
- }
- asin(x) {
- return notYetImplemented('asin');
- }
- acos(x) {
- return notYetImplemented('acos');
- }
- atan(x) {
- return notYetImplemented('atan');
- }
- atan2(a, b) {
- return notYetImplemented('atan2');
- }
- sinh(x) {
- return notYetImplemented('sinh');
- }
- cosh(x) {
- return notYetImplemented('cosh');
- }
- tanh(x) {
- return notYetImplemented('tanh');
- }
- asinh(x) {
- return notYetImplemented('asinh');
- }
- acosh(x) {
- return notYetImplemented('acosh');
- }
- atanh(x) {
- return notYetImplemented('atanh');
- }
- erf(x) {
- return notYetImplemented('erf');
- }
- step(x, alpha) {
- return notYetImplemented('step');
- }
- fusedConv2d({ input, filter, convInfo, bias, activation, preluActivationWeights }) {
- return notYetImplemented('fusedConv2d');
- }
- conv2d(x, filter, convInfo) {
- return notYetImplemented('conv2d');
- }
- conv2dDerInput(dy, filter, convInfo) {
- return notYetImplemented('conv2dDerInput');
- }
- conv2dDerFilter(x, dY, convInfo) {
- return notYetImplemented('conv2dDerFilter');
- }
- fusedDepthwiseConv2D({ input, filter, convInfo, bias, activation, preluActivationWeights }) {
- return notYetImplemented('fusedDepthwiseConv2D');
- }
- depthwiseConv2D(input, filter, convInfo) {
- return notYetImplemented('depthwiseConv2D');
- }
- depthwiseConv2DDerInput(dy, filter, convInfo) {
- return notYetImplemented('depthwiseConv2DDerInput');
- }
- depthwiseConv2DDerFilter(x, dY, convInfo) {
- return notYetImplemented('depthwiseConv2DDerFilter');
- }
- conv3d(x, filter, convInfo) {
- return notYetImplemented('conv3d');
- }
- conv3dDerInput(dy, filter, convInfo) {
- return notYetImplemented('conv3dDerInput');
- }
- conv3dDerFilter(x, dY, convInfo) {
- return notYetImplemented('conv3dDerFilter');
- }
- maxPool(x, convInfo) {
- return notYetImplemented('maxPool');
- }
- maxPoolBackprop(dy, x, y, convInfo) {
- return notYetImplemented('maxPoolBackprop');
- }
- avgPool(x, convInfo) {
- return notYetImplemented('avgPool');
- }
- avgPoolBackprop(dy, x, convInfo) {
- return notYetImplemented('avgPoolBackprop');
- }
- avgPool3d(x, convInfo) {
- return notYetImplemented('avgPool3d');
- }
- avgPool3dBackprop(dy, x, convInfo) {
- return notYetImplemented('avgPool3dBackprop');
- }
- maxPool3d(x, convInfo) {
- return notYetImplemented('maxPool3d');
- }
- maxPool3dBackprop(dy, x, y, convInfo) {
- return notYetImplemented('maxPool3dBackprop');
- }
- reshape(x, shape) {
- return notYetImplemented('reshape');
- }
- cast(x, dtype) {
- return notYetImplemented('cast');
- }
- tile(x, reps) {
- return notYetImplemented('tile');
- }
- pad(x, paddings, constantValue) {
- return notYetImplemented('pad');
- }
- transpose(x, perm) {
- return notYetImplemented('transpose');
- }
- gather(x, indices, axis) {
- return notYetImplemented('gather');
- }
- gatherND(x, indices) {
- return notYetImplemented('gatherND');
- }
- scatterND(indices, updates, shape) {
- return notYetImplemented('scatterND');
- }
- batchToSpaceND(x, blockShape, crops) {
- return notYetImplemented('batchToSpaceND');
- }
- spaceToBatchND(x, blockShape, paddings) {
- return notYetImplemented('spaceToBatchND');
- }
- resizeBilinear(x, newHeight, newWidth, alignCorners) {
- return notYetImplemented('resizeBilinear');
- }
- resizeBilinearBackprop(dy, x, alignCorners) {
- return notYetImplemented('resizeBilinearBackprop');
- }
- resizeNearestNeighbor(x, newHEight, newWidth, alignCorners) {
- return notYetImplemented('resizeNearestNeighbor');
- }
- resizeNearestNeighborBackprop(dy, x, alignCorners) {
- return notYetImplemented('resizeNearestNeighborBackprop');
- }
- batchNorm(x, mean, variance, offset, scale, varianceEpsilon) {
- return notYetImplemented('batchNorm');
- }
- localResponseNormalization4D(x, radius, bias, alpha, beta) {
- return notYetImplemented('localResponseNormalization4D');
- }
- LRNGrad(dy, inputImage, outputImage, radius, bias, alpha, beta) {
- return notYetImplemented('LRNGrad');
- }
- multinomial(logits, normalized, numSamples, seed) {
- return notYetImplemented('multinomial');
- }
- oneHot(indices, depth, onValue, offValue) {
- return notYetImplemented('oneHot');
- }
- cumsum(x, axis, exclusive, reverse) {
- return notYetImplemented('cumsum');
- }
- nonMaxSuppression(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
- return notYetImplemented('nonMaxSuppression');
- }
- fft(x) {
- return notYetImplemented('fft');
- }
- ifft(x) {
- return notYetImplemented('ifft');
- }
- complex(real, imag) {
- return notYetImplemented('complex');
- }
- real(input) {
- return notYetImplemented('real');
- }
- imag(input) {
- return notYetImplemented('imag');
- }
- cropAndResize(image, boxes, boxIndex, cropSize, method, extrapolationValue) {
- return notYetImplemented('cropAndResize');
- }
- depthToSpace(x, blockSize, dataFormat) {
- return notYetImplemented('depthToSpace');
- }
- // Aligns with the "SplitV" kernel in TensorFlow.
- split(value, sizeSplits, axis) {
- return notYetImplemented('split');
- }
- sparseToDense(sparseIndices, sparseValues, outputShape, defaultValue) {
- return notYetImplemented('sparseToDense');
- }
- diag(x) {
- return notYetImplemented('diag');
- }
- fill(shape, value, dtype) {
- return notYetImplemented('fill');
- }
- onesLike(x) {
- return notYetImplemented('onesLike');
- }
- zerosLike(x) {
- return notYetImplemented('zerosLike');
- }
- linspace(start, stop, num) {
- return notYetImplemented('linspace');
- }
- dispose() {
- return notYetImplemented('dispose');
- }
- }
- function notYetImplemented(kernelName) {
- throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
- `This kernel may not be supported by the tfjs backend you have chosen`);
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
- const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
- /**
- * The environment contains evaluated flags as well as the registered platform.
- * This is always used as a global singleton and can be retrieved with
- * `tf.env()`.
- *
- * @doc {heading: 'Environment'}
- */
- class Environment {
- // tslint:disable-next-line: no-any
- constructor(global) {
- this.global = global;
- this.flags = {};
- this.flagRegistry = {};
- this.urlFlags = {};
- this.populateURLFlags();
- }
- setPlatform(platformName, platform) {
- if (this.platform != null) {
- console.warn(`Platform ${this.platformName} has already been set. ` +
- `Overwriting the platform with ${platform}.`);
- }
- this.platformName = platformName;
- this.platform = platform;
- }
- registerFlag(flagName, evaluationFn, setHook) {
- this.flagRegistry[flagName] = { evaluationFn, setHook };
- // Override the flag value from the URL. This has to happen here because the
- // environment is initialized before flags get registered.
- if (this.urlFlags[flagName] != null) {
- const flagValue = this.urlFlags[flagName];
- console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
- this.set(flagName, flagValue);
- }
- }
- async getAsync(flagName) {
- if (flagName in this.flags) {
- return this.flags[flagName];
- }
- this.flags[flagName] = await this.evaluateFlag(flagName);
- return this.flags[flagName];
- }
- get(flagName) {
- if (flagName in this.flags) {
- return this.flags[flagName];
- }
- const flagValue = this.evaluateFlag(flagName);
- if (flagValue instanceof Promise) {
- throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
- `Please use getAsync() instead.`);
- }
- this.flags[flagName] = flagValue;
- return this.flags[flagName];
- }
- getNumber(flagName) {
- return this.get(flagName);
- }
- getBool(flagName) {
- return this.get(flagName);
- }
- getFlags() {
- return this.flags;
- }
- // For backwards compatibility.
- get features() {
- return this.flags;
- }
- set(flagName, value) {
- if (this.flagRegistry[flagName] == null) {
- throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
- }
- this.flags[flagName] = value;
- if (this.flagRegistry[flagName].setHook != null) {
- this.flagRegistry[flagName].setHook(value);
- }
- }
- evaluateFlag(flagName) {
- if (this.flagRegistry[flagName] == null) {
- throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
- }
- return this.flagRegistry[flagName].evaluationFn();
- }
- setFlags(flags) {
- this.flags = Object.assign({}, flags);
- }
- reset() {
- this.flags = {};
- this.urlFlags = {};
- this.populateURLFlags();
- }
- populateURLFlags() {
- if (typeof this.global === 'undefined' ||
- typeof this.global.location === 'undefined' ||
- typeof this.global.location.search === 'undefined') {
- return;
- }
- const urlParams = getQueryParams(this.global.location.search);
- if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
- const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
- keyValues.forEach(keyValue => {
- const [key, value] = keyValue.split(':');
- this.urlFlags[key] = parseValue(key, value);
- });
- }
- }
- }
- function getQueryParams(queryString) {
- const params = {};
- queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
- decodeParam(params, t[0], t[1]);
- return t.join('=');
- });
- return params;
- }
- function decodeParam(params, name, value) {
- params[decodeURIComponent(name)] = decodeURIComponent(value || '');
- }
- function parseValue(flagName, value) {
- value = value.toLowerCase();
- if (value === 'true' || value === 'false') {
- return value === 'true';
- }
- else if (`${+value}` === value) {
- return +value;
- }
- throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`);
- }
- /**
- * Returns the current environment (a global singleton).
- *
- * The environment object contains the evaluated feature values as well as the
- * active platform.
- *
- * @doc {heading: 'Environment'}
- */
- function env() {
- return exports.ENV;
- }
- exports.ENV = null;
- function setEnvironmentGlobal(environment) {
- exports.ENV = environment;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Note that the identifier globalNameSpace is scoped to this module, but will
- // always resolve to the same global object regardless of how the module is
- // resolved.
- // tslint:disable-next-line:no-any
- let globalNameSpace;
- // tslint:disable-next-line:no-any
- function getGlobalNamespace() {
- if (globalNameSpace == null) {
- // tslint:disable-next-line:no-any
- let ns;
- if (typeof (window) !== 'undefined') {
- ns = window;
- }
- else if (typeof (global) !== 'undefined') {
- ns = global;
- }
- else if (typeof (process) !== 'undefined') {
- ns = process;
- }
- else if (typeof (self) !== 'undefined') {
- ns = self;
- }
- else {
- throw new Error('Could not find a global object');
- }
- globalNameSpace = ns;
- }
- return globalNameSpace;
- }
- // tslint:disable-next-line:no-any
- function getGlobalMap() {
- const ns = getGlobalNamespace();
- if (ns._tfGlobals == null) {
- ns._tfGlobals = new Map();
- }
- return ns._tfGlobals;
- }
- /**
- * Returns a globally accessible 'singleton' object.
- *
- * @param key the name of the object
- * @param init a function to initialize to initialize this object
- * the first time it is fetched.
- */
- function getGlobal(key, init) {
- const globalMap = getGlobalMap();
- if (globalMap.has(key)) {
- return globalMap.get(key);
- }
- else {
- const singleton = init();
- globalMap.set(key, singleton);
- return globalMap.get(key);
- }
- }
-
- const Abs = 'Abs';
- const Acos = 'Acos';
- const Acosh = 'Acosh';
- const Add = 'Add';
- const AddN = 'AddN';
- const All = 'All';
- const Any = 'Any';
- const ArgMax = 'ArgMax';
- const ArgMin = 'ArgMin';
- const Asin = 'Asin';
- const Asinh = 'Asinh';
- const Atan = 'Atan';
- const Atanh = 'Atanh';
- const Atan2 = 'Atan2';
- const AvgPool = 'AvgPool';
- const AvgPoolBackprop = 'AvgPoolBackprop';
- const AvgPool3D = 'AvgPool3D';
- const AvgPool3DBackprop = 'AvgPool3DBackprop';
- const BatchMatMul = 'BatchMatMul';
- const BatchToSpaceND = 'BatchToSpaceND';
- const BroadcastTo = 'BroadcastTo';
- const Cast = 'Cast';
- const Ceil = 'Ceil';
- const ClipByValue = 'ClipByValue';
- const Complex = 'Complex';
- const Concat = 'Concat';
- const Conv2D = 'Conv2D';
- const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
- const Conv2DBackpropInput = 'Conv2DBackpropInput';
- const Conv3D = 'Conv3D';
- const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
- const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
- const Cos = 'Cos';
- const Cosh = 'Cosh';
- const Cumsum = 'Cumsum';
- const CropAndResize = 'CropAndResize';
- const DepthToSpace = 'DepthToSpace';
- const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
- const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
- const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
- const Diag = 'Diag';
- const Dilation2D = 'Dilation2D';
- const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
- const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
- const Div = 'Div';
- const Elu = 'Elu';
- const EluGrad = 'EluGrad';
- const Erf = 'Erf';
- const Equal = 'Equal';
- const Exp = 'Exp';
- const Expm1 = 'Expm1';
- const FFT = 'FFT';
- const Fill = 'Fill';
- const FlipLeftRight = 'FlipLeftRight';
- const Floor = 'Floor';
- const FloorDiv = 'FloorDiv';
- const FusedBatchNorm = 'FusedBatchNorm';
- const GatherV2 = 'GatherV2';
- const GatherNd = 'GatherNd';
- const Greater = 'Greater';
- const GreaterEqual = 'GreaterEqual';
- const Identity = 'Identity';
- const IFFT = 'IFFT';
- const Imag = 'Imag';
- const IsFinite = 'IsFinite';
- const IsInf = 'IsInf';
- const IsNan = 'IsNan';
- const Less = 'Less';
- const LessEqual = 'LessEqual';
- const LinSpace = 'LinSpace';
- const Log = 'Log';
- const Log1p = 'Log1p';
- const LogicalAnd = 'LogicalAnd';
- const LogicalNot = 'LogicalNot';
- const LogicalOr = 'LogicalOr';
- const LogSoftmax = 'LogSoftmax';
- const LRN = 'LRN';
- const LRNBackprop = 'LRNBackprop';
- const Max = 'Max';
- const Maximum = 'Maximum';
- const MaxPool = 'MaxPool';
- const MaxPoolBackprop = 'MaxPoolBackprop';
- const MaxPool3D = 'MaxPool3D';
- const MaxPool3DBackprop = 'MaxPool3DBackprop';
- const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
- const Mean = 'Mean';
- const Min = 'Min';
- const Minimum = 'Minimum';
- const MirrorPad = 'MirrorPad';
- const Mod = 'Mod';
- const Multiply = 'Multiply';
- const Negate = 'Negate';
- const NotEqual = 'NotEqual';
- const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
- const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
- const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
- const OnesLike = 'OnesLike';
- const OneHot = 'OneHot';
- const PadV2 = 'PadV2';
- const Pool = 'Pool';
- const Pow = 'Pow';
- const Prelu = 'Prelu';
- const Prod = 'Prod';
- const Range = 'Range';
- const Real = 'Real';
- const Reciprocal = 'Reciprocal';
- const Relu = 'Relu';
- const Reshape = 'Reshape';
- const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
- const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
- const ResizeBilinear = 'ResizeBilinear';
- const ResizeBilinearGrad = 'ResizeBilinearGrad';
- const Relu6 = 'Relu6';
- const Reverse = 'Reverse';
- const Round = 'Round';
- const Rsqrt = 'Rsqrt';
- const ScatterNd = 'ScatterNd';
- const SelectV2 = 'SelectV2';
- const Selu = 'Selu';
- const Slice = 'Slice';
- const Sin = 'Sin';
- const Sinh = 'Sinh';
- const Sign = 'Sign';
- const Sigmoid = 'Sigmoid';
- const Softplus = 'Softplus';
- const Sqrt = 'Sqrt';
- const Sum = 'Sum';
- const SpaceToBatchND = 'SpaceToBatchND';
- const SplitV = 'SplitV';
- const Softmax = 'Softmax';
- const SquaredDifference = 'SquaredDifference';
- const Square = 'Square';
- const Sub = 'Sub';
- const SparseToDense = 'SparseToDense';
- const StridedSlice = 'StridedSlice';
- const Tan = 'Tan';
- const Tanh = 'Tanh';
- const Tile = 'Tile';
- const TopK = 'TopK';
- const Transpose = 'Transpose';
- const Unique = 'Unique';
- const Unpack = 'Unpack';
- const UnsortedSegmentSum = 'UnsortedSegmentSum';
- const ZerosLike = 'ZerosLike';
- /**
- * TensorFlow.js-only kernels
- */
- const Step = 'Step';
- const FromPixels = 'FromPixels';
- const RotateWithOffset = 'RotateWithOffset';
- const _FusedMatMul = '_FusedMatMul';
- const FusedConv2D = 'FusedConv2D';
- const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
- const gradRegistry = getGlobal('gradRegistry', () => new Map());
- /**
- * Returns the kernel function (code) associated with the provided names.
- *
- * @param kernelName The official name of the kernel.
- * @param backendName The official name of the backend.
- */
- function getKernel(kernelName, backendName) {
- const key = makeKey(kernelName, backendName);
- return kernelRegistry.get(key);
- }
- /**
- * Returns the registered gradient info associated with the provided kernel.
- * @param kernelName The official TF kernel name.
- */
- function getGradient(kernelName) {
- return gradRegistry.get(kernelName);
- }
- function getKernelsForBackend(backendName) {
- const it = kernelRegistry.entries();
- const result = [];
- while (true) {
- const { done, value } = it.next();
- if (done) {
- break;
- }
- const [key, config] = value;
- const [backend,] = key.split('_');
- if (backend === backendName) {
- result.push(config);
- }
- }
- return result;
- }
- /**
- * Registers the function (forward pass) for the kernel in a global registry.
- *
- * @param config A config object with the following properties:
- * - `kernelName` The official name of the kernel.
- * - `backendName` The official name of the backend.
- * - `kernelFunc` The function to run during the forward pass of the kernel.
- * - `setupFunc` Optional. Gets called once, after the backend initializes.
- * - `disposeFunc` Optional. Gets called once, right before the backend is
- * disposed.
- */
- function registerKernel(config) {
- const { kernelName, backendName } = config;
- const key = makeKey(kernelName, backendName);
- if (kernelRegistry.has(key)) {
- console.warn(`The kernel '${kernelName}' for backend ` +
- `'${backendName}' is already registered`);
- }
- kernelRegistry.set(key, config);
- }
- /**
- * Registers a gradient function for a given kernel in the global registry,
- * to be used during the back-propagation of that kernel.
- *
- * @param config An object with the following properties:
- * - `kernelName` The name of the kernel that the gradient function is for.
- * - `gradFunc` The function to run during back-propagation.
- */
- function registerGradient(config) {
- const { kernelName } = config;
- if (gradRegistry.has(kernelName)) {
- // TODO (yassogba) after 3.0 assess whether we need to keep this gated
- // to debug mode.
- if (env().getBool('DEBUG')) {
- console.warn(`Overriding the gradient for '${kernelName}'`);
- }
- }
- gradRegistry.set(kernelName, config);
- }
- /**
- * Removes the kernel function from the registry.
- *
- * @param kernelName The official name of the kernel.
- * @param backendName The official name of the backend.
- *
- */
- function unregisterKernel(kernelName, backendName) {
- const key = makeKey(kernelName, backendName);
- if (!kernelRegistry.has(key)) {
- throw new Error(`The kernel '${kernelName}' for backend ` +
- `'${backendName}' is not registered`);
- }
- kernelRegistry.delete(key);
- }
- /** Removes the registered gradient from the global registry. */
- function unregisterGradient(kernelName) {
- if (!gradRegistry.has(kernelName)) {
- throw new Error(`The gradient '${kernelName}' for backend is not registered`);
- }
- gradRegistry.delete(kernelName);
- }
- /**
- * Finds kernels that have already been registered to a backend and re-registers
- * them for a new backend. Useful for registering custom backends.
- * @param registeredBackendName Already registered backend.
- * @param newBackendName New backend.
- */
- function copyRegisteredKernels(registeredBackendName, newBackendName) {
- const kernels = getKernelsForBackend(registeredBackendName);
- kernels.forEach(kernelConfig => {
- const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
- registerKernel(newKernelConfig);
- });
- }
- function makeKey(kernelName, backendName) {
- return `${backendName}_${kernelName}`;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Shuffles the array in-place using Fisher-Yates algorithm.
- *
- * ```js
- * const a = [1, 2, 3, 4, 5];
- * tf.util.shuffle(a);
- * console.log(a);
- * ```
- *
- * @param array The array to shuffle in-place.
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- // tslint:disable-next-line:no-any
- function shuffle(array) {
- let counter = array.length;
- let temp = 0;
- let index = 0;
- // While there are elements in the array
- while (counter > 0) {
- // Pick a random index
- index = (Math.random() * counter) | 0;
- // Decrease counter by 1
- counter--;
- // And swap the last element with it
- temp = array[counter];
- array[counter] = array[index];
- array[index] = temp;
- }
- }
- /** Clamps a value to a specified range. */
- function clamp(min, x, max) {
- return Math.max(min, Math.min(x, max));
- }
- function nearestLargerEven(val) {
- return val % 2 === 0 ? val : val + 1;
- }
- function sum(arr) {
- let sum = 0;
- for (let i = 0; i < arr.length; i++) {
- sum += arr[i];
- }
- return sum;
- }
- /**
- * Returns a sample from a uniform [a, b) distribution.
- *
- * @param a The minimum support (inclusive).
- * @param b The maximum support (exclusive).
- * @return A pseudorandom number on the half-open interval [a,b).
- */
- function randUniform(a, b) {
- const r = Math.random();
- return (b * r) + (1 - r) * a;
- }
- /** Returns the squared Euclidean distance between two vectors. */
- function distSquared(a, b) {
- let result = 0;
- for (let i = 0; i < a.length; i++) {
- const diff = Number(a[i]) - Number(b[i]);
- result += diff * diff;
- }
- return result;
- }
- /**
- * Asserts that the expression is true. Otherwise throws an error with the
- * provided message.
- *
- * ```js
- * const x = 2;
- * tf.util.assert(x === 2, 'x is not 2');
- * ```
- *
- * @param expr The expression to assert (as a boolean).
- * @param msg A function that returns the message to report when throwing an
- * error. We use a function for performance reasons.
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- function assert(expr, msg) {
- if (!expr) {
- throw new Error(typeof msg === 'string' ? msg : msg());
- }
- }
- function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
- assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
- }
- function assertNonNull(a) {
- assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
- }
- // NOTE: We explicitly type out what T extends instead of any so that
- // util.flatten on a nested array of number doesn't try to infer T as a
- // number[][], causing us to explicitly type util.flatten().
- /**
- * Flattens an arbitrarily nested array.
- *
- * ```js
- * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
- * const flat = tf.util.flatten(a);
- * console.log(flat);
- * ```
- *
- * @param arr The nested array to flatten.
- * @param result The destination array which holds the elements.
- * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
- * to false.
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- function flatten(arr, result = [], skipTypedArray = false) {
- if (result == null) {
- result = [];
- }
- if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
- for (let i = 0; i < arr.length; ++i) {
- flatten(arr[i], result, skipTypedArray);
- }
- }
- else {
- result.push(arr);
- }
- return result;
- }
- /**
- * Returns the size (number of elements) of the tensor given its shape.
- *
- * ```js
- * const shape = [3, 4, 2];
- * const size = tf.util.sizeFromShape(shape);
- * console.log(size);
- * ```
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- function sizeFromShape(shape) {
- if (shape.length === 0) {
- // Scalar.
- return 1;
- }
- let size = shape[0];
- for (let i = 1; i < shape.length; i++) {
- size *= shape[i];
- }
- return size;
- }
- function isScalarShape(shape) {
- return shape.length === 0;
- }
- function arraysEqual(n1, n2) {
- if (n1 === n2) {
- return true;
- }
- if (n1 == null || n2 == null) {
- return false;
- }
- if (n1.length !== n2.length) {
- return false;
- }
- for (let i = 0; i < n1.length; i++) {
- if (n1[i] !== n2[i]) {
- return false;
- }
- }
- return true;
- }
- function isInt(a) {
- return a % 1 === 0;
- }
- function tanh(x) {
- // tslint:disable-next-line:no-any
- if (Math.tanh != null) {
- // tslint:disable-next-line:no-any
- return Math.tanh(x);
- }
- if (x === Infinity) {
- return 1;
- }
- else if (x === -Infinity) {
- return -1;
- }
- else {
- const e2x = Math.exp(2 * x);
- return (e2x - 1) / (e2x + 1);
- }
- }
- function sizeToSquarishShape(size) {
- const width = Math.ceil(Math.sqrt(size));
- return [width, Math.ceil(size / width)];
- }
- /**
- * Creates a new array with randomized indicies to a given quantity.
- *
- * ```js
- * const randomTen = tf.util.createShuffledIndices(10);
- * console.log(randomTen);
- * ```
- *
- * @param number Quantity of how many shuffled indicies to create.
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- function createShuffledIndices(n) {
- const shuffledIndices = new Uint32Array(n);
- for (let i = 0; i < n; ++i) {
- shuffledIndices[i] = i;
- }
- shuffle(shuffledIndices);
- return shuffledIndices;
- }
- function rightPad(a, size) {
- if (size <= a.length) {
- return a;
- }
- return a + ' '.repeat(size - a.length);
- }
- function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) {
- return new Promise((resolve, reject) => {
- let tryCount = 0;
- const tryFn = () => {
- if (checkFn()) {
- resolve();
- return;
- }
- tryCount++;
- const nextBackoff = delayFn(tryCount);
- if (maxCounter != null && tryCount >= maxCounter) {
- reject();
- return;
- }
- setTimeout(tryFn, nextBackoff);
- };
- tryFn();
- });
- }
- /**
- * Given the full size of the array and a shape that may contain -1 as the
- * implicit dimension, returns the inferred shape where -1 is replaced.
- * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
- *
- * @param shape The shape, which may contain -1 in some dimension.
- * @param size The full size (number of elements) of the array.
- * @return The inferred shape where -1 is replaced with the inferred size.
- */
- function inferFromImplicitShape(shape, size) {
- let shapeProd = 1;
- let implicitIdx = -1;
- for (let i = 0; i < shape.length; ++i) {
- if (shape[i] >= 0) {
- shapeProd *= shape[i];
- }
- else if (shape[i] === -1) {
- if (implicitIdx !== -1) {
- throw Error(`Shapes can only have 1 implicit size. ` +
- `Found -1 at dim ${implicitIdx} and dim ${i}`);
- }
- implicitIdx = i;
- }
- else if (shape[i] < 0) {
- throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
- }
- }
- if (implicitIdx === -1) {
- if (size > 0 && size !== shapeProd) {
- throw Error(`Size(${size}) must match the product of shape ${shape}`);
- }
- return shape;
- }
- if (shapeProd === 0) {
- throw Error(`Cannot infer the missing size in [${shape}] when ` +
- `there are 0 elements`);
- }
- if (size % shapeProd !== 0) {
- throw Error(`The implicit shape can't be a fractional number. ` +
- `Got ${size} / ${shapeProd}`);
- }
- const newShape = shape.slice();
- newShape[implicitIdx] = size / shapeProd;
- return newShape;
- }
- function parseAxisParam(axis, shape) {
- const rank = shape.length;
- // Normalize input
- axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
- // Check for valid range
- assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
- `got axis ${axis}`);
- // Check for only integers
- assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
- `got axis ${axis}`);
- // Handle negative axis.
- return axis.map(a => a < 0 ? rank + a : a);
- }
- /** Reduces the shape by removing all dimensions of shape 1. */
- function squeezeShape(shape, axis) {
- const newShape = [];
- const keptDims = [];
- const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
- const axes = (axis == null || isEmptyArray) ?
- null :
- parseAxisParam(axis, shape).sort();
- let j = 0;
- for (let i = 0; i < shape.length; ++i) {
- if (axes != null) {
- if (axes[j] === i && shape[i] !== 1) {
- throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
- }
- if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
- newShape.push(shape[i]);
- keptDims.push(i);
- }
- if (axes[j] <= i) {
- j++;
- }
- }
- if (shape[i] !== 1) {
- newShape.push(shape[i]);
- keptDims.push(i);
- }
- }
- return { newShape, keptDims };
- }
- function getTypedArrayFromDType(dtype, size) {
- let values = null;
- if (dtype == null || dtype === 'float32') {
- values = new Float32Array(size);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(size);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(size);
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- return values;
- }
- function getArrayFromDType(dtype, size) {
- let values = null;
- if (dtype == null || dtype === 'float32') {
- values = new Float32Array(size);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(size);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(size);
- }
- else if (dtype === 'string') {
- values = new Array(size);
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- return values;
- }
- function checkConversionForErrors(vals, dtype) {
- for (let i = 0; i < vals.length; i++) {
- const num = vals[i];
- if (isNaN(num) || !isFinite(num)) {
- throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
- }
- }
- }
- /** Returns true if the dtype is valid. */
- function isValidDtype(dtype) {
- return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
- dtype === 'int32' || dtype === 'string';
- }
- /**
- * Returns true if the new type can't encode the old type without loss of
- * precision.
- */
- function hasEncodingLoss(oldType, newType) {
- if (newType === 'complex64') {
- return false;
- }
- if (newType === 'float32' && oldType !== 'complex64') {
- return false;
- }
- if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
- return false;
- }
- if (newType === 'bool' && oldType === 'bool') {
- return false;
- }
- return true;
- }
- function isTypedArray(a) {
- return a instanceof Float32Array || a instanceof Int32Array ||
- a instanceof Uint8Array;
- }
- function bytesPerElement(dtype) {
- if (dtype === 'float32' || dtype === 'int32') {
- return 4;
- }
- else if (dtype === 'complex64') {
- return 8;
- }
- else if (dtype === 'bool') {
- return 1;
- }
- else {
- throw new Error(`Unknown dtype ${dtype}`);
- }
- }
- /**
- * Returns the approximate number of bytes allocated in the string array - 2
- * bytes per character. Computing the exact bytes for a native string in JS is
- * not possible since it depends on the encoding of the html page that serves
- * the website.
- */
- function bytesFromStringArray(arr) {
- if (arr == null) {
- return 0;
- }
- let bytes = 0;
- arr.forEach(x => bytes += x.length);
- return bytes;
- }
- /** Returns true if the value is a string. */
- function isString(value) {
- return typeof value === 'string' || value instanceof String;
- }
- function isBoolean(value) {
- return typeof value === 'boolean';
- }
- function isNumber(value) {
- return typeof value === 'number';
- }
- function inferDtype(values) {
- if (Array.isArray(values)) {
- return inferDtype(values[0]);
- }
- if (values instanceof Float32Array) {
- return 'float32';
- }
- else if (values instanceof Int32Array || values instanceof Uint8Array) {
- return 'int32';
- }
- else if (isNumber(values)) {
- return 'float32';
- }
- else if (isString(values)) {
- return 'string';
- }
- else if (isBoolean(values)) {
- return 'bool';
- }
- return 'float32';
- }
- function isFunction(f) {
- return !!(f && f.constructor && f.call && f.apply);
- }
- function nearestDivisor(size, start) {
- for (let i = start; i < size; ++i) {
- if (size % i === 0) {
- return i;
- }
- }
- return size;
- }
- function computeStrides(shape) {
- const rank = shape.length;
- if (rank < 2) {
- return [];
- }
- // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
- // strides.
- const strides = new Array(rank - 1);
- strides[rank - 2] = shape[rank - 1];
- for (let i = rank - 3; i >= 0; --i) {
- strides[i] = strides[i + 1] * shape[i + 1];
- }
- return strides;
- }
- /**
- * Create typed array for scalar value. Used for storing in `DataStorage`.
- */
- function createScalarValue(value, dtype) {
- if (dtype === 'string') {
- return encodeString(value);
- }
- return toTypedArray([value], dtype);
- }
- function toTypedArray(a, dtype) {
- if (dtype === 'string') {
- throw new Error('Cannot convert a string[] to a TypedArray');
- }
- if (Array.isArray(a)) {
- a = flatten(a);
- }
- if (env().getBool('DEBUG')) {
- checkConversionForErrors(a, dtype);
- }
- if (noConversionNeeded(a, dtype)) {
- return a;
- }
- if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
- return new Float32Array(a);
- }
- else if (dtype === 'int32') {
- return new Int32Array(a);
- }
- else if (dtype === 'bool') {
- const bool = new Uint8Array(a.length);
- for (let i = 0; i < bool.length; ++i) {
- if (Math.round(a[i]) !== 0) {
- bool[i] = 1;
- }
- }
- return bool;
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- }
- function createNestedArray(offset, shape, a) {
- const ret = new Array();
- if (shape.length === 1) {
- const d = shape[0];
- for (let i = 0; i < d; i++) {
- ret[i] = a[offset + i];
- }
- }
- else {
- const d = shape[0];
- const rest = shape.slice(1);
- const len = rest.reduce((acc, c) => acc * c);
- for (let i = 0; i < d; i++) {
- ret[i] = createNestedArray(offset + i * len, rest, a);
- }
- }
- return ret;
- }
- // Provide a nested array of TypedArray in given shape.
- function toNestedArray(shape, a) {
- if (shape.length === 0) {
- // Scalar type should return a single number.
- return a[0];
- }
- const size = shape.reduce((acc, c) => acc * c);
- if (size === 0) {
- // A tensor with shape zero should be turned into empty list.
- return [];
- }
- if (size !== a.length) {
- throw new Error(`[${shape}] does not match the input size ${a.length}.`);
- }
- return createNestedArray(0, shape, a);
- }
- function noConversionNeeded(a, dtype) {
- return (a instanceof Float32Array && dtype === 'float32') ||
- (a instanceof Int32Array && dtype === 'int32') ||
- (a instanceof Uint8Array && dtype === 'bool');
- }
- function makeOnesTypedArray(size, dtype) {
- const array = makeZerosTypedArray(size, dtype);
- for (let i = 0; i < array.length; i++) {
- array[i] = 1;
- }
- return array;
- }
- function makeZerosTypedArray(size, dtype) {
- if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
- return new Float32Array(size);
- }
- else if (dtype === 'int32') {
- return new Int32Array(size);
- }
- else if (dtype === 'bool') {
- return new Uint8Array(size);
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- }
- /**
- * Make nested `TypedArray` filled with zeros.
- * @param shape The shape information for the nested array.
- * @param dtype dtype of the array element.
- */
- function makeZerosNestedTypedArray(shape, dtype) {
- const size = shape.reduce((prev, curr) => prev * curr, 1);
- if (dtype == null || dtype === 'float32') {
- return toNestedArray(shape, new Float32Array(size));
- }
- else if (dtype === 'int32') {
- return toNestedArray(shape, new Int32Array(size));
- }
- else if (dtype === 'bool') {
- return toNestedArray(shape, new Uint8Array(size));
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- }
- /**
- * Returns the current high-resolution time in milliseconds relative to an
- * arbitrary time in the past. It works across different platforms (node.js,
- * browsers).
- *
- * ```js
- * console.log(tf.util.now());
- * ```
- *
- * @doc {heading: 'Util', namespace: 'util'}
- */
- function now() {
- return env().platform.now();
- }
- function assertNonNegativeIntegerDimensions(shape) {
- shape.forEach(dimSize => {
- assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
- `shape [${shape}].`);
- });
- }
- /**
- * Returns a platform-specific implementation of
- * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
- *
- * If `fetch` is defined on the global object (`window`, `process`, etc.),
- * `tf.util.fetch` returns that function.
- *
- * If not, `tf.util.fetch` returns a platform-specific solution.
- *
- * ```js
- * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
- * // handle response
- * ```
- *
- * @doc {heading: 'Util'}
- */
- function fetch$1(path, requestInits) {
- return env().platform.fetch(path, requestInits);
- }
- /**
- * Encodes the provided string into bytes using the provided encoding scheme.
- *
- * @param s The string to encode.
- * @param encoding The encoding scheme. Defaults to utf-8.
- *
- * @doc {heading: 'Util'}
- */
- function encodeString(s, encoding = 'utf-8') {
- encoding = encoding || 'utf-8';
- return env().platform.encode(s, encoding);
- }
- /**
- * Decodes the provided bytes into a string using the provided encoding scheme.
- * @param bytes The bytes to decode.
- *
- * @param encoding The encoding scheme. Defaults to utf-8.
- *
- * @doc {heading: 'Util'}
- */
- function decodeString(bytes, encoding = 'utf-8') {
- encoding = encoding || 'utf-8';
- return env().platform.decode(bytes, encoding);
- }
- /**
- * Computes flat index for a given location (multidimentionsal index) in a
- * Tensor/multidimensional array.
- *
- * @param locs Location in the tensor.
- * @param rank Rank of the tensor.
- * @param strides Tensor strides.
- */
- function locToIndex(locs, rank, strides) {
- if (rank === 0) {
- return 0;
- }
- else if (rank === 1) {
- return locs[0];
- }
- let index = locs[locs.length - 1];
- for (let i = 0; i < locs.length - 1; ++i) {
- index += strides[i] * locs[i];
- }
- return index;
- }
- /**
- * Computes the location (multidimensional index) in a tensor/multidimentional
- * array for a given flat index.
- *
- * @param index Index in flat array.
- * @param rank Rank of tensor.
- * @param strides Strides of tensor.
- */
- function indexToLoc(index, rank, strides) {
- if (rank === 0) {
- return [];
- }
- else if (rank === 1) {
- return [index];
- }
- const locs = new Array(rank);
- for (let i = 0; i < locs.length - 1; ++i) {
- locs[i] = Math.floor(index / strides[i]);
- index -= locs[i] * strides[i];
- }
- locs[locs.length - 1] = index;
- return locs;
- }
-
- var util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- shuffle: shuffle,
- clamp: clamp,
- nearestLargerEven: nearestLargerEven,
- sum: sum,
- randUniform: randUniform,
- distSquared: distSquared,
- assert: assert,
- assertShapesMatch: assertShapesMatch,
- assertNonNull: assertNonNull,
- flatten: flatten,
- sizeFromShape: sizeFromShape,
- isScalarShape: isScalarShape,
- arraysEqual: arraysEqual,
- isInt: isInt,
- tanh: tanh,
- sizeToSquarishShape: sizeToSquarishShape,
- createShuffledIndices: createShuffledIndices,
- rightPad: rightPad,
- repeatedTry: repeatedTry,
- inferFromImplicitShape: inferFromImplicitShape,
- parseAxisParam: parseAxisParam,
- squeezeShape: squeezeShape,
- getTypedArrayFromDType: getTypedArrayFromDType,
- getArrayFromDType: getArrayFromDType,
- checkConversionForErrors: checkConversionForErrors,
- isValidDtype: isValidDtype,
- hasEncodingLoss: hasEncodingLoss,
- isTypedArray: isTypedArray,
- bytesPerElement: bytesPerElement,
- bytesFromStringArray: bytesFromStringArray,
- isString: isString,
- isBoolean: isBoolean,
- isNumber: isNumber,
- inferDtype: inferDtype,
- isFunction: isFunction,
- nearestDivisor: nearestDivisor,
- computeStrides: computeStrides,
- createScalarValue: createScalarValue,
- toTypedArray: toTypedArray,
- toNestedArray: toNestedArray,
- makeOnesTypedArray: makeOnesTypedArray,
- makeZerosTypedArray: makeZerosTypedArray,
- makeZerosNestedTypedArray: makeZerosNestedTypedArray,
- now: now,
- assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
- fetch: fetch$1,
- encodeString: encodeString,
- decodeString: decodeString,
- locToIndex: locToIndex,
- indexToLoc: indexToLoc
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class Profiler {
- constructor(backendTimer, logger) {
- this.backendTimer = backendTimer;
- this.logger = logger;
- if (logger == null) {
- this.logger = new Logger();
- }
- }
- profileKernel(kernelName, inputs, f) {
- let outputs;
- const holdResultWrapperFn = () => {
- outputs = f();
- };
- const timer = this.backendTimer.time(holdResultWrapperFn);
- for (let i = 0; i < outputs.length; i++) {
- const output = outputs[i];
- // Dangling promise here because we don't want to propagate up
- // asynchronicity.
- output.data().then(tensorVals => {
- checkComputationForErrors(tensorVals, output.dtype, kernelName);
- });
- }
- const kernelProfile = {
- kernelName,
- outputs,
- inputs,
- timeMs: timer.then(timing => timing.kernelMs),
- extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ?
- timing.getExtraProfileInfo() :
- '')
- };
- return kernelProfile;
- }
- logKernelProfile(kernelProfile) {
- const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile;
- outputs.forEach(result => {
- Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => {
- this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
- });
- });
- }
- }
- function checkComputationForErrors(vals, dtype, kernelName) {
- if (dtype !== 'float32') {
- // Only floating point computations will generate NaN values
- return false;
- }
- for (let i = 0; i < vals.length; i++) {
- const num = vals[i];
- if (isNaN(num) || !isFinite(num)) {
- // Throwing custom exception so behavior is testable.
- console.warn(`Found ${num} in the result of '${kernelName}'`);
- return true;
- }
- }
- return false;
- }
- class Logger {
- logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
- const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
- timeMs['error'];
- const paddedName = rightPad(name, 25);
- const rank = result.rank;
- const size = result.size;
- const shape = rightPad(result.shape.toString(), 14);
- let inputShapesDescription = '';
- for (const name in inputs) {
- const input = inputs[name];
- if (input != null) {
- // The input might be a non-tensor (e.g HTMLImageElement), in which case
- // we claim the output shape as input shape.
- const inputShape = input.shape || result.shape;
- const inputRank = inputShape.length;
- inputShapesDescription +=
- `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
- }
- }
- console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes a list of TapeNodes that connect x to y, filtering everything else
- * out and preserving the order of the original tape elements.
- *
- * @param tape The tape elements to filter.
- * @param xs The input Tensors.
- * @param y The output Tensor.
- */
- function getFilteredNodesXToY(tape, xs, y) {
- // Forward pass to compute all the nodes and Tensors that are transitively a
- // function of x.
- const tensorsFromX = {};
- const nodesFromX = {};
- for (let i = 0; i < xs.length; i++) {
- tensorsFromX[xs[i].id] = true;
- }
- for (let i = 0; i < tape.length; i++) {
- const node = tape[i];
- const nodeInputs = node.inputs;
- for (const inputName in nodeInputs) {
- const input = nodeInputs[inputName];
- let anyInputFromX = false;
- for (let j = 0; j < xs.length; j++) {
- if (tensorsFromX[input.id]) {
- node.outputs.forEach(output => tensorsFromX[output.id] = true);
- anyInputFromX = true;
- nodesFromX[node.id] = true;
- break;
- }
- }
- if (anyInputFromX) {
- break;
- }
- }
- }
- // Backward pass to find all of the nodes and Tensors that lead to y.
- const tensorsLeadToY = {};
- tensorsLeadToY[y.id] = true;
- const nodesToY = {};
- for (let i = tape.length - 1; i >= 0; i--) {
- const node = tape[i];
- const nodeInputs = node.inputs;
- // If any of the outputs lead to y, mark all of the inputs as leading to y.
- for (let j = 0; j < node.outputs.length; j++) {
- if (tensorsLeadToY[node.outputs[j].id]) {
- for (const inputName in nodeInputs) {
- tensorsLeadToY[nodeInputs[inputName].id] = true;
- nodesToY[node.id] = true;
- }
- break;
- }
- }
- }
- // Return the paths that come from x and lead to y.
- const filteredTape = [];
- for (let i = 0; i < tape.length; i++) {
- const node = tape[i];
- if (nodesFromX[node.id] && nodesToY[node.id]) {
- // Prune the inputs from the node that aren't a function of x.
- const prunedInputs = {};
- for (const inputName in node.inputs) {
- const nodeInput = node.inputs[inputName];
- if (tensorsFromX[nodeInput.id]) {
- prunedInputs[inputName] = nodeInput;
- }
- }
- // Copy the node and overwrite inputsAndArgs to the pruned version.
- const prunedNode = Object.assign({}, node);
- prunedNode.inputs = prunedInputs;
- prunedNode.outputs = node.outputs;
- filteredTape.push(prunedNode);
- }
- }
- return filteredTape;
- }
- /**
- * Backpropagate gradients through the filtered TapeNodes.
- *
- * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
- * is mutated by this method.
- * @param filteredTape The filtered TapeNodes to backprop through.
- */
- function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
- // Walk the tape backward and keep a map of Tensor to its gradient.
- for (let i = filteredTape.length - 1; i >= 0; i--) {
- const node = filteredTape[i];
- const dys = [];
- node.outputs.forEach(o => {
- const gradTensor = tensorAccumulatedGradientMap[o.id];
- if (gradTensor != null) {
- dys.push(gradTensor);
- }
- else {
- // This particular output is not in the back-propagation subgraph, so it
- // does not affect the final output, thus we put null for its dy.
- dys.push(null);
- }
- });
- if (node.gradient == null) {
- throw new Error(`Cannot compute gradient: gradient function not found ` +
- `for ${node.kernelName}.`);
- }
- // Backprop dy through this node and accumulate gradients over the inputs.
- const inputGradients = node.gradient(dys);
- for (const inputName in node.inputs) {
- if (!(inputName in inputGradients)) {
- throw new Error(`Cannot backprop through input ${inputName}. ` +
- `Available gradients found: ${Object.keys(inputGradients)}.`);
- }
- // Call the gradient function.
- const dx = tidy(() => inputGradients[inputName]());
- if (dx.dtype !== 'float32') {
- throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
- `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
- }
- const x = node.inputs[inputName];
- if (!arraysEqual(dx.shape, x.shape)) {
- throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
- `'${inputName}' has shape '${dx.shape}', which does not match ` +
- `the shape of the input '${x.shape}'`);
- }
- if (tensorAccumulatedGradientMap[x.id] == null) {
- tensorAccumulatedGradientMap[x.id] = dx;
- }
- else {
- const curGradient = tensorAccumulatedGradientMap[x.id];
- tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
- curGradient.dispose();
- }
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Maximum number of values before we decide to show ellipsis.
- const FORMAT_LIMIT_NUM_VALS = 20;
- // Number of first and last values to show when displaying a, b,...,y, z.
- const FORMAT_NUM_FIRST_LAST_VALS = 3;
- // Number of significant digits to show.
- const FORMAT_NUM_SIG_DIGITS = 7;
- function tensorToString(vals, shape, dtype, verbose) {
- const strides = computeStrides(shape);
- const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
- const rank = shape.length;
- const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
- const lines = ['Tensor'];
- if (verbose) {
- lines.push(` dtype: ${dtype}`);
- lines.push(` rank: ${rank}`);
- lines.push(` shape: [${shape}]`);
- lines.push(` values:`);
- }
- lines.push(valsLines.map(l => ' ' + l).join('\n'));
- return lines.join('\n');
- }
- function computeMaxSizePerColumn(vals, shape, dtype, strides) {
- const n = sizeFromShape(shape);
- const numCols = strides[strides.length - 1];
- const padPerCol = new Array(numCols).fill(0);
- const rank = shape.length;
- const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
- if (rank > 1) {
- for (let row = 0; row < n / numCols; row++) {
- const offset = row * numCols;
- for (let j = 0; j < numCols; j++) {
- padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
- }
- }
- }
- return padPerCol;
- }
- function valToString(val, pad, dtype) {
- let valStr;
- if (Array.isArray(val)) {
- valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
- `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
- }
- else if (isString(val)) {
- valStr = `'${val}'`;
- }
- else if (dtype === 'bool') {
- valStr = boolNumToString(val);
- }
- else {
- valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
- }
- return rightPad(valStr, pad);
- }
- function boolNumToString(v) {
- return v === 0 ? 'false' : 'true';
- }
- function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
- const storagePerElement = dtype === 'complex64' ? 2 : 1;
- const size = shape[0];
- const rank = shape.length;
- if (rank === 0) {
- if (dtype === 'complex64') {
- const complexTuple = createComplexTuples(vals);
- return [valToString(complexTuple[0], 0, dtype)];
- }
- if (dtype === 'bool') {
- return [boolNumToString(vals[0])];
- }
- return [vals[0].toString()];
- }
- if (rank === 1) {
- if (size > FORMAT_LIMIT_NUM_VALS) {
- const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
- let firstVals = Array.from(vals.slice(0, firstValsSize));
- let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
- if (dtype === 'complex64') {
- firstVals = createComplexTuples(firstVals);
- lastVals = createComplexTuples(lastVals);
- }
- return [
- '[' +
- firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
- .join(', ') +
- ', ..., ' +
- lastVals
- .map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
- .join(', ') +
- ']'
- ];
- }
- const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
- Array.from(vals);
- return [
- '[' +
- displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
- .join(', ') +
- ']'
- ];
- }
- // The array is rank 2 or more.
- const subshape = shape.slice(1);
- const substrides = strides.slice(1);
- const stride = strides[0] * storagePerElement;
- const lines = [];
- if (size > FORMAT_LIMIT_NUM_VALS) {
- for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
- const start = i * stride;
- const end = start + stride;
- lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */));
- }
- lines.push('...');
- for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
- const start = i * stride;
- const end = start + stride;
- lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
- }
- }
- else {
- for (let i = 0; i < size; i++) {
- const start = i * stride;
- const end = start + stride;
- lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
- }
- }
- const sep = rank === 2 ? ',' : '';
- lines[0] = '[' + lines[0] + sep;
- for (let i = 1; i < lines.length - 1; i++) {
- lines[i] = ' ' + lines[i] + sep;
- }
- let newLineSep = ',\n';
- for (let i = 2; i < rank; i++) {
- newLineSep += '\n';
- }
- lines[lines.length - 1] =
- ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
- return lines;
- }
- function createComplexTuples(vals) {
- const complexTuples = [];
- for (let i = 0; i < vals.length; i += 2) {
- complexTuples.push([vals[i], vals[i + 1]]);
- }
- return complexTuples;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * A mutable object, similar to `tf.Tensor`, that allows users to set values
- * at locations before converting to an immutable `tf.Tensor`.
- *
- * See `tf.buffer` for creating a tensor buffer.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- class TensorBuffer {
- constructor(shape, dtype, values) {
- this.dtype = dtype;
- this.shape = shape.slice();
- this.size = sizeFromShape(shape);
- if (values != null) {
- const n = values.length;
- assert(n === this.size, () => `Length of values '${n}' does not match the size ` +
- `inferred by the shape '${this.size}'.`);
- }
- if (dtype === 'complex64') {
- throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
- `a TensorBuffer for the real and imaginary parts separately and ` +
- `call tf.complex(real, imag).`);
- }
- this.values = values || getArrayFromDType(dtype, this.size);
- this.strides = computeStrides(shape);
- }
- /**
- * Sets a value in the buffer at a given location.
- *
- * @param value The value to set.
- * @param locs The location indices.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- set(value, ...locs) {
- if (locs.length === 0) {
- locs = [0];
- }
- assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
- `match the rank (${this.rank})`);
- const index = this.locToIndex(locs);
- this.values[index] = value;
- }
- /**
- * Returns the value in the buffer at the provided location.
- *
- * @param locs The location indices.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- get(...locs) {
- if (locs.length === 0) {
- locs = [0];
- }
- let i = 0;
- for (const loc of locs) {
- if (loc < 0 || loc >= this.shape[i]) {
- const msg = `Requested out of range element at ${locs}. ` +
- ` Buffer shape=${this.shape}`;
- throw new Error(msg);
- }
- i++;
- }
- let index = locs[locs.length - 1];
- for (let i = 0; i < locs.length - 1; ++i) {
- index += this.strides[i] * locs[i];
- }
- return this.values[index];
- }
- locToIndex(locs) {
- if (this.rank === 0) {
- return 0;
- }
- else if (this.rank === 1) {
- return locs[0];
- }
- let index = locs[locs.length - 1];
- for (let i = 0; i < locs.length - 1; ++i) {
- index += this.strides[i] * locs[i];
- }
- return index;
- }
- indexToLoc(index) {
- if (this.rank === 0) {
- return [];
- }
- else if (this.rank === 1) {
- return [index];
- }
- const locs = new Array(this.shape.length);
- for (let i = 0; i < locs.length - 1; ++i) {
- locs[i] = Math.floor(index / this.strides[i]);
- index -= locs[i] * this.strides[i];
- }
- locs[locs.length - 1] = index;
- return locs;
- }
- get rank() {
- return this.shape.length;
- }
- /**
- * Creates an immutable `tf.Tensor` object from the buffer.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- toTensor() {
- return trackerFn().makeTensor(this.values, this.shape, this.dtype);
- }
- }
- // For tracking tensor creation and disposal.
- let trackerFn = null;
- // Used by chaining methods to call into ops.
- let opHandler = null;
- // Used to warn about deprecated methods.
- let deprecationWarningFn = null;
- // This here so that we can use this method on dev branches and keep the
- // functionality at master.
- // tslint:disable-next-line:no-unused-expression
- [deprecationWarningFn];
- /**
- * An external consumer can register itself as the tensor tracker. This way
- * the Tensor class can notify the tracker for every tensor created and
- * disposed.
- */
- function setTensorTracker(fn) {
- trackerFn = fn;
- }
- /**
- * An external consumer can register itself as the op handler. This way the
- * Tensor class can have chaining methods that call into ops via the op
- * handler.
- */
- function setOpHandler(handler) {
- opHandler = handler;
- }
- /**
- * Sets the deprecation warning function to be used by this file. This way the
- * Tensor class can be a leaf but still use the environment.
- */
- function setDeprecationWarningFn(fn) {
- deprecationWarningFn = fn;
- }
- /**
- * A `tf.Tensor` object represents an immutable, multidimensional array of
- * numbers that has a shape and a data type.
- *
- * See `tf.tensor` for details on how to create a `tf.Tensor`.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- class Tensor {
- constructor(shape, dtype, dataId, id) {
- /** Whether this tensor has been globally kept. */
- this.kept = false;
- this.isDisposedInternal = false;
- this.shape = shape.slice();
- this.dtype = dtype || 'float32';
- this.size = sizeFromShape(shape);
- this.strides = computeStrides(shape);
- this.dataId = dataId;
- this.id = id;
- this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
- }
- get rank() {
- return this.shape.length;
- }
- /**
- * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- async buffer() {
- const vals = await this.data();
- return opHandler.buffer(this.shape, this.dtype, vals);
- }
- /**
- * Returns a `tf.TensorBuffer` that holds the underlying data.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- bufferSync() {
- return opHandler.buffer(this.shape, this.dtype, this.dataSync());
- }
- /**
- * Returns the tensor data as a nested array. The transfer of data is done
- * asynchronously.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- async array() {
- const vals = await this.data();
- return toNestedArray(this.shape, vals);
- }
- /**
- * Returns the tensor data as a nested array. The transfer of data is done
- * synchronously.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- arraySync() {
- return toNestedArray(this.shape, this.dataSync());
- }
- /**
- * Asynchronously downloads the values from the `tf.Tensor`. Returns a
- * promise of `TypedArray` that resolves when the computation has finished.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- async data() {
- this.throwIfDisposed();
- const data = trackerFn().read(this.dataId);
- if (this.dtype === 'string') {
- const bytes = await data;
- try {
- return bytes.map(b => decodeString(b));
- }
- catch {
- throw new Error('Failed to decode the string bytes into utf-8. ' +
- 'To get the original bytes, call tensor.bytes().');
- }
- }
- return data;
- }
- /**
- * Synchronously downloads the values from the `tf.Tensor`. This blocks the
- * UI thread until the values are ready, which can cause performance issues.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- dataSync() {
- this.throwIfDisposed();
- const data = trackerFn().readSync(this.dataId);
- if (this.dtype === 'string') {
- try {
- return data.map(b => decodeString(b));
- }
- catch {
- throw new Error('Failed to decode the string bytes into utf-8. ' +
- 'To get the original bytes, call tensor.bytes().');
- }
- }
- return data;
- }
- /** Returns the underlying bytes of the tensor's data. */
- async bytes() {
- this.throwIfDisposed();
- const data = await trackerFn().read(this.dataId);
- if (this.dtype === 'string') {
- return data;
- }
- else {
- return new Uint8Array(data.buffer);
- }
- }
- /**
- * Disposes `tf.Tensor` from memory.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- dispose() {
- if (this.isDisposed) {
- return;
- }
- trackerFn().disposeTensor(this);
- this.isDisposedInternal = true;
- }
- get isDisposed() {
- return this.isDisposedInternal;
- }
- throwIfDisposed() {
- if (this.isDisposed) {
- throw new Error(`Tensor is disposed.`);
- }
- }
- /**
- * Prints the `tf.Tensor`. See `tf.print` for details.
- *
- * @param verbose Whether to print verbose information about the tensor,
- * including dtype and size.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- print(verbose = false) {
- return opHandler.print(this, verbose);
- }
- /**
- * Returns a copy of the tensor. See `tf.clone` for details.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- clone() {
- this.throwIfDisposed();
- return opHandler.clone(this);
- }
- /**
- * Returns a human-readable description of the tensor. Useful for logging.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- toString(verbose = false) {
- const vals = this.dataSync();
- return tensorToString(vals, this.shape, this.dtype, verbose);
- }
- cast(dtype) {
- this.throwIfDisposed();
- return opHandler.cast(this, dtype);
- }
- variable(trainable = true, name, dtype) {
- this.throwIfDisposed();
- return trackerFn().makeVariable(this, trainable, name, dtype);
- }
- }
- Object.defineProperty(Tensor, Symbol.hasInstance, {
- value: (instance) => {
- // Implementation note: we should use properties of the object that will be
- // defined before the constructor body has finished executing (methods).
- // This is because when this code is transpiled by babel, babel will call
- // classCallCheck before the constructor body is run.
- // See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
- return !!instance && instance.data != null && instance.dataSync != null &&
- instance.throwIfDisposed != null;
- }
- });
- /**
- * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- class Variable extends Tensor {
- constructor(initialValue, trainable, name, tensorId) {
- super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
- this.trainable = trainable;
- this.name = name;
- }
- /**
- * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
- * the same shape and dtype as the old `tf.Tensor`.
- *
- * @param newValue New tensor to be assigned to this variable.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- assign(newValue) {
- if (newValue.dtype !== this.dtype) {
- throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
- `previous value (${this.dtype}) must match`);
- }
- if (!arraysEqual(newValue.shape, this.shape)) {
- throw new Error(`shape of the new value (${newValue.shape}) and ` +
- `previous value (${this.shape}) must match`);
- }
- trackerFn().disposeTensor(this);
- this.dataId = newValue.dataId;
- trackerFn().incRef(this, null /* backend */);
- }
- dispose() {
- trackerFn().disposeVariable(this);
- this.isDisposedInternal = true;
- }
- }
- Object.defineProperty(Variable, Symbol.hasInstance, {
- value: (instance) => {
- return instance instanceof Tensor && instance.assign != null &&
- instance.assign instanceof Function;
- }
- });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- (function (Rank) {
- Rank["R0"] = "R0";
- Rank["R1"] = "R1";
- Rank["R2"] = "R2";
- Rank["R3"] = "R3";
- Rank["R4"] = "R4";
- Rank["R5"] = "R5";
- Rank["R6"] = "R6";
- })(exports.Rank || (exports.Rank = {}));
- // Looks for upcasting types. Used, for example, in operations with mixed dtype
- // inputs.
- var UpcastInt32AndMap;
- (function (UpcastInt32AndMap) {
- UpcastInt32AndMap["float32"] = "float32";
- UpcastInt32AndMap["int32"] = "int32";
- UpcastInt32AndMap["bool"] = "int32";
- UpcastInt32AndMap["complex64"] = "complex64";
- })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
- var UpcastBoolAndMap;
- (function (UpcastBoolAndMap) {
- UpcastBoolAndMap["float32"] = "float32";
- UpcastBoolAndMap["int32"] = "int32";
- UpcastBoolAndMap["bool"] = "bool";
- UpcastBoolAndMap["complex64"] = "complex64";
- })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
- var UpcastFloat32AndMap;
- (function (UpcastFloat32AndMap) {
- UpcastFloat32AndMap["float32"] = "float32";
- UpcastFloat32AndMap["int32"] = "float32";
- UpcastFloat32AndMap["bool"] = "float32";
- UpcastFloat32AndMap["complex64"] = "complex64";
- })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
- var UpcastComplex64AndMap;
- (function (UpcastComplex64AndMap) {
- UpcastComplex64AndMap["float32"] = "complex64";
- UpcastComplex64AndMap["int32"] = "complex64";
- UpcastComplex64AndMap["bool"] = "complex64";
- UpcastComplex64AndMap["complex64"] = "complex64";
- })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
- const upcastTypeMap = {
- 'float32': UpcastFloat32AndMap,
- 'int32': UpcastInt32AndMap,
- 'bool': UpcastBoolAndMap,
- 'complex64': UpcastComplex64AndMap
- };
- function upcastType(typeA, typeB) {
- if (typeA === 'string' || typeB === 'string') {
- if (typeA === 'string' && typeB === 'string') {
- return 'string';
- }
- throw new Error(`Can not upcast ${typeA} with ${typeB}`);
- }
- return upcastTypeMap[typeA][typeB];
- }
- /** Returns the output type after summation. */
- function sumOutType(type) {
- return upcastType(type, 'int32');
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function makeTypesMatch(a, b) {
- if (a.dtype === b.dtype) {
- return [a, b];
- }
- const dtype = upcastType(a.dtype, b.dtype);
- return [a.cast(dtype), b.cast(dtype)];
- }
- function assertTypesMatch(a, b) {
- assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` +
- ` second(${b.dtype}) input must match`);
- }
- function isTensorInList(tensor, tensorList) {
- return tensorList.some(x => x.id === tensor.id);
- }
- /**
- * Extracts any `Tensor`s found within the provided object.
- *
- * @param container an object that may be a `Tensor` or may directly contain
- * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
- * is safe to pass any object here, except that `Promise`s are not
- * supported.
- * @returns An array of `Tensors` found within the passed object. If the
- * argument is simply a `Tensor', a list containing that `Tensor` is
- * returned. If the object is not a `Tensor` or does not
- * contain `Tensors`, an empty list is returned.
- */
- function getTensorsInContainer(result) {
- const list = [];
- const seen = new Set();
- walkTensorContainer(result, list, seen);
- return list;
- }
- function walkTensorContainer(container, list, seen) {
- if (container == null) {
- return;
- }
- if (container instanceof Tensor) {
- list.push(container);
- return;
- }
- if (!isIterable(container)) {
- return;
- }
- // Iteration over keys works also for arrays.
- const iterable = container;
- for (const k in iterable) {
- const val = iterable[k];
- if (!seen.has(val)) {
- seen.add(val);
- walkTensorContainer(val, list, seen);
- }
- }
- }
- // tslint:disable-next-line:no-any
- function isIterable(obj) {
- return Array.isArray(obj) || typeof obj === 'object';
- }
-
- var tensor_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- makeTypesMatch: makeTypesMatch,
- assertTypesMatch: assertTypesMatch,
- isTensorInList: isTensorInList,
- getTensorsInContainer: getTensorsInContainer
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class EngineState {
- constructor() {
- // Public since optimizers will use it.
- this.registeredVariables = {};
- this.nextTapeNodeId = 0;
- this.numBytes = 0;
- this.numTensors = 0;
- this.numStringTensors = 0;
- this.numDataBuffers = 0;
- // Number of nested tf.grad() statements when computing higher-order
- // gradients. E.g. `1` for first-order gradients and `2` for second-order
- // gradients. Used to track if the tape should be removed after a backprop.
- this.gradientDepth = 0;
- // Number of nested kernel calls. When kernel depth is greater than 1, we turn
- // off the tape.
- this.kernelDepth = 0;
- this.scopeStack = [];
- /**
- * Keeps track of the number of data moves during a kernel execution. We
- * maintain a stack since kernels can call other kernels, recursively.
- */
- this.numDataMovesStack = [];
- this.nextScopeId = 0;
- this.tensorInfo = new WeakMap();
- this.profiling = false;
- this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null };
- }
- dispose() {
- for (const variableName in this.registeredVariables) {
- this.registeredVariables[variableName].dispose();
- }
- }
- }
- class Engine {
- constructor(ENV) {
- this.ENV = ENV;
- this.registry = {};
- this.registryFactory = {};
- this.pendingBackendInitId = 0;
- this.state = new EngineState();
- }
- async ready() {
- if (this.pendingBackendInit != null) {
- return this.pendingBackendInit.then(() => { });
- }
- if (this.backendInstance != null) {
- return;
- }
- const sortedBackends = this.getSortedBackends();
- for (let i = 0; i < sortedBackends.length; i++) {
- const backendName = sortedBackends[i];
- const success = await this.initializeBackend(backendName).success;
- if (success) {
- await this.setBackend(backendName);
- return;
- }
- }
- throw new Error(`Could not initialize any backends, all backend initializations ` +
- `failed.`);
- }
- get backend() {
- if (this.pendingBackendInit != null) {
- throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
- `sure to await tf.ready() or await tf.setBackend() before calling ` +
- `other methods`);
- }
- if (this.backendInstance == null) {
- const { name, asyncInit } = this.initializeBackendsAndReturnBest();
- if (asyncInit) {
- throw new Error(`The highest priority backend '${name}' has not yet been ` +
- `initialized. Make sure to await tf.ready() or ` +
- `await tf.setBackend() before calling other methods`);
- }
- this.setBackend(name);
- }
- return this.backendInstance;
- }
- backendNames() {
- return Object.keys(this.registryFactory);
- }
- findBackend(backendName) {
- if (!(backendName in this.registry)) {
- // If the backend hasn't been initialized but we have a registry entry for
- // it, initialize it and return it.
- if (backendName in this.registryFactory) {
- const { asyncInit } = this.initializeBackend(backendName);
- if (asyncInit) {
- // Backend is not ready yet.
- return null;
- }
- }
- else {
- return null;
- }
- }
- return this.registry[backendName];
- }
- findBackendFactory(backendName) {
- if (!(backendName in this.registryFactory)) {
- return null;
- }
- return this.registryFactory[backendName].factory;
- }
- registerBackend(backendName, factory, priority = 1) {
- if (backendName in this.registryFactory) {
- console.warn(`${backendName} backend was already registered. ` +
- `Reusing existing backend factory.`);
- return false;
- }
- this.registryFactory[backendName] = { factory, priority };
- return true;
- }
- async setBackend(backendName) {
- if (this.registryFactory[backendName] == null) {
- throw new Error(`Backend name '${backendName}' not found in registry`);
- }
- this.backendName = backendName;
- if (this.registry[backendName] == null) {
- this.backendInstance = null;
- const { success, asyncInit } = this.initializeBackend(backendName);
- const result = asyncInit ? await success : success;
- if (!result) {
- return false;
- }
- }
- this.backendInstance = this.registry[backendName];
- this.setupRegisteredKernels();
- // Reset the profiler.
- this.profiler = new Profiler(this.backendInstance);
- return true;
- }
- setupRegisteredKernels() {
- const kernels = getKernelsForBackend(this.backendName);
- kernels.forEach(kernel => {
- if (kernel.setupFunc != null) {
- kernel.setupFunc(this.backendInstance);
- }
- });
- }
- disposeRegisteredKernels(backendName) {
- const kernels = getKernelsForBackend(backendName);
- kernels.forEach(kernel => {
- if (kernel.disposeFunc != null) {
- kernel.disposeFunc(this.registry[backendName]);
- }
- });
- }
- /**
- * Initializes a backend by looking up the backend name in the factory
- * registry and calling the factory method. Returns a boolean representing
- * whether the initialization of the backend suceeded. Throws an error if
- * there is no backend in the factory registry.
- */
- initializeBackend(backendName) {
- const registryFactoryEntry = this.registryFactory[backendName];
- if (registryFactoryEntry == null) {
- throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
- }
- try {
- const backend = registryFactoryEntry.factory();
- /* Test if the factory returns a promise.
- Done in a more liberal way than
- previous 'Promise.resolve(backend)===backend'
- as we needed to account for custom Promise
- implementations (e.g. Angular) */
- if (backend && !(backend instanceof KernelBackend)
- && typeof backend.then === 'function') {
- const promiseId = ++this.pendingBackendInitId;
- const success = backend
- .then(backendInstance => {
- // Outdated promise. Another backend was set in the meantime.
- if (promiseId < this.pendingBackendInitId) {
- return false;
- }
- this.registry[backendName] = backendInstance;
- this.pendingBackendInit = null;
- return true;
- })
- .catch(err => {
- // Outdated promise. Another backend was set in the meantime.
- if (promiseId < this.pendingBackendInitId) {
- return false;
- }
- this.pendingBackendInit = null;
- console.warn(`Initialization of backend ${backendName} failed`);
- console.warn(err.stack || err.message);
- return false;
- });
- this.pendingBackendInit = success;
- return { success, asyncInit: true };
- }
- else {
- this.registry[backendName] = backend;
- return { success: true, asyncInit: false };
- }
- }
- catch (err) {
- console.warn(`Initialization of backend ${backendName} failed`);
- console.warn(err.stack || err.message);
- return { success: false, asyncInit: false };
- }
- }
- removeBackend(backendName) {
- if (!(backendName in this.registryFactory)) {
- throw new Error(`${backendName} backend not found in registry`);
- }
- if (this.backendName === backendName && this.pendingBackendInit != null) {
- // There is a pending promise of the backend we want to remove. Make it
- // obsolete.
- this.pendingBackendInitId++;
- }
- if (backendName in this.registry) {
- this.disposeRegisteredKernels(backendName);
- this.registry[backendName].dispose();
- delete this.registry[backendName];
- }
- delete this.registryFactory[backendName];
- // Unset the backend if it is active.
- if (this.backendName === backendName) {
- this.pendingBackendInit = null;
- this.backendName = null;
- this.backendInstance = null;
- }
- }
- getSortedBackends() {
- if (Object.keys(this.registryFactory).length === 0) {
- throw new Error('No backend found in registry.');
- }
- return Object.keys(this.registryFactory).sort((a, b) => {
- // Highest priority comes first.
- return this.registryFactory[b].priority -
- this.registryFactory[a].priority;
- });
- }
- initializeBackendsAndReturnBest() {
- const sortedBackends = this.getSortedBackends();
- for (let i = 0; i < sortedBackends.length; i++) {
- const backendName = sortedBackends[i];
- const { success, asyncInit } = this.initializeBackend(backendName);
- if (asyncInit || success) {
- return { name: backendName, asyncInit };
- }
- }
- throw new Error(`Could not initialize any backends, all backend initializations ` +
- `failed.`);
- }
- moveData(backend, dataId) {
- const info = this.state.tensorInfo.get(dataId);
- const srcBackend = info.backend;
- const values = this.readSync(dataId);
- // Delete the tensor from the old backend and move it to the new
- // backend.
- srcBackend.disposeData(dataId);
- info.backend = backend;
- backend.move(dataId, values, info.shape, info.dtype);
- if (this.shouldCheckForMemLeaks()) {
- // Track the number of moves during a kernel execution to correctly
- // detect memory leaks.
- this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
- }
- }
- tidy(nameOrFn, fn) {
- let name = null;
- if (fn == null) {
- // Called with only 1 argument.
- if (typeof nameOrFn !== 'function') {
- throw new Error('Please provide a function to tidy()');
- }
- fn = nameOrFn;
- }
- else {
- // Called with 2 arguments.
- if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
- throw new Error('When calling with two arguments, the first argument ' +
- 'to tidy() must be a string');
- }
- if (typeof fn !== 'function') {
- throw new Error('When calling with two arguments, the 2nd argument ' +
- 'to tidy() must be a function');
- }
- name = nameOrFn;
- // TODO(nsthorat,smilkov): Do operation logging and performance
- // profiling.
- }
- let result;
- return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
- result = fn();
- if (result instanceof Promise) {
- console.error('Cannot return a Promise inside of tidy.');
- }
- return result;
- });
- }
- scopedRun(start, end, f) {
- start();
- try {
- const res = f();
- end();
- return res;
- }
- catch (ex) {
- end();
- throw ex;
- }
- }
- nextTensorId() {
- return Engine.nextTensorId++;
- }
- nextVariableId() {
- return Engine.nextVariableId++;
- }
- /**
- * This method is called instead of the public-facing tensor.clone() when
- * saving a tensor for backwards pass. It makes sure to add the clone
- * operation to the tape regardless of being called inside a kernel
- * execution.
- *
- * This method will go away once all kernels are modularized since we won't
- * need to turn off the tape inside runKernel().
- */
- clone(x) {
- const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
- const inputs = { x };
- const grad = (dy) => ({
- x: () => {
- const dtype = 'float32';
- const gradInputs = { x: dy };
- const attrs = { dtype };
- return ENGINE.runKernelFunc(backend => backend.cast(dy, dtype), gradInputs, null /* grad */, Cast, attrs);
- }
- });
- const saved = [];
- this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
- return y;
- }
- /**
- * Execute a kernel with the given name and return the output tensor.
- *
- * @param kernelName The name of the kernel to execute.
- * @param inputs A map of input names to tensors.
- * @param attrs A map of attribute names to their values. An attribute is a
- * primitive (non-tensor) input to the kernel.
- * @param inputsToSave A list of tensors, inputs to save for the backprop
- * computation.
- * @param outputsToSave A list of booleans, specifying which output to save
- * for the backprop computation. These are booleans since the output
- * tensors are not visible to the user.
- */
- runKernel(kernelName, inputs, attrs, inputsToSave, outputsToSave) {
- const forwardFunc = null;
- const backwardsFunc = null;
- // Call runKernel as a stop-gap until we modularize all kernels.
- // Once we modularize all kernels, we will remove the existing
- // `runKernelFunc`.
- return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave);
- }
- shouldCheckForMemLeaks() {
- return this.ENV.getBool('IS_TEST');
- }
- checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
- const numDataIdsAfter = this.backend.numDataIds();
- // Count the number of data ids associated with the result of the kernel.
- let numOutputDataIds = 0;
- outInfos.forEach(info => {
- // Complex numbers allocate 3 data ids, one for 'real', one for
- // 'imaginary', and one for the container that holds the former two.
- numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
- });
- // Account for the number of moves during kernel execution. A "data move"
- // can happen in the middle of a kernel execution, placing a new (key,value)
- // pair in the data storage. Since data moves have net zero effect (we
- // always remove the data from the old backend), we have to cancel them out
- // when detecting memory leaks.
- const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
- const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
- if (dataIdsLeaked > 0) {
- throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
- `(${dataIdsLeaked} data ids) after running '${kernelName}'`);
- }
- }
- /**
- * @deprecated Use `runKernel` for newly added kernels. Keep using this method
- * only for kernels that are not yet fully modularized.
- */
- runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) {
- let outputs;
- let saved = [];
- const isTapeOn = this.isTapeOn();
- if (kernelName == null) {
- kernelName =
- this.state.activeScope != null ? this.state.activeScope.name : '';
- }
- const startingBytecount = this.state.numBytes;
- const startingNumTensors = this.state.numTensors;
- if (this.shouldCheckForMemLeaks()) {
- this.state.numDataMovesStack.push(0);
- }
- let kernelFunc;
- const kernel = getKernel(kernelName, this.backendName);
- let out;
- if (kernel != null) {
- kernelFunc = () => {
- const numDataIdsBefore = this.backend.numDataIds();
- out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
- const outInfos = Array.isArray(out) ? out : [out];
- if (this.shouldCheckForMemLeaks()) {
- this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
- }
- const outTensors = outInfos.map(({ dataId, shape, dtype }) => this.makeTensorFromDataId(dataId, shape, dtype));
- // Save the inputs and outputs.
- // Do not save unless we are recording to the tape. Otherwise it would
- // cause a mem leak since we would never run backprop, which disposes
- // the kept tensors.
- if (isTapeOn) {
- let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
- if (tensorsToSave == null) {
- // Fallback for ops that call runKernelFunc and pass in
- // inputsToSave and outputsToSave. Currently this is the set of ops
- // with kernel support in the WASM backend. Once those ops and
- // respective gradients are modularised we can remove this path.
- if (outputsToSave == null) {
- outputsToSave = [];
- }
- const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
- tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
- }
- saved = this.saveTensorsForBackwardMode(tensorsToSave);
- }
- return outTensors;
- };
- }
- else {
- const saveFunc = (tensors) => {
- // Do not save unless we are recording to the tape. Otherwise it would
- // cause a mem leak since we would never run backprop, which disposes
- // the kept tensors.
- if (!isTapeOn) {
- return;
- }
- saved = tensors.map(tensor => this.keep(this.clone(tensor)));
- };
- kernelFunc = () => {
- const numDataIdsBefore = this.backend.numDataIds();
- out = this.tidy(() => forwardFunc(this.backend, saveFunc));
- const outs = (Array.isArray(out) ? out : [out]);
- if (this.shouldCheckForMemLeaks()) {
- this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
- }
- return outs;
- };
- }
- // Stop recording to a tape when running a kernel.
- let kernelProfile;
- this.scopedRun(() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
- if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
- outputs = kernelFunc();
- }
- else {
- kernelProfile = this.profiler.profileKernel(kernelName, inputs, () => kernelFunc());
- if (this.ENV.getBool('DEBUG')) {
- this.profiler.logKernelProfile(kernelProfile);
- }
- outputs = kernelProfile.outputs;
- }
- });
- if (isTapeOn) {
- this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
- }
- if (this.state.profiling) {
- this.state.activeProfile.kernels.push({
- name: kernelName,
- bytesAdded: this.state.numBytes - startingBytecount,
- totalBytesSnapshot: this.state.numBytes,
- tensorsAdded: this.state.numTensors - startingNumTensors,
- totalTensorsSnapshot: this.state.numTensors,
- inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
- outputShapes: outputs.map(item => item.shape),
- kernelTimeMs: kernelProfile.timeMs,
- extraInfo: kernelProfile.extraInfo
- });
- }
- return (Array.isArray(out) ? outputs : outputs[0]);
- }
- /**
- * Saves tensors used in forward mode for use in backward mode.
- *
- * @param tensors the list of tensors to save.
- */
- saveTensorsForBackwardMode(tensors) {
- const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
- return saved;
- }
- /**
- * Returns a list of tensors to save for a given gradient calculation.
- *
- * Returns undefined if their is no registered gradient for this kernel in the
- * gradient registry.
- *
- * @param kernelName name of kernel to look up gradient for.
- * @param inputs a map of input tensors.
- * @param outputs an array of output tensors from forward mode of kernel.
- */
- getTensorsForGradient(kernelName, inputs, outputs) {
- const gradConfig = getGradient(kernelName);
- if (gradConfig != null) {
- const inputsToSave = gradConfig.inputsToSave || [];
- const outputsToSave = gradConfig.outputsToSave || [];
- // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
- // specified in inputsToSave will be saved.
- let inputTensorsToSave;
- if (gradConfig.saveAllInputs) {
- assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
- inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
- }
- else {
- inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
- }
- const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
- return inputTensorsToSave.concat(outputTensorsToSave);
- }
- // TODO(yassogba) throw exception here once all runkernelFunc calls with
- // inputsToSave/outputsToSave are removed
- return null;
- }
- /**
- * Internal method used by public APIs for tensor creation. Makes a new
- * tensor with the provided shape, dtype and values. It always
- * creates a new data id and writes the values to the underlying backend.
- */
- makeTensor(values, shape, dtype, backend) {
- if (values == null) {
- throw new Error('Values passed to engine.makeTensor() are null');
- }
- dtype = dtype || 'float32';
- backend = backend || this.backend;
- let backendVals = values;
- if (dtype === 'string' && isString(values[0])) {
- backendVals = values.map(d => encodeString(d));
- }
- const dataId = backend.write(backendVals, shape, dtype);
- const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
- this.incRef(t, backend);
- // Count bytes for string tensors.
- if (dtype === 'string') {
- const info = this.state.tensorInfo.get(dataId);
- const newBytes = bytesFromStringArray(backendVals);
- this.state.numBytes += newBytes - info.bytes;
- info.bytes = newBytes;
- }
- return t;
- }
- /**
- * Internal method used by backends. Makes a new tensor
- * that is a wrapper around an existing data id. It doesn't create
- * a new data id, only increments the ref count used in memory tracking.
- */
- makeTensorFromDataId(dataId, shape, dtype, backend) {
- dtype = dtype || 'float32';
- const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
- this.incRef(t, backend);
- return t;
- }
- makeVariable(initialValue, trainable = true, name, dtype) {
- name = name || this.nextVariableId().toString();
- if (dtype != null && dtype !== initialValue.dtype) {
- initialValue = initialValue.cast(dtype);
- }
- const v = new Variable(initialValue, trainable, name, this.nextTensorId());
- if (this.state.registeredVariables[v.name] != null) {
- throw new Error(`Variable with name ${v.name} was already registered`);
- }
- this.state.registeredVariables[v.name] = v;
- this.incRef(v, this.backend);
- return v;
- }
- incRef(a, backend) {
- const refCount = this.state.tensorInfo.has(a.dataId) ?
- this.state.tensorInfo.get(a.dataId).refCount :
- 0;
- this.state.numTensors++;
- if (a.dtype === 'string') {
- this.state.numStringTensors++;
- }
- if (refCount === 0) {
- this.state.numDataBuffers++;
- // Bytes for complex numbers are counted by their components. Bytes for
- // string tensors are counted when writing values.
- let bytes = 0;
- if (a.dtype !== 'complex64' && a.dtype !== 'string') {
- bytes = a.size * bytesPerElement(a.dtype);
- }
- this.state.tensorInfo.set(a.dataId, {
- backend: backend || this.backend,
- dtype: a.dtype,
- shape: a.shape,
- bytes,
- refCount: 0
- });
- this.state.numBytes += bytes;
- }
- this.state.tensorInfo.get(a.dataId).refCount++;
- if (!(a instanceof Variable)) {
- this.track(a);
- }
- }
- disposeTensor(a) {
- if (!this.state.tensorInfo.has(a.dataId)) {
- return;
- }
- this.state.numTensors--;
- if (a.dtype === 'string') {
- this.state.numStringTensors--;
- }
- const info = this.state.tensorInfo.get(a.dataId);
- const refCount = info.refCount;
- if (refCount <= 1) {
- // Don't count bytes for complex numbers as they are counted by their
- // components.
- if (a.dtype !== 'complex64') {
- this.state.numBytes -= info.bytes;
- }
- this.state.numDataBuffers--;
- info.backend.disposeData(a.dataId);
- this.state.tensorInfo.delete(a.dataId);
- }
- else {
- this.state.tensorInfo.get(a.dataId).refCount--;
- }
- // TODO(nsthorat): Construct an error and save the stack trace for
- // debugging when in debug mode. Creating a stack trace is too expensive
- // to do unconditionally.
- }
- disposeVariables() {
- for (const varName in this.state.registeredVariables) {
- const v = this.state.registeredVariables[varName];
- this.disposeVariable(v);
- }
- }
- disposeVariable(v) {
- this.disposeTensor(v);
- if (this.state.registeredVariables[v.name] != null) {
- delete this.state.registeredVariables[v.name];
- }
- }
- memory() {
- const info = this.backend.memory();
- info.numTensors = this.state.numTensors;
- info.numDataBuffers = this.state.numDataBuffers;
- info.numBytes = this.state.numBytes;
- if (this.state.numStringTensors > 0) {
- info.unreliable = true;
- if (info.reasons == null) {
- info.reasons = [];
- }
- info.reasons.push('Memory usage by string tensors is approximate ' +
- '(2 bytes per character)');
- }
- return info;
- }
- async profile(query) {
- this.state.profiling = true;
- const startBytes = this.state.numBytes;
- const startNumTensors = this.state.numTensors;
- this.state.activeProfile.kernels = [];
- this.state.activeProfile.result = await query();
- this.state.profiling = false;
- this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
- this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
- this.state.activeProfile.newTensors =
- this.state.numTensors - startNumTensors;
- for (const kernel of this.state.activeProfile.kernels) {
- kernel.kernelTimeMs = await kernel.kernelTimeMs;
- kernel.extraInfo = await kernel.extraInfo;
- }
- return this.state.activeProfile;
- }
- isTapeOn() {
- return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
- }
- addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
- const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
- const gradConfig = getGradient(kernelName);
- if (gradConfig != null) {
- gradientsFunc = gradConfig.gradFunc;
- }
- if (gradientsFunc != null) {
- tapeNode.gradient = (dys) => {
- // TODO(smilkov): To optimize back-prop, pass dys that are not used in
- // the backprop graph to the user as null instead of zeros
- dys = dys.map((dy, i) => {
- if (dy == null) {
- const output = outputs[i];
- const vals = makeZerosTypedArray(output.size, output.dtype);
- return this.makeTensor(vals, output.shape, output.dtype);
- }
- return dy;
- });
- // Grad functions of ops with single outputs expect a dy, while ops
- // with multiple outputs expect dys (array of dy).
- return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
- };
- }
- this.state.activeTape.push(tapeNode);
- }
- keep(result) {
- result.kept = true;
- return result;
- }
- startTape() {
- if (this.state.gradientDepth === 0) {
- this.state.activeTape = [];
- }
- this.state.gradientDepth++;
- }
- endTape() {
- this.state.gradientDepth--;
- }
- /**
- * Start a scope. Use this with endScope() to achieve the same functionality
- * as scope() without the need for a function closure.
- */
- startScope(name) {
- const scopeInfo = {
- track: [],
- name: 'unnamed scope',
- id: this.state.nextScopeId++
- };
- if (name) {
- scopeInfo.name = name;
- }
- this.state.scopeStack.push(scopeInfo);
- this.state.activeScope = scopeInfo;
- }
- /**
- * End a scope. Use this with startScope() to achieve the same functionality
- * as scope() without the need for a function closure.
- */
- endScope(result) {
- const tensorsToTrackInParent = getTensorsInContainer(result);
- const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
- // Dispose the arrays tracked in this scope.
- for (let i = 0; i < this.state.activeScope.track.length; i++) {
- const tensor = this.state.activeScope.track[i];
- if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
- tensor.dispose();
- }
- }
- const oldScope = this.state.scopeStack.pop();
- this.state.activeScope = this.state.scopeStack.length === 0 ?
- null :
- this.state.scopeStack[this.state.scopeStack.length - 1];
- // Track the current result in the parent scope.
- tensorsToTrackInParent.forEach(tensor => {
- // Only track the tensor if was allocated in the inner scope and is not
- // globally kept.
- if (!tensor.kept && tensor.scopeId === oldScope.id) {
- this.track(tensor);
- }
- });
- }
- /**
- * Returns gradients of `f` with respect to each of the `xs`. The gradients
- * returned are of the same length as `xs`, but some might be null if `f`
- * was not a function of that `x`. It also takes optional dy to multiply the
- * gradient, which defaults to `1`.
- */
- gradients(f, xs, dy, allowNoGradients = false) {
- assert(xs.length > 0, () => 'gradients() received an empty list of xs.');
- if (dy != null && dy.dtype !== 'float32') {
- throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
- }
- const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
- assert(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
- // Filter out the nodes that don't connect x => y.
- const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
- if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
- throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
- 'that the f you passed encloses all operations that lead from x ' +
- 'to y.');
- }
- return this.tidy('backward', () => {
- const accumulatedGradientMap = {};
- accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
- // Backprop gradients through the filtered nodes.
- backpropagateGradients(accumulatedGradientMap, filteredTape,
- // Pass the tidy function to avoid circular dep with `tape.ts`.
- f => this.tidy(f),
- // Pass an add function to avoide a circular dep with `tape.ts`.
- add);
- const grads = xs.map(x => accumulatedGradientMap[x.id]);
- if (this.state.gradientDepth === 0) {
- // This means that we are not computing higher-order gradients
- // and can clean up the tape.
- this.state.activeTape.forEach(node => {
- for (const tensor of node.saved) {
- tensor.dispose();
- }
- });
- this.state.activeTape = null;
- }
- return { value: y, grads };
- });
- }
- customGrad(f) {
- assert(isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
- return (...inputs) => {
- assert(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
- 'tensors');
- let res;
- const inputMap = {};
- inputs.forEach((input, i) => {
- inputMap[i] = input;
- });
- return this.runKernelFunc((_, save) => {
- res = f(...[...inputs, save]);
- assert(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.value` is a tensor');
- assert(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function.');
- return res.value;
- }, inputMap, (dy, saved) => {
- const gradRes = res.gradFunc(dy, saved);
- const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
- assert(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function that returns ' +
- 'the same number of tensors as inputs passed to f(...).');
- assert(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function that returns ' +
- 'a list of only tensors.');
- const gradMap = {};
- grads.forEach((grad, i) => {
- gradMap[i] = () => grad;
- });
- return gradMap;
- });
- };
- }
- readSync(dataId) {
- // Route the read to the correct backend.
- const info = this.state.tensorInfo.get(dataId);
- return info.backend.readSync(dataId);
- }
- read(dataId) {
- // Route the read to the correct backend.
- const info = this.state.tensorInfo.get(dataId);
- return info.backend.read(dataId);
- }
- async time(query) {
- const start = now();
- const timingInfo = await this.backend.time(query);
- timingInfo.wallMs = now() - start;
- return timingInfo;
- }
- /**
- * Tracks a Tensor in the current scope to be automatically cleaned up
- * when the current scope ends, and returns the value.
- *
- * @param result The Tensor to track in the current scope.
- */
- track(result) {
- if (this.state.activeScope != null) {
- result.scopeId = this.state.activeScope.id;
- this.state.activeScope.track.push(result);
- }
- return result;
- }
- get registeredVariables() {
- return this.state.registeredVariables;
- }
- /**
- * Resets the engine state. Removes all backends but does not remove
- * registered backend factories.
- */
- reset() {
- // Make any pending promise obsolete.
- this.pendingBackendInitId++;
- this.state.dispose();
- this.ENV.reset();
- this.state = new EngineState();
- for (const backendName in this.registry) {
- this.disposeRegisteredKernels(backendName);
- this.registry[backendName].dispose();
- delete this.registry[backendName];
- }
- this.backendName = null;
- this.backendInstance = null;
- this.pendingBackendInit = null;
- }
- }
- Engine.nextTensorId = 0;
- Engine.nextVariableId = 0;
- function ones(shape) {
- const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
- return ENGINE.makeTensor(values, shape, 'float32');
- }
- function getOrMakeEngine() {
- const ns = getGlobalNamespace();
- if (ns._tfengine == null) {
- const environment = new Environment(ns);
- ns._tfengine = new Engine(environment);
- }
- setEnvironmentGlobal(ns._tfengine.ENV);
- // Tell the current tensor interface that the global engine is responsible
- // for tracking.
- setTensorTracker(() => ns._tfengine);
- return ns._tfengine;
- }
- const ENGINE = getOrMakeEngine();
- /**
- * A implementation of the add op for use within engine and tape.
- *
- * This allows us to avoid a circular dependency between add.ts and engine.
- * It is exported to be available in tape tests.
- */
- function add(a, b) {
- // We duplicate Add here to avoid a circular dependency with add.ts.
- const inputs = { a, b };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.add(a, b);
- save([a, b]);
- return res;
- }, inputs, null /* gradient */, Add);
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // tslint:disable-next-line:no-any
- function _isNavigatorDefined() {
- return typeof navigator !== 'undefined' && navigator != null;
- }
- function isMobile() {
- if (_isNavigatorDefined()) {
- // tslint:disable-next-line:no-any
- const a = navigator.userAgent || navigator.vendor || window.opera;
- // tslint:disable-next-line:max-line-length
- return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
- .test(a) ||
- // tslint:disable-next-line:max-line-length
- /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
- .test(a.substr(0, 4));
- }
- return false;
- }
- function isBrowser() {
- return (typeof window !== 'undefined' && window.document != null) ||
- //@ts-ignore
- (typeof WorkerGlobalScope !== 'undefined');
- }
-
- var device_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- isMobile: isMobile,
- isBrowser: isBrowser
- });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ENV = env();
- /**
- * This file contains environment-related flag registrations.
- */
- /** Whether to enable debug mode. */
- ENV.registerFlag('DEBUG', () => false, debugValue => {
- if (debugValue) {
- console.warn('Debugging mode is ON. The output of every math call will ' +
- 'be downloaded to CPU and checked for NaNs. ' +
- 'This significantly impacts performance.');
- }
- });
- /** Whether we are in a browser (as versus, say, node.js) environment. */
- ENV.registerFlag('IS_BROWSER', () => isBrowser());
- /** Whether we are in a browser (as versus, say, node.js) environment. */
- ENV.registerFlag('IS_NODE', () => (typeof process !== 'undefined') &&
- (typeof process.versions !== 'undefined') &&
- (typeof process.versions.node !== 'undefined'));
- /** Whether this browser is Chrome. */
- ENV.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null &&
- navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
- /Google Inc/.test(navigator.vendor));
- /**
- * True when the environment is "production" where we disable safety checks
- * to gain performance.
- */
- ENV.registerFlag('PROD', () => false);
- /**
- * Whether to do sanity checks when inferring a shape from user-provided
- * values, used when creating a new tensor.
- */
- ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV.getBool('DEBUG'));
- /** Whether deprecation warnings are enabled. */
- ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);
- /** True if running unit tests. */
- ENV.registerFlag('IS_TEST', () => false);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function inferShape(val, dtype) {
- let firstElem = val;
- if (isTypedArray(val)) {
- return dtype === 'string' ? [] : [val.length];
- }
- if (!Array.isArray(val)) {
- return []; // Scalar.
- }
- const shape = [];
- while (Array.isArray(firstElem) ||
- isTypedArray(firstElem) && dtype !== 'string') {
- shape.push(firstElem.length);
- firstElem = firstElem[0];
- }
- if (Array.isArray(val) &&
- env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
- deepAssertShapeConsistency(val, shape, []);
- }
- return shape;
- }
- function deepAssertShapeConsistency(val, shape, indices) {
- indices = indices || [];
- if (!(Array.isArray(val)) && !isTypedArray(val)) {
- assert(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
- `but should be an array/TypedArray of ${shape[0]} elements`);
- return;
- }
- assert(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
- `but is an array of ${val.length} elements`);
- assert(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
- `elements, but has ${val.length} elements`);
- const subShape = shape.slice(1);
- for (let i = 0; i < val.length; ++i) {
- deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
- }
- }
- function assertDtype(expectedDtype, actualDType, argName, functionName) {
- if (expectedDtype == null) {
- return;
- }
- if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
- expectedDtype === 'numeric' && actualDType === 'string') {
- throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
- `be ${expectedDtype} tensor, but got ${actualDType} tensor`);
- }
- }
- function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
- if (x instanceof Tensor) {
- assertDtype(parseAsDtype, x.dtype, argName, functionName);
- return x;
- }
- let inferredDtype = inferDtype(x);
- // If the user expects a bool/int/float, use that info to update the
- // inferredDtype when it is not a string.
- if (inferredDtype !== 'string' &&
- ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
- inferredDtype = parseAsDtype;
- }
- assertDtype(parseAsDtype, inferredDtype, argName, functionName);
- if ((x == null) ||
- (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
- typeof x !== 'boolean' && typeof x !== 'string')) {
- const type = x == null ? 'null' : x.constructor.name;
- throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
- `Tensor or TensorLike, but got '${type}'`);
- }
- const inferredShape = inferShape(x, inferredDtype);
- if (!isTypedArray(x) && !Array.isArray(x)) {
- x = [x];
- }
- const skipTypedArray = true;
- const values = inferredDtype !== 'string' ?
- toTypedArray(x, inferredDtype) :
- flatten(x, [], skipTypedArray);
- return ENGINE.makeTensor(values, inferredShape, inferredDtype);
- }
- function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
- if (!Array.isArray(arg)) {
- throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
- '`Tensor[]` or `TensorLike[]`');
- }
- const tensors = arg;
- return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName), parseAsDtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const OP_SCOPE_SUFFIX = '__op';
- /**
- * Used for wrapping functions that perform math operations on
- * Tensors. The function will be wrapped in a named scope that cleans all
- * memory usage after the function is done.
- */
- function op(f) {
- const keys = Object.keys(f);
- if (keys.length !== 1) {
- throw new Error(`Please provide an object with a single key ` +
- `(operation name) mapping to a function. Got an object with ` +
- `${keys.length} keys.`);
- }
- let opName = keys[0];
- const fn = f[opName];
- // Strip the underscore from the end of the function name.
- if (opName.endsWith('_')) {
- opName = opName.substring(0, opName.length - 1);
- }
- // add an __op suffix to distinguish ops from kernels in tf.profile
- opName = opName + OP_SCOPE_SUFFIX;
- // tslint:disable-next-line:no-any
- const f2 = (...args) => {
- ENGINE.startScope(opName);
- try {
- const result = fn(...args);
- if (result instanceof Promise) {
- console.error('Cannot return a Promise inside of tidy.');
- }
- ENGINE.endScope(result);
- return result;
- }
- catch (ex) {
- ENGINE.endScope(null);
- throw ex;
- }
- };
- Object.defineProperty(f2, 'name', { value: opName, configurable: true });
- // tslint:disable-next-line:no-any
- return f2;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts two real numbers to a complex number.
- *
- * Given a tensor `real` representing the real part of a complex number, and a
- * tensor `imag` representing the imaginary part of a complex number, this
- * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
- * where r represents the real part and i represents the imag part.
- *
- * The input tensors real and imag must have the same shape.
- *
- * ```js
- * const real = tf.tensor1d([2.25, 3.25]);
- * const imag = tf.tensor1d([4.75, 5.75]);
- * const complex = tf.complex(real, imag);
- *
- * complex.print();
- * ```
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function complex_(real, imag) {
- const $real = convertToTensor(real, 'real', 'complex');
- const $imag = convertToTensor(imag, 'imag', 'complex');
- assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +
- `must match in call to tf.complex().`);
- const forward = (backend) => {
- return backend.complex($real, $imag);
- };
- const inputs = { real: $real, imag: $imag };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Complex);
- }
- const complex = op({ complex_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** This is shared code across all tensor creation methods. */
- function makeTensor(values, shape, inferredShape, dtype) {
- if (dtype == null) {
- dtype = inferDtype(values);
- }
- if (dtype === 'complex64') {
- throw new Error(`Cannot construct a complex64 tensor directly. ` +
- `Please use tf.complex(real, imag).`);
- }
- if (!isTypedArray(values) && !Array.isArray(values) &&
- typeof values !== 'number' && typeof values !== 'boolean' &&
- typeof values !== 'string') {
- throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
- 'an array of numbers/booleans/strings, or a TypedArray');
- }
- if (shape != null) {
- assertNonNegativeIntegerDimensions(shape);
- const providedSize = sizeFromShape(shape);
- const inferredSize = sizeFromShape(inferredShape);
- assert(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` +
- `${providedSize} values but has ${inferredSize}`);
- for (let i = 0; i < inferredShape.length; ++i) {
- const inferred = inferredShape[i];
- const flatDimsDontMatch = i === inferredShape.length - 1 ?
- inferred !== sizeFromShape(shape.slice(i)) :
- true;
- assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` +
- `(${inferredShape}) does not match the provided ` +
- `shape (${shape}). `);
- }
- }
- if (!isTypedArray(values) && !Array.isArray(values)) {
- values = [values];
- }
- shape = shape || inferredShape;
- values = dtype !== 'string' ?
- toTypedArray(values, dtype) :
- flatten(values, [], true);
- return ENGINE.makeTensor(values, shape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with the provided values, shape and dtype.
- *
- * ```js
- * // Pass an array of values to create a vector.
- * tf.tensor([1, 2, 3, 4]).print();
- * ```
- *
- * ```js
- * // Pass a nested array of values to make a matrix or a higher
- * // dimensional tensor.
- * tf.tensor([[1, 2], [3, 4]]).print();
- * ```
- *
- * ```js
- * // Pass a flat array and specify a shape yourself.
- * tf.tensor([1, 2, 3, 4], [2, 2]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`. If the values are strings,
- * they will be encoded as utf-8 and kept as `Uint8Array[]`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor(values, shape, dtype) {
- const inferredShape = inferShape(values, dtype);
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /* Type definitions for exporting and importing of models. */
- /**
- * A map from Tensor dtype to number of bytes per element of the Tensor.
- */
- const DTYPE_VALUE_SIZE_MAP = {
- 'float32': 4,
- 'float16': 2,
- 'int32': 4,
- 'uint16': 2,
- 'uint8': 1,
- 'bool': 1,
- 'complex64': 8
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Number of bytes reserved for the length of the string. (32bit integer). */
- const NUM_BYTES_STRING_LENGTH = 4;
- /**
- * Encode a map from names to weight values as an ArrayBuffer, along with an
- * `Array` of `WeightsManifestEntry` as specification of the encoded weights.
- *
- * This function does not perform sharding.
- *
- * This function is the reverse of `decodeWeights`.
- *
- * @param tensors A map ("dict") from names to tensors.
- * @param group Group to which the weights belong (optional).
- * @returns A `Promise` of
- * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
- * concatenated.
- * - An `Array` of `WeightManifestEntry`s, carrying information including
- * tensor names, `dtype`s and shapes.
- * @throws Error: on unsupported tensor `dtype`.
- */
- async function encodeWeights(tensors, group) {
- // TODO(adarob, cais): Support quantization.
- const specs = [];
- const dataPromises = [];
- const names = Array.isArray(tensors) ?
- tensors.map(tensor => tensor.name) :
- Object.keys(tensors);
- for (let i = 0; i < names.length; ++i) {
- const name = names[i];
- const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
- if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
- t.dtype !== 'string' && t.dtype !== 'complex64') {
- throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
- }
- const spec = { name, shape: t.shape, dtype: t.dtype };
- if (t.dtype === 'string') {
- const utf8bytes = new Promise(async (resolve) => {
- const vals = await t.bytes();
- const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
- NUM_BYTES_STRING_LENGTH * vals.length;
- const bytes = new Uint8Array(totalNumBytes);
- let offset = 0;
- for (let i = 0; i < vals.length; i++) {
- const val = vals[i];
- const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
- bytes.set(bytesOfLength, offset);
- offset += NUM_BYTES_STRING_LENGTH;
- bytes.set(val, offset);
- offset += val.length;
- }
- resolve(bytes);
- });
- dataPromises.push(utf8bytes);
- }
- else {
- dataPromises.push(t.data());
- }
- if (group != null) {
- spec.group = group;
- }
- specs.push(spec);
- }
- const tensorValues = await Promise.all(dataPromises);
- return { data: concatenateTypedArrays(tensorValues), specs };
- }
- /**
- * Decode flat ArrayBuffer as weights.
- *
- * This function does not handle sharding.
- *
- * This function is the reverse of `encodeWeights`.
- *
- * @param buffer A flat ArrayBuffer carrying the binary values of the tensors
- * concatenated in the order specified in `specs`.
- * @param specs Specifications of the names, dtypes and shapes of the tensors
- * whose value are encoded by `buffer`.
- * @return A map from tensor name to tensor value, with the names corresponding
- * to names in `specs`.
- * @throws Error, if any of the tensors has unsupported dtype.
- */
- function decodeWeights(buffer, specs) {
- // TODO(adarob, cais): Support quantization.
- const out = {};
- let float16Decode;
- let offset = 0;
- for (const spec of specs) {
- const name = spec.name;
- const dtype = spec.dtype;
- const shape = spec.shape;
- const size = sizeFromShape(shape);
- let values;
- if ('quantization' in spec) {
- const quantization = spec.quantization;
- if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
- if (!('min' in quantization && 'scale' in quantization)) {
- throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
- `doesn't have corresponding metadata min and scale.`);
- }
- }
- else if (quantization.dtype === 'float16') {
- if (dtype !== 'float32') {
- throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
- `which only supports weights of type float32 not ${dtype}.`);
- }
- }
- else {
- throw new Error(`Weight ${spec.name} has unknown ` +
- `quantization dtype ${quantization.dtype}. ` +
- `Supported quantization dtypes are: ` +
- `'uint8', 'uint16', and 'float16'.`);
- }
- const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
- const byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor);
- const quantizedArray = (quantization.dtype === 'uint8') ?
- new Uint8Array(byteBuffer) :
- new Uint16Array(byteBuffer);
- if (dtype === 'float32') {
- if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
- values = new Float32Array(quantizedArray.length);
- for (let i = 0; i < quantizedArray.length; i++) {
- const v = quantizedArray[i];
- values[i] = v * quantization.scale + quantization.min;
- }
- }
- else if (quantization.dtype === 'float16') {
- if (float16Decode === undefined) {
- float16Decode = getFloat16Decoder();
- }
- values = float16Decode(quantizedArray);
- }
- else {
- throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
- `for weight type float32.`);
- }
- }
- else if (dtype === 'int32') {
- if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
- throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
- `for weight type int32.`);
- }
- values = new Int32Array(quantizedArray.length);
- for (let i = 0; i < quantizedArray.length; i++) {
- const v = quantizedArray[i];
- values[i] = Math.round(v * quantization.scale + quantization.min);
- }
- }
- else {
- throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
- }
- offset += size * quantizationSizeFactor;
- }
- else if (dtype === 'string') {
- const size = sizeFromShape(spec.shape);
- values = [];
- for (let i = 0; i < size; i++) {
- const byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
- offset += NUM_BYTES_STRING_LENGTH;
- const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
- values.push(bytes);
- offset += byteLength;
- }
- }
- else {
- const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
- const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);
- if (dtype === 'float32') {
- values = new Float32Array(byteBuffer);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(byteBuffer);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(byteBuffer);
- }
- else if (dtype === 'complex64') {
- values = new Float32Array(byteBuffer);
- const real = new Float32Array(values.length / 2);
- const image = new Float32Array(values.length / 2);
- for (let i = 0; i < real.length; i++) {
- real[i] = values[i * 2];
- image[i] = values[i * 2 + 1];
- }
- const realTensor = tensor(real, shape, 'float32');
- const imageTensor = tensor(image, shape, 'float32');
- out[name] = complex(realTensor, imageTensor);
- realTensor.dispose();
- imageTensor.dispose();
- }
- else {
- throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
- }
- offset += size * dtypeFactor;
- }
- if (dtype !== 'complex64') {
- out[name] = tensor(values, shape, dtype);
- }
- }
- return out;
- }
- /**
- * Concatenate TypedArrays into an ArrayBuffer.
- */
- function concatenateTypedArrays(xs) {
- // TODO(adarob, cais): Support quantization.
- if (xs === null) {
- throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
- }
- let totalByteLength = 0;
- // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
- // can have a different byte length from that of the `TypedArray` itself,
- // for example, when the `TypedArray` is created from an offset in an
- // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
- // the `TypedArray` in byte length. If an element of `xs` does not show
- // this property, a new `TypedArray` that satisfy this property will be
- // constructed and pushed into `normalizedXs`.
- const normalizedXs = [];
- xs.forEach((x) => {
- totalByteLength += x.byteLength;
- // tslint:disable:no-any
- normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
- new x.constructor(x));
- if (!(x instanceof Float32Array || x instanceof Int32Array ||
- x instanceof Uint8Array)) {
- throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
- }
- // tslint:enable:no-any
- });
- const y = new Uint8Array(totalByteLength);
- let offset = 0;
- normalizedXs.forEach((x) => {
- y.set(new Uint8Array(x.buffer), offset);
- offset += x.byteLength;
- });
- return y.buffer;
- }
- // Use Buffer on Node.js instead of Blob/atob/btoa
- const useNodeBuffer = typeof Buffer !== 'undefined' &&
- (typeof Blob === 'undefined' || typeof atob === 'undefined' ||
- typeof btoa === 'undefined');
- /**
- * Calculate the byte length of a JavaScript string.
- *
- * Note that a JavaScript string can contain wide characters, therefore the
- * length of the string is not necessarily equal to the byte length.
- *
- * @param str Input string.
- * @returns Byte length.
- */
- function stringByteLength(str) {
- if (useNodeBuffer) {
- return Buffer.byteLength(str);
- }
- return new Blob([str]).size;
- }
- /**
- * Encode an ArrayBuffer as a base64 encoded string.
- *
- * @param buffer `ArrayBuffer` to be converted.
- * @returns A string that base64-encodes `buffer`.
- */
- function arrayBufferToBase64String(buffer) {
- if (useNodeBuffer) {
- return Buffer.from(buffer).toString('base64');
- }
- const buf = new Uint8Array(buffer);
- let s = '';
- for (let i = 0, l = buf.length; i < l; i++) {
- s += String.fromCharCode(buf[i]);
- }
- return btoa(s);
- }
- /**
- * Decode a base64 string as an ArrayBuffer.
- *
- * @param str Base64 string.
- * @returns Decoded `ArrayBuffer`.
- */
- function base64StringToArrayBuffer(str) {
- if (useNodeBuffer) {
- const buf = Buffer.from(str, 'base64');
- return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
- }
- const s = atob(str);
- const buffer = new Uint8Array(s.length);
- for (let i = 0; i < s.length; ++i) {
- buffer.set([s.charCodeAt(i)], i);
- }
- return buffer.buffer;
- }
- /**
- * Concatenate a number of ArrayBuffers into one.
- *
- * @param buffers A number of array buffers to concatenate.
- * @returns Result of concatenating `buffers` in order.
- */
- function concatenateArrayBuffers(buffers) {
- if (buffers.length === 1) {
- return buffers[0];
- }
- let totalByteLength = 0;
- buffers.forEach((buffer) => {
- totalByteLength += buffer.byteLength;
- });
- const temp = new Uint8Array(totalByteLength);
- let offset = 0;
- buffers.forEach((buffer) => {
- temp.set(new Uint8Array(buffer), offset);
- offset += buffer.byteLength;
- });
- return temp.buffer;
- }
- /**
- * Get the basename of a path.
- *
- * Behaves in a way analogous to Linux's basename command.
- *
- * @param path
- */
- function basename(path) {
- const SEPARATOR = '/';
- path = path.trim();
- while (path.endsWith(SEPARATOR)) {
- path = path.slice(0, path.length - 1);
- }
- const items = path.split(SEPARATOR);
- return items[items.length - 1];
- }
- /**
- * Populate ModelArtifactsInfo fields for a model with JSON topology.
- * @param modelArtifacts
- * @returns A ModelArtifactsInfo object.
- */
- function getModelArtifactsInfoForJSON(modelArtifacts) {
- if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
- throw new Error('Expected JSON model topology, received ArrayBuffer.');
- }
- return {
- dateSaved: new Date(),
- modelTopologyType: 'JSON',
- modelTopologyBytes: modelArtifacts.modelTopology == null ?
- 0 :
- stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
- weightSpecsBytes: modelArtifacts.weightSpecs == null ?
- 0 :
- stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
- weightDataBytes: modelArtifacts.weightData == null ?
- 0 :
- modelArtifacts.weightData.byteLength,
- };
- }
- /**
- * Computes mantisa table for casting Float16 to Float32
- * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
- *
- * @returns Uint32Array, 2048 mantissa lookup values.
- */
- function computeFloat16MantisaTable() {
- const convertMantissa = (i) => {
- let m = i << 13;
- let e = 0;
- while ((m & 0x00800000) === 0) {
- e -= 0x00800000;
- m <<= 1;
- }
- m &= ~0x00800000;
- e += 0x38800000;
- return m | e;
- };
- const mantisaTable = new Uint32Array(2048);
- mantisaTable[0] = 0;
- for (let i = 1; i < 1024; i++) {
- mantisaTable[i] = convertMantissa(i);
- }
- for (let i = 1024; i < 2048; i++) {
- mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
- }
- return mantisaTable;
- }
- /**
- * Computes exponent table for casting Float16 to Float32
- * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
- *
- * @returns Uint32Array, 64 exponent lookup values.
- */
- function computeFloat16ExponentTable() {
- const exponentTable = new Uint32Array(64);
- exponentTable[0] = 0;
- exponentTable[31] = 0x47800000;
- exponentTable[32] = 0x80000000;
- exponentTable[63] = 0xc7800000;
- for (let i = 1; i < 31; i++) {
- exponentTable[i] = i << 23;
- }
- for (let i = 33; i < 63; i++) {
- exponentTable[i] = 0x80000000 + ((i - 32) << 23);
- }
- return exponentTable;
- }
- /**
- * Computes offset table for casting Float16 to Float32
- * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
- *
- * @returns Uint32Array, 6d offset values.
- */
- function computeFloat16OffsetTable() {
- const offsetTable = new Uint32Array(64);
- for (let i = 0; i < 64; i++) {
- offsetTable[i] = 1024;
- }
- offsetTable[0] = offsetTable[32] = 0;
- return offsetTable;
- }
- /**
- * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
- * to a Float32Array.
- *
- * @returns Function (buffer: Uint16Array) => Float32Array which decodes
- * the Uint16Array of Float16 bytes to a Float32Array.
- */
- function getFloat16Decoder() {
- // Algorithm is based off of
- // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
- // Cache lookup tables
- const mantisaTable = computeFloat16MantisaTable();
- const exponentTable = computeFloat16ExponentTable();
- const offsetTable = computeFloat16OffsetTable();
- return (quantizedArray) => {
- const buffer = new ArrayBuffer(4 * quantizedArray.length);
- const bufferUint32View = new Uint32Array(buffer);
- for (let index = 0; index < quantizedArray.length; index++) {
- const float16Bits = quantizedArray[index];
- const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
- exponentTable[float16Bits >> 10];
- bufferUint32View[index] = float32Bits;
- }
- return new Float32Array(buffer);
- };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class IORouterRegistry {
- constructor() {
- this.saveRouters = [];
- this.loadRouters = [];
- }
- static getInstance() {
- if (IORouterRegistry.instance == null) {
- IORouterRegistry.instance = new IORouterRegistry();
- }
- return IORouterRegistry.instance;
- }
- /**
- * Register a save-handler router.
- *
- * @param saveRouter A function that maps a URL-like string onto an instance
- * of `IOHandler` with the `save` method defined or `null`.
- */
- static registerSaveRouter(saveRouter) {
- IORouterRegistry.getInstance().saveRouters.push(saveRouter);
- }
- /**
- * Register a load-handler router.
- *
- * @param loadRouter A function that maps a URL-like string onto an instance
- * of `IOHandler` with the `load` method defined or `null`.
- */
- static registerLoadRouter(loadRouter) {
- IORouterRegistry.getInstance().loadRouters.push(loadRouter);
- }
- /**
- * Look up IOHandler for saving, given a URL-like string.
- *
- * @param url
- * @returns If only one match is found, an instance of IOHandler with the
- * `save` method defined. If no match is found, `null`.
- * @throws Error, if more than one match is found.
- */
- static getSaveHandlers(url) {
- return IORouterRegistry.getHandlers(url, 'save');
- }
- /**
- * Look up IOHandler for loading, given a URL-like string.
- *
- * @param url
- * @param loadOptions Optional, custom load options.
- * @returns All valid handlers for `url`, given the currently registered
- * handler routers.
- */
- static getLoadHandlers(url, loadOptions) {
- return IORouterRegistry.getHandlers(url, 'load', loadOptions);
- }
- static getHandlers(url, handlerType, loadOptions) {
- const validHandlers = [];
- const routers = handlerType === 'load' ?
- IORouterRegistry.getInstance().loadRouters :
- IORouterRegistry.getInstance().saveRouters;
- routers.forEach(router => {
- const handler = router(url, loadOptions);
- if (handler !== null) {
- validHandlers.push(handler);
- }
- });
- return validHandlers;
- }
- }
- const registerSaveRouter = (loudRouter) => IORouterRegistry.registerSaveRouter(loudRouter);
- const registerLoadRouter = (loudRouter) => IORouterRegistry.registerLoadRouter(loudRouter);
- const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url);
- const getLoadHandlers = (url, loadOptions) => IORouterRegistry.getLoadHandlers(url, loadOptions);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const DATABASE_NAME = 'tensorflowjs';
- const DATABASE_VERSION = 1;
- // Model data and ModelArtifactsInfo (metadata) are stored in two separate
- // stores for efficient access of the list of stored models and their metadata.
- // 1. The object store for model data: topology, weights and weight manifests.
- const MODEL_STORE_NAME = 'models_store';
- // 2. The object store for ModelArtifactsInfo, including meta-information such
- // as the type of topology (JSON vs binary), byte size of the topology, byte
- // size of the weights, etc.
- const INFO_STORE_NAME = 'model_info_store';
- /**
- * Delete the entire database for tensorflow.js, including the models store.
- */
- async function deleteDatabase() {
- const idbFactory = getIndexedDBFactory();
- return new Promise((resolve, reject) => {
- const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);
- deleteRequest.onsuccess = () => resolve();
- deleteRequest.onerror = error => reject(error);
- });
- }
- function getIndexedDBFactory() {
- if (!env().getBool('IS_BROWSER')) {
- // TODO(cais): Add more info about what IOHandler subtypes are available.
- // Maybe point to a doc page on the web and/or automatically determine
- // the available IOHandlers and print them in the error message.
- throw new Error('Failed to obtain IndexedDB factory because the current environment' +
- 'is not a web browser.');
- }
- // tslint:disable-next-line:no-any
- const theWindow = typeof window === 'undefined' ? self : window;
- const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
- theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
- theWindow.shimIndexedDB;
- if (factory == null) {
- throw new Error('The current browser does not appear to support IndexedDB.');
- }
- return factory;
- }
- function setUpDatabase(openRequest) {
- const db = openRequest.result;
- db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
- db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
- }
- /**
- * IOHandler subclass: Browser IndexedDB.
- *
- * See the doc string of `browserIndexedDB` for more details.
- */
- class BrowserIndexedDB {
- constructor(modelPath) {
- this.indexedDB = getIndexedDBFactory();
- if (modelPath == null || !modelPath) {
- throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
- }
- this.modelPath = modelPath;
- }
- async save(modelArtifacts) {
- // TODO(cais): Support saving GraphDef models.
- if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
- throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
- 'in binary formats yet.');
- }
- return this.databaseAction(this.modelPath, modelArtifacts);
- }
- async load() {
- return this.databaseAction(this.modelPath);
- }
- /**
- * Perform database action to put model artifacts into or read model artifacts
- * from IndexedDB object store.
- *
- * Whether the action is put or get depends on whether `modelArtifacts` is
- * specified. If it is specified, the action will be put; otherwise the action
- * will be get.
- *
- * @param modelPath A unique string path for the model.
- * @param modelArtifacts If specified, it will be the model artifacts to be
- * stored in IndexedDB.
- * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
- * of `ModelArtifacts`, if the action is get.
- */
- databaseAction(modelPath, modelArtifacts) {
- return new Promise((resolve, reject) => {
- const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
- openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
- openRequest.onsuccess = () => {
- const db = openRequest.result;
- if (modelArtifacts == null) {
- // Read model out from object store.
- const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
- const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
- const getRequest = modelStore.get(this.modelPath);
- getRequest.onsuccess = () => {
- if (getRequest.result == null) {
- db.close();
- return reject(new Error(`Cannot find model with path '${this.modelPath}' ` +
- `in IndexedDB.`));
- }
- else {
- resolve(getRequest.result.modelArtifacts);
- }
- };
- getRequest.onerror = error => {
- db.close();
- return reject(getRequest.error);
- };
- modelTx.oncomplete = () => db.close();
- }
- else {
- // Put model into object store.
- const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
- // First, put ModelArtifactsInfo into info store.
- const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
- let infoStore = infoTx.objectStore(INFO_STORE_NAME);
- const putInfoRequest = infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo });
- let modelTx;
- putInfoRequest.onsuccess = () => {
- // Second, put model data into model store.
- modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
- const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
- const putModelRequest = modelStore.put({
- modelPath: this.modelPath,
- modelArtifacts,
- modelArtifactsInfo
- });
- putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo });
- putModelRequest.onerror = error => {
- // If the put-model request fails, roll back the info entry as
- // well.
- infoStore = infoTx.objectStore(INFO_STORE_NAME);
- const deleteInfoRequest = infoStore.delete(this.modelPath);
- deleteInfoRequest.onsuccess = () => {
- db.close();
- return reject(putModelRequest.error);
- };
- deleteInfoRequest.onerror = error => {
- db.close();
- return reject(putModelRequest.error);
- };
- };
- };
- putInfoRequest.onerror = error => {
- db.close();
- return reject(putInfoRequest.error);
- };
- infoTx.oncomplete = () => {
- if (modelTx == null) {
- db.close();
- }
- else {
- modelTx.oncomplete = () => db.close();
- }
- };
- }
- };
- openRequest.onerror = error => reject(openRequest.error);
- });
- }
- }
- BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
- const indexedDBRouter = (url) => {
- if (!env().getBool('IS_BROWSER')) {
- return null;
- }
- else {
- if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
- return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
- }
- else {
- return null;
- }
- }
- };
- IORouterRegistry.registerSaveRouter(indexedDBRouter);
- IORouterRegistry.registerLoadRouter(indexedDBRouter);
- /**
- * Creates a browser IndexedDB IOHandler for saving and loading models.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(
- * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
- *
- * const saveResult = await model.save('indexeddb://MyModel'));
- * console.log(saveResult);
- * ```
- *
- * @param modelPath A unique identifier for the model to be saved. Must be a
- * non-empty string.
- * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`),
- * which can be used with, e.g., `tf.Model.save`.
- */
- function browserIndexedDB(modelPath) {
- return new BrowserIndexedDB(modelPath);
- }
- function maybeStripScheme(key) {
- return key.startsWith(BrowserIndexedDB.URL_SCHEME) ?
- key.slice(BrowserIndexedDB.URL_SCHEME.length) :
- key;
- }
- class BrowserIndexedDBManager {
- constructor() {
- this.indexedDB = getIndexedDBFactory();
- }
- async listModels() {
- return new Promise((resolve, reject) => {
- const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
- openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
- openRequest.onsuccess = () => {
- const db = openRequest.result;
- const tx = db.transaction(INFO_STORE_NAME, 'readonly');
- const store = tx.objectStore(INFO_STORE_NAME);
- // tslint:disable:max-line-length
- // Need to cast `store` as `any` here because TypeScript's DOM
- // library does not have the `getAll()` method even though the
- // method is supported in the latest version of most mainstream
- // browsers:
- // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
- // tslint:enable:max-line-length
- // tslint:disable-next-line:no-any
- const getAllInfoRequest = store.getAll();
- getAllInfoRequest.onsuccess = () => {
- const out = {};
- for (const item of getAllInfoRequest.result) {
- out[item.modelPath] = item.modelArtifactsInfo;
- }
- resolve(out);
- };
- getAllInfoRequest.onerror = error => {
- db.close();
- return reject(getAllInfoRequest.error);
- };
- tx.oncomplete = () => db.close();
- };
- openRequest.onerror = error => reject(openRequest.error);
- });
- }
- async removeModel(path) {
- path = maybeStripScheme(path);
- return new Promise((resolve, reject) => {
- const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
- openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
- openRequest.onsuccess = () => {
- const db = openRequest.result;
- const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
- const infoStore = infoTx.objectStore(INFO_STORE_NAME);
- const getInfoRequest = infoStore.get(path);
- let modelTx;
- getInfoRequest.onsuccess = () => {
- if (getInfoRequest.result == null) {
- db.close();
- return reject(new Error(`Cannot find model with path '${path}' ` +
- `in IndexedDB.`));
- }
- else {
- // First, delete the entry in the info store.
- const deleteInfoRequest = infoStore.delete(path);
- const deleteModelData = () => {
- // Second, delete the entry in the model store.
- modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
- const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
- const deleteModelRequest = modelStore.delete(path);
- deleteModelRequest.onsuccess = () => resolve(getInfoRequest.result.modelArtifactsInfo);
- deleteModelRequest.onerror = error => reject(getInfoRequest.error);
- };
- // Proceed with deleting model data regardless of whether deletion
- // of info data succeeds or not.
- deleteInfoRequest.onsuccess = deleteModelData;
- deleteInfoRequest.onerror = error => {
- deleteModelData();
- db.close();
- return reject(getInfoRequest.error);
- };
- }
- };
- getInfoRequest.onerror = error => {
- db.close();
- return reject(getInfoRequest.error);
- };
- infoTx.oncomplete = () => {
- if (modelTx == null) {
- db.close();
- }
- else {
- modelTx.oncomplete = () => db.close();
- }
- };
- };
- openRequest.onerror = error => reject(openRequest.error);
- });
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const PATH_SEPARATOR = '/';
- const PATH_PREFIX = 'tensorflowjs_models';
- const INFO_SUFFIX = 'info';
- const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
- const WEIGHT_SPECS_SUFFIX = 'weight_specs';
- const WEIGHT_DATA_SUFFIX = 'weight_data';
- const MODEL_METADATA_SUFFIX = 'model_metadata';
- /**
- * Purge all tensorflow.js-saved model artifacts from local storage.
- *
- * @returns Paths of the models purged.
- */
- function purgeLocalStorageArtifacts() {
- if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
- typeof window.localStorage === 'undefined') {
- throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' +
- 'unavailable in the current environment.');
- }
- const LS = window.localStorage;
- const purgedModelPaths = [];
- for (let i = 0; i < LS.length; ++i) {
- const key = LS.key(i);
- const prefix = PATH_PREFIX + PATH_SEPARATOR;
- if (key.startsWith(prefix) && key.length > prefix.length) {
- LS.removeItem(key);
- const modelName = getModelPathFromKey(key);
- if (purgedModelPaths.indexOf(modelName) === -1) {
- purgedModelPaths.push(modelName);
- }
- }
- }
- return purgedModelPaths;
- }
- function getModelKeys(path) {
- return {
- info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
- topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
- weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
- weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
- modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
- };
- }
- /**
- * Get model path from a local-storage key.
- *
- * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
- *
- * @param key
- */
- function getModelPathFromKey(key) {
- const items = key.split(PATH_SEPARATOR);
- if (items.length < 3) {
- throw new Error(`Invalid key format: ${key}`);
- }
- return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
- }
- function maybeStripScheme$1(key) {
- return key.startsWith(BrowserLocalStorage.URL_SCHEME) ?
- key.slice(BrowserLocalStorage.URL_SCHEME.length) :
- key;
- }
- /**
- * IOHandler subclass: Browser Local Storage.
- *
- * See the doc string to `browserLocalStorage` for more details.
- */
- class BrowserLocalStorage {
- constructor(modelPath) {
- if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
- typeof window.localStorage === 'undefined') {
- // TODO(cais): Add more info about what IOHandler subtypes are
- // available.
- // Maybe point to a doc page on the web and/or automatically determine
- // the available IOHandlers and print them in the error message.
- throw new Error('The current environment does not support local storage.');
- }
- this.LS = window.localStorage;
- if (modelPath == null || !modelPath) {
- throw new Error('For local storage, modelPath must not be null, undefined or empty.');
- }
- this.modelPath = modelPath;
- this.keys = getModelKeys(this.modelPath);
- }
- /**
- * Save model artifacts to browser local storage.
- *
- * See the documentation to `browserLocalStorage` for details on the saved
- * artifacts.
- *
- * @param modelArtifacts The model artifacts to be stored.
- * @returns An instance of SaveResult.
- */
- async save(modelArtifacts) {
- if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
- throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
- 'in binary formats yet.');
- }
- else {
- const topology = JSON.stringify(modelArtifacts.modelTopology);
- const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
- const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
- try {
- this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
- this.LS.setItem(this.keys.topology, topology);
- this.LS.setItem(this.keys.weightSpecs, weightSpecs);
- this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData));
- this.LS.setItem(this.keys.modelMetadata, JSON.stringify({
- format: modelArtifacts.format,
- generatedBy: modelArtifacts.generatedBy,
- convertedBy: modelArtifacts.convertedBy,
- userDefinedMetadata: modelArtifacts.userDefinedMetadata
- }));
- return { modelArtifactsInfo };
- }
- catch (err) {
- // If saving failed, clean up all items saved so far.
- this.LS.removeItem(this.keys.info);
- this.LS.removeItem(this.keys.topology);
- this.LS.removeItem(this.keys.weightSpecs);
- this.LS.removeItem(this.keys.weightData);
- this.LS.removeItem(this.keys.modelMetadata);
- throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` +
- `size quota being exceeded is a possible cause of this failure: ` +
- `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +
- `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +
- `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);
- }
- }
- }
- /**
- * Load a model from local storage.
- *
- * See the documentation to `browserLocalStorage` for details on the saved
- * artifacts.
- *
- * @returns The loaded model (if loading succeeds).
- */
- async load() {
- const info = JSON.parse(this.LS.getItem(this.keys.info));
- if (info == null) {
- throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
- }
- if (info.modelTopologyType !== 'JSON') {
- throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
- 'topology yet.');
- }
- const out = {};
- // Load topology.
- const topology = JSON.parse(this.LS.getItem(this.keys.topology));
- if (topology == null) {
- throw new Error(`In local storage, the topology of model '${this.modelPath}' ` +
- `is missing.`);
- }
- out.modelTopology = topology;
- // Load weight specs.
- const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
- if (weightSpecs == null) {
- throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` +
- `are missing.`);
- }
- out.weightSpecs = weightSpecs;
- // Load meta-data fields.
- const metadataString = this.LS.getItem(this.keys.modelMetadata);
- if (metadataString != null) {
- const metadata = JSON.parse(metadataString);
- out.format = metadata['format'];
- out.generatedBy = metadata['generatedBy'];
- out.convertedBy = metadata['convertedBy'];
- out.userDefinedMetadata = metadata['userDefinedMetadata'];
- }
- // Load weight data.
- const weightDataBase64 = this.LS.getItem(this.keys.weightData);
- if (weightDataBase64 == null) {
- throw new Error(`In local storage, the binary weight values of model ` +
- `'${this.modelPath}' are missing.`);
- }
- out.weightData = base64StringToArrayBuffer(weightDataBase64);
- return out;
- }
- }
- BrowserLocalStorage.URL_SCHEME = 'localstorage://';
- const localStorageRouter = (url) => {
- if (!env().getBool('IS_BROWSER')) {
- return null;
- }
- else {
- if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
- return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
- }
- else {
- return null;
- }
- }
- };
- IORouterRegistry.registerSaveRouter(localStorageRouter);
- IORouterRegistry.registerLoadRouter(localStorageRouter);
- /**
- * Factory function for local storage IOHandler.
- *
- * This `IOHandler` supports both `save` and `load`.
- *
- * For each model's saved artifacts, four items are saved to local storage.
- * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
- * model, such as date saved, type of the topology, size in bytes, etc.
- * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
- * style models, this is a stringized JSON.
- * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
- * model, can be used to decode the saved binary weight values (see
- * item below).
- * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
- * weight values, stored as a base64-encoded string.
- *
- * Saving may throw an `Error` if the total size of the artifacts exceed the
- * browser-specific quota.
- *
- * @param modelPath A unique identifier for the model to be saved. Must be a
- * non-empty string.
- * @returns An instance of `IOHandler`, which can be used with, e.g.,
- * `tf.Model.save`.
- */
- function browserLocalStorage(modelPath) {
- return new BrowserLocalStorage(modelPath);
- }
- class BrowserLocalStorageManager {
- constructor() {
- assert(env().getBool('IS_BROWSER'), () => 'Current environment is not a web browser');
- assert(typeof window === 'undefined' ||
- typeof window.localStorage !== 'undefined', () => 'Current browser does not appear to support localStorage');
- this.LS = window.localStorage;
- }
- async listModels() {
- const out = {};
- const prefix = PATH_PREFIX + PATH_SEPARATOR;
- const suffix = PATH_SEPARATOR + INFO_SUFFIX;
- for (let i = 0; i < this.LS.length; ++i) {
- const key = this.LS.key(i);
- if (key.startsWith(prefix) && key.endsWith(suffix)) {
- const modelPath = getModelPathFromKey(key);
- out[modelPath] = JSON.parse(this.LS.getItem(key));
- }
- }
- return out;
- }
- async removeModel(path) {
- path = maybeStripScheme$1(path);
- const keys = getModelKeys(path);
- if (this.LS.getItem(keys.info) == null) {
- throw new Error(`Cannot find model at path '${path}'`);
- }
- const info = JSON.parse(this.LS.getItem(keys.info));
- this.LS.removeItem(keys.info);
- this.LS.removeItem(keys.topology);
- this.LS.removeItem(keys.weightSpecs);
- this.LS.removeItem(keys.weightData);
- return info;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const URL_SCHEME_SUFFIX = '://';
- class ModelStoreManagerRegistry {
- constructor() {
- this.managers = {};
- }
- static getInstance() {
- if (ModelStoreManagerRegistry.instance == null) {
- ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
- }
- return ModelStoreManagerRegistry.instance;
- }
- /**
- * Register a save-handler router.
- *
- * @param saveRouter A function that maps a URL-like string onto an instance
- * of `IOHandler` with the `save` method defined or `null`.
- */
- static registerManager(scheme, manager) {
- assert(scheme != null, () => 'scheme must not be undefined or null.');
- if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
- scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
- }
- assert(scheme.length > 0, () => 'scheme must not be an empty string.');
- const registry = ModelStoreManagerRegistry.getInstance();
- assert(registry.managers[scheme] == null, () => `A model store manager is already registered for scheme '${scheme}'.`);
- registry.managers[scheme] = manager;
- }
- static getManager(scheme) {
- const manager = this.getInstance().managers[scheme];
- if (manager == null) {
- throw new Error(`Cannot find model manager for scheme '${scheme}'`);
- }
- return manager;
- }
- static getSchemes() {
- return Object.keys(this.getInstance().managers);
- }
- }
- /**
- * Helper method for parsing a URL string into a scheme and a path.
- *
- * @param url E.g., 'localstorage://my-model'
- * @returns A dictionary with two fields: scheme and path.
- * Scheme: e.g., 'localstorage' in the example above.
- * Path: e.g., 'my-model' in the example above.
- */
- function parseURL(url) {
- if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
- throw new Error(`The url string provided does not contain a scheme. ` +
- `Supported schemes are: ` +
- `${ModelStoreManagerRegistry.getSchemes().join(',')}`);
- }
- return {
- scheme: url.split(URL_SCHEME_SUFFIX)[0],
- path: url.split(URL_SCHEME_SUFFIX)[1],
- };
- }
- async function cloneModelInternal(sourceURL, destURL, deleteSource = false) {
- assert(sourceURL !== destURL, () => `Old path and new path are the same: '${sourceURL}'`);
- const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
- assert(loadHandlers.length > 0, () => `Copying failed because no load handler is found for source URL ${sourceURL}.`);
- assert(loadHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
- `load handlers for source URL ${sourceURL}.`);
- const loadHandler = loadHandlers[0];
- const saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
- assert(saveHandlers.length > 0, () => `Copying failed because no save handler is found for destination ` +
- `URL ${destURL}.`);
- assert(saveHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
- `save handlers for destination URL ${destURL}.`);
- const saveHandler = saveHandlers[0];
- const sourceScheme = parseURL(sourceURL).scheme;
- const sourcePath = parseURL(sourceURL).path;
- const sameMedium = sourceScheme === parseURL(sourceURL).scheme;
- const modelArtifacts = await loadHandler.load();
- // If moving within the same storage medium, remove the old model as soon as
- // the loading is done. Without doing this, it is possible that the combined
- // size of the two models will cause the cloning to fail.
- if (deleteSource && sameMedium) {
- await ModelStoreManagerRegistry.getManager(sourceScheme)
- .removeModel(sourcePath);
- }
- const saveResult = await saveHandler.save(modelArtifacts);
- // If moving between mediums, the deletion is done after the save succeeds.
- // This guards against the case in which saving to the destination medium
- // fails.
- if (deleteSource && !sameMedium) {
- await ModelStoreManagerRegistry.getManager(sourceScheme)
- .removeModel(sourcePath);
- }
- return saveResult.modelArtifactsInfo;
- }
- /**
- * List all models stored in registered storage mediums.
- *
- * For a web browser environment, the registered mediums are Local Storage and
- * IndexedDB.
- *
- * ```js
- * // First create and save a model.
- * const model = tf.sequential();
- * model.add(tf.layers.dense(
- * {units: 1, inputShape: [10], activation: 'sigmoid'}));
- * await model.save('localstorage://demo/management/model1');
- *
- * // Then list existing models.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Delete the model.
- * await tf.io.removeModel('localstorage://demo/management/model1');
- *
- * // List models again.
- * console.log(JSON.stringify(await tf.io.listModels()));
- * ```
- *
- * @returns A `Promise` of a dictionary mapping URLs of existing models to
- * their model artifacts info. URLs include medium-specific schemes, e.g.,
- * 'indexeddb://my/model/1'. Model artifacts info include type of the
- * model's topology, byte sizes of the topology, weights, etc.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Management',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- async function listModels() {
- const schemes = ModelStoreManagerRegistry.getSchemes();
- const out = {};
- for (const scheme of schemes) {
- const schemeOut = await ModelStoreManagerRegistry.getManager(scheme).listModels();
- for (const path in schemeOut) {
- const url = scheme + URL_SCHEME_SUFFIX + path;
- out[url] = schemeOut[path];
- }
- }
- return out;
- }
- /**
- * Remove a model specified by URL from a reigstered storage medium.
- *
- * ```js
- * // First create and save a model.
- * const model = tf.sequential();
- * model.add(tf.layers.dense(
- * {units: 1, inputShape: [10], activation: 'sigmoid'}));
- * await model.save('localstorage://demo/management/model1');
- *
- * // Then list existing models.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Delete the model.
- * await tf.io.removeModel('localstorage://demo/management/model1');
- *
- * // List models again.
- * console.log(JSON.stringify(await tf.io.listModels()));
- * ```
- *
- * @param url A URL to a stored model, with a scheme prefix, e.g.,
- * 'localstorage://my-model-1', 'indexeddb://my/model/2'.
- * @returns ModelArtifactsInfo of the deleted model (if and only if deletion
- * is successful).
- * @throws Error if deletion fails, e.g., if no model exists at `path`.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Management',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- async function removeModel(url) {
- const schemeAndPath = parseURL(url);
- const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
- return manager.removeModel(schemeAndPath.path);
- }
- /**
- * Copy a model from one URL to another.
- *
- * This function supports:
- *
- * 1. Copying within a storage medium, e.g.,
- * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
- * 2. Copying between two storage mediums, e.g.,
- * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
- *
- * ```js
- * // First create and save a model.
- * const model = tf.sequential();
- * model.add(tf.layers.dense(
- * {units: 1, inputShape: [10], activation: 'sigmoid'}));
- * await model.save('localstorage://demo/management/model1');
- *
- * // Then list existing models.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Copy the model, from Local Storage to IndexedDB.
- * await tf.io.copyModel(
- * 'localstorage://demo/management/model1',
- * 'indexeddb://demo/management/model1');
- *
- * // List models again.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Remove both models.
- * await tf.io.removeModel('localstorage://demo/management/model1');
- * await tf.io.removeModel('indexeddb://demo/management/model1');
- * ```
- *
- * @param sourceURL Source URL of copying.
- * @param destURL Destination URL of copying.
- * @returns ModelArtifactsInfo of the copied model (if and only if copying
- * is successful).
- * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
- * if `oldPath` and `newPath` are identical.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Management',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- async function copyModel(sourceURL, destURL) {
- const deleteSource = false;
- return cloneModelInternal(sourceURL, destURL, deleteSource);
- }
- /**
- * Move a model from one URL to another.
- *
- * This function supports:
- *
- * 1. Moving within a storage medium, e.g.,
- * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
- * 2. Moving between two storage mediums, e.g.,
- * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
- *
- * ```js
- * // First create and save a model.
- * const model = tf.sequential();
- * model.add(tf.layers.dense(
- * {units: 1, inputShape: [10], activation: 'sigmoid'}));
- * await model.save('localstorage://demo/management/model1');
- *
- * // Then list existing models.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Move the model, from Local Storage to IndexedDB.
- * await tf.io.moveModel(
- * 'localstorage://demo/management/model1',
- * 'indexeddb://demo/management/model1');
- *
- * // List models again.
- * console.log(JSON.stringify(await tf.io.listModels()));
- *
- * // Remove the moved model.
- * await tf.io.removeModel('indexeddb://demo/management/model1');
- * ```
- *
- * @param sourceURL Source URL of moving.
- * @param destURL Destination URL of moving.
- * @returns ModelArtifactsInfo of the copied model (if and only if copying
- * is successful).
- * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
- * if `oldPath` and `newPath` are identical.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Management',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- async function moveModel(sourceURL, destURL) {
- const deleteSource = true;
- return cloneModelInternal(sourceURL, destURL, deleteSource);
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class PlatformBrowser {
- fetch(path, init) {
- return fetch(path, init);
- }
- now() {
- return performance.now();
- }
- encode(text, encoding) {
- if (encoding !== 'utf-8' && encoding !== 'utf8') {
- throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
- }
- if (this.textEncoder == null) {
- this.textEncoder = new TextEncoder();
- }
- return this.textEncoder.encode(text);
- }
- decode(bytes, encoding) {
- return new TextDecoder(encoding).decode(bytes);
- }
- }
- if (env().get('IS_BROWSER')) {
- env().setPlatform('browser', new PlatformBrowser());
- // Register LocalStorage IOHandler
- try {
- ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
- }
- catch (err) {
- }
- // Register IndexedDB IOHandler
- try {
- ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
- }
- catch (err) {
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // We are wrapping this within an object so it can be stubbed by Jasmine.
- const getNodeFetch = {
- // tslint:disable-next-line:no-require-imports
- importFetch: () => require('node-fetch')
- };
- let systemFetch;
- // These getters and setters are for testing so we don't export a mutable
- // variable.
- function resetSystemFetch() {
- systemFetch = null;
- }
- function setSystemFetch(fetchFn) {
- systemFetch = fetchFn;
- }
- function getSystemFetch() {
- return systemFetch;
- }
- class PlatformNode {
- constructor() {
- // tslint:disable-next-line:no-require-imports
- this.util = require('util');
- // According to the spec, the built-in encoder can do only UTF-8 encoding.
- // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
- this.textEncoder = new this.util.TextEncoder();
- }
- fetch(path, requestInits) {
- if (env().global.fetch != null) {
- return env().global.fetch(path, requestInits);
- }
- if (systemFetch == null) {
- systemFetch = getNodeFetch.importFetch();
- }
- return systemFetch(path, requestInits);
- }
- now() {
- const time = process.hrtime();
- return time[0] * 1000 + time[1] / 1000000;
- }
- encode(text, encoding) {
- if (encoding !== 'utf-8' && encoding !== 'utf8') {
- throw new Error(`Node built-in encoder only supports utf-8, but got ${encoding}`);
- }
- return this.textEncoder.encode(text);
- }
- decode(bytes, encoding) {
- if (bytes.length === 0) {
- return '';
- }
- return new this.util.TextDecoder(encoding).decode(bytes);
- }
- }
- if (env().get('IS_NODE')) {
- env().setPlatform('node', new PlatformNode());
- }
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
- *
- * The values are stored in CPU as `TypedArray`. Fill the buffer using
- * `buffer.set()`, or by modifying directly `buffer.values`.
- *
- * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
- * those values.
- *
- * ```js
- * // Create a buffer and set values at particular indices.
- * const buffer = tf.buffer([2, 2]);
- * buffer.set(3, 0, 0);
- * buffer.set(5, 1, 0);
- *
- * // Convert the buffer back to a tensor.
- * buffer.toTensor().print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The dtype of the buffer. Defaults to 'float32'.
- * @param values The values of the buffer as `TypedArray`. Defaults to
- * zeros.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function buffer(shape, dtype = 'float32', values) {
- dtype = dtype || 'float32';
- assertNonNegativeIntegerDimensions(shape);
- return new TensorBuffer(shape, dtype, values);
- }
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Casts a `tf.Tensor` to a new dtype.
- *
- * ```js
- * const x = tf.tensor1d([1.5, 2.5, 3]);
- * tf.cast(x, 'int32').print();
- * ```
- * @param x The input tensor to be casted.
- * @param dtype The dtype to cast the input tensor to.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function cast_(x, dtype) {
- const $x = convertToTensor(x, 'x', 'cast');
- // Sanity checks.
- if (!isValidDtype(dtype)) {
- throw new Error(`Failed to cast to unknown dtype ${dtype}`);
- }
- if (dtype === 'string' && $x.dtype !== 'string' ||
- dtype !== 'string' && $x.dtype === 'string') {
- throw new Error('Only strings can be casted to strings');
- }
- const inputs = { x: $x };
- const attrs = { dtype };
- return ENGINE.runKernelFunc(backend => backend.cast($x, dtype), inputs, null /* grad */, Cast, attrs);
- }
- const cast = op({ cast_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a new tensor with the same values and shape as the specified
- * tensor.
- *
- * ```js
- * const x = tf.tensor([1, 2]);
- *
- * x.clone().print();
- * ```
- *
- * @param x The tensor to clone.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function clone_(x) {
- const $x = convertToTensor(x, 'x', 'clone', null);
- const forward = () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype);
- const inputs = { x: $x };
- // Note this op is called tf.identity in python. Hence the kernel name used
- // here.
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Identity);
- }
- const clone = op({ clone_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Prints information about the `tf.Tensor` including its data.
- *
- * ```js
- * const verbose = true;
- * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
- * ```
- * @param x The tensor to be printed.
- * @param verbose Whether to print verbose information about the ` Tensor`,
- * including dtype and size.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function print(x, verbose = false) {
- console.log(x.toString(verbose));
- }
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- getOrMakeEngine();
- const opHandler$1 = {
- buffer,
- cast,
- clone,
- print
- };
- setOpHandler(opHandler$1);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const DEFAULT_FILE_NAME_PREFIX = 'model';
- const DEFAULT_JSON_EXTENSION_NAME = '.json';
- const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
- function defer(f) {
- return new Promise(resolve => setTimeout(resolve)).then(f);
- }
- class BrowserDownloads {
- constructor(fileNamePrefix) {
- if (!env().getBool('IS_BROWSER')) {
- // TODO(cais): Provide info on what IOHandlers are available under the
- // current environment.
- throw new Error('browserDownloads() cannot proceed because the current environment ' +
- 'is not a browser.');
- }
- if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
- fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
- }
- if (fileNamePrefix == null || fileNamePrefix.length === 0) {
- fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
- }
- this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
- this.weightDataFileName =
- fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
- }
- async save(modelArtifacts) {
- if (typeof (document) === 'undefined') {
- throw new Error('Browser downloads are not supported in ' +
- 'this environment since `document` is not present');
- }
- const weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' }));
- if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
- throw new Error('BrowserDownloads.save() does not support saving model topology ' +
- 'in binary formats yet.');
- }
- else {
- const weightsManifest = [{
- paths: ['./' + this.weightDataFileName],
- weights: modelArtifacts.weightSpecs
- }];
- const modelTopologyAndWeightManifest = {
- modelTopology: modelArtifacts.modelTopology,
- format: modelArtifacts.format,
- generatedBy: modelArtifacts.generatedBy,
- convertedBy: modelArtifacts.convertedBy,
- weightsManifest
- };
- const modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: 'application/json' }));
- // If anchor elements are not provided, create them without attaching them
- // to parents, so that the downloaded file names can be controlled.
- const jsonAnchor = this.jsonAnchor == null ? document.createElement('a') :
- this.jsonAnchor;
- jsonAnchor.download = this.modelTopologyFileName;
- jsonAnchor.href = modelTopologyAndWeightManifestURL;
- // Trigger downloads by evoking a click event on the download anchors.
- // When multiple downloads are started synchronously, Firefox will only
- // save the last one.
- await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));
- if (modelArtifacts.weightData != null) {
- const weightDataAnchor = this.weightDataAnchor == null ?
- document.createElement('a') :
- this.weightDataAnchor;
- weightDataAnchor.download = this.weightDataFileName;
- weightDataAnchor.href = weightsURL;
- await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
- }
- return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) };
- }
- }
- }
- BrowserDownloads.URL_SCHEME = 'downloads://';
- class BrowserFiles {
- constructor(files) {
- if (files == null || files.length < 1) {
- throw new Error(`When calling browserFiles, at least 1 file is required, ` +
- `but received ${files}`);
- }
- this.files = files;
- }
- async load() {
- const jsonFile = this.files[0];
- const weightFiles = this.files.slice(1);
- return new Promise((resolve, reject) => {
- const jsonReader = new FileReader();
- jsonReader.onload = (event) => {
- // tslint:disable-next-line:no-any
- const modelJSON = JSON.parse(event.target.result);
- const modelTopology = modelJSON.modelTopology;
- if (modelTopology == null) {
- reject(new Error(`modelTopology field is missing from file ${jsonFile.name}`));
- return;
- }
- if (weightFiles.length === 0) {
- resolve({ modelTopology });
- }
- const weightsManifest = modelJSON.weightsManifest;
- if (weightsManifest == null) {
- reject(new Error(`weightManifest field is missing from file ${jsonFile.name}`));
- return;
- }
- let pathToFile;
- try {
- pathToFile =
- this.checkManifestAndWeightFiles(weightsManifest, weightFiles);
- }
- catch (err) {
- reject(err);
- return;
- }
- const weightSpecs = [];
- const paths = [];
- const perFileBuffers = [];
- weightsManifest.forEach(weightsGroup => {
- weightsGroup.paths.forEach(path => {
- paths.push(path);
- perFileBuffers.push(null);
- });
- weightSpecs.push(...weightsGroup.weights);
- });
- weightsManifest.forEach(weightsGroup => {
- weightsGroup.paths.forEach(path => {
- const weightFileReader = new FileReader();
- weightFileReader.onload = (event) => {
- // tslint:disable-next-line:no-any
- const weightData = event.target.result;
- const index = paths.indexOf(path);
- perFileBuffers[index] = weightData;
- if (perFileBuffers.indexOf(null) === -1) {
- resolve({
- modelTopology,
- weightSpecs,
- weightData: concatenateArrayBuffers(perFileBuffers),
- format: modelJSON.format,
- generatedBy: modelJSON.generatedBy,
- convertedBy: modelJSON.convertedBy,
- userDefinedMetadata: modelJSON.userDefinedMetadata
- });
- }
- };
- weightFileReader.onerror = error => reject(`Failed to weights data from file of path '${path}'.`);
- weightFileReader.readAsArrayBuffer(pathToFile[path]);
- });
- });
- };
- jsonReader.onerror = error => reject(`Failed to read model topology and weights manifest JSON ` +
- `from file '${jsonFile.name}'. BrowserFiles supports loading ` +
- `Keras-style tf.Model artifacts only.`);
- jsonReader.readAsText(jsonFile);
- });
- }
- /**
- * Check the compatibility between weights manifest and weight files.
- */
- checkManifestAndWeightFiles(manifest, files) {
- const basenames = [];
- const fileNames = files.map(file => basename(file.name));
- const pathToFile = {};
- for (const group of manifest) {
- group.paths.forEach(path => {
- const pathBasename = basename(path);
- if (basenames.indexOf(pathBasename) !== -1) {
- throw new Error(`Duplicate file basename found in weights manifest: ` +
- `'${pathBasename}'`);
- }
- basenames.push(pathBasename);
- if (fileNames.indexOf(pathBasename) === -1) {
- throw new Error(`Weight file with basename '${pathBasename}' is not provided.`);
- }
- else {
- pathToFile[path] = files[fileNames.indexOf(pathBasename)];
- }
- });
- }
- if (basenames.length !== files.length) {
- throw new Error(`Mismatch in the number of files in weights manifest ` +
- `(${basenames.length}) and the number of weight files provided ` +
- `(${files.length}).`);
- }
- return pathToFile;
- }
- }
- const browserDownloadsRouter = (url) => {
- if (!env().getBool('IS_BROWSER')) {
- return null;
- }
- else {
- if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
- return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
- }
- else {
- return null;
- }
- }
- };
- IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
- /**
- * Creates an IOHandler that triggers file downloads from the browser.
- *
- * The returned `IOHandler` instance can be used as model exporting methods such
- * as `tf.Model.save` and supports only saving.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.dense(
- * {units: 1, inputShape: [10], activation: 'sigmoid'}));
- * const saveResult = await model.save('downloads://mymodel');
- * // This will trigger downloading of two files:
- * // 'mymodel.json' and 'mymodel.weights.bin'.
- * console.log(saveResult);
- * ```
- *
- * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
- * `tf.Model`, `fileNamePrefix` should follow either of the following two
- * formats:
- * 1. `null` or `undefined`, in which case the default file
- * names will be used:
- * - 'model.json' for the JSON file containing the model topology and
- * weights manifest.
- * - 'model.weights.bin' for the binary file containing the binary weight
- * values.
- * 2. A single string or an Array of a single string, as the file name prefix.
- * For example, if `'foo'` is provided, the downloaded JSON
- * file and binary weights file will be named 'foo.json' and
- * 'foo.weights.bin', respectively.
- * @param config Additional configuration for triggering downloads.
- * @returns An instance of `BrowserDownloads` `IOHandler`.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Loading',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- function browserDownloads(fileNamePrefix = 'model') {
- return new BrowserDownloads(fileNamePrefix);
- }
- /**
- * Creates an IOHandler that loads model artifacts from user-selected files.
- *
- * This method can be used for loading from files such as user-selected files
- * in the browser.
- * When used in conjunction with `tf.loadLayersModel`, an instance of
- * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
- *
- * ```js
- * // Note: This code snippet won't run properly without the actual file input
- * // elements in the HTML DOM.
- *
- * // Suppose there are two HTML file input (` `)
- * // elements.
- * const uploadJSONInput = document.getElementById('upload-json');
- * const uploadWeightsInput = document.getElementById('upload-weights');
- * const model = await tf.loadLayersModel(tf.io.browserFiles(
- * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
- * ```
- *
- * @param files `File`s to load from. Currently, this function supports only
- * loading from files that contain Keras-style models (i.e., `tf.Model`s), for
- * which an `Array` of `File`s is expected (in that order):
- * - A JSON file containing the model topology and weight manifest.
- * - Optionally, One or more binary files containing the binary weights.
- * These files must have names that match the paths in the `weightsManifest`
- * contained by the aforementioned JSON file, or errors will be thrown
- * during loading. These weights files have the same format as the ones
- * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
- * Python PIP package. If no weights files are provided, only the model
- * topology will be loaded from the JSON file above.
- * @returns An instance of `Files` `IOHandler`.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Loading',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- function browserFiles(files) {
- return new BrowserFiles(files);
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Monitor Promise.all progress, fire onProgress callback function.
- *
- * @param promises Promise list going to be monitored
- * @param onProgress Callback function. Fired when a promise resolved.
- * @param startFraction Optional fraction start. Default to 0.
- * @param endFraction Optional fraction end. Default to 1.
- */
- function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
- checkPromises(promises);
- startFraction = startFraction == null ? 0 : startFraction;
- endFraction = endFraction == null ? 1 : endFraction;
- checkFraction(startFraction, endFraction);
- let resolvedPromise = 0;
- const registerMonitor = (promise) => {
- promise.then(value => {
- const fraction = startFraction +
- ++resolvedPromise / promises.length * (endFraction - startFraction);
- // pass fraction as parameter to callback function.
- onProgress(fraction);
- return value;
- });
- return promise;
- };
- function checkPromises(promises) {
- assert(promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array');
- }
- function checkFraction(startFraction, endFraction) {
- assert(startFraction >= 0 && startFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
- `got startFraction ${startFraction}`);
- assert(endFraction >= 0 && endFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
- `got endFraction ${endFraction}`);
- assert(endFraction >= startFraction, () => `startFraction must be no more than endFraction, but ` +
- `got startFraction ${startFraction} and endFraction ` +
- `${endFraction}`);
- }
- return Promise.all(promises.map(registerMonitor));
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reads binary weights data from a number of URLs.
- *
- * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
- * @param requestOptions RequestInit (options) for the HTTP requests.
- * @param fetchFunc Optional overriding value for the `window.fetch` function.
- * @param onProgress Optional, progress callback function, fired periodically
- * before the load is completed.
- * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
- * length as `fetchURLs`.
- */
- async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
- if (loadOptions == null) {
- loadOptions = {};
- }
- const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
- loadOptions.fetchFunc;
- // Create the requests for all of the weights in parallel.
- const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));
- const fetchStartFraction = 0;
- const fetchEndFraction = 0.5;
- const responses = loadOptions.onProgress == null ?
- await Promise.all(requests) :
- await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
- const bufferPromises = responses.map(response => response.arrayBuffer());
- const bufferStartFraction = 0.5;
- const bufferEndFraction = 1;
- const buffers = loadOptions.onProgress == null ?
- await Promise.all(bufferPromises) :
- await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
- return buffers;
- }
- /**
- * Reads a weights manifest JSON configuration, fetches the weights and
- * returns them as `Tensor`s.
- *
- * @param manifest The weights manifest JSON.
- * @param filePathPrefix The path prefix for filenames given in the manifest.
- * Defaults to the empty string.
- * @param weightNames The names of the weights to be fetched.
- */
- async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) {
- // TODO(nsthorat): Groups are currently fetched atomically. If you need a
- // single weight from a group, the whole group will be fetched. At a future
- // date, we should support fetching only the individual shards within a
- // group that are needed to reconstruct the requested weight.
- // TODO(cais): Use `decodeWeights` for implementation.
- const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit });
- const loadWeights = weightsLoaderFactory(fetchWeights);
- return loadWeights(manifest, filePathPrefix, weightNames);
- }
- /**
- * Creates a function, which reads a weights manifest JSON configuration,
- * fetches the weight files using the specified function and returns them as
- * `Tensor`s.
- *
- * ```js
- * // example for creating a nodejs weight loader, which reads the weight files
- * // from disk using fs.readFileSync
- *
- * import * as fs from 'fs'
- *
- * const fetchWeightsFromDisk = (filePaths: string[]) =>
- * filePaths.map(filePath => fs.readFileSync(filePath).buffer)
- *
- * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
- *
- * const manifest = JSON.parse(
- * fs.readFileSync('./my_model-weights_manifest').toString()
- * )
- * const weightMap = await loadWeights(manifest, './')
- * ```
- * @param fetchWeightsFunction The function used for fetching the weight files.
- * @returns Weight loading function.
- */
- function weightsLoaderFactory(fetchWeightsFunction) {
- return async (manifest, filePathPrefix = '', weightNames) => {
- // Collect all the groups, weights, and their relative offsets to be
- // fetched.
- const groupIndicesToFetchMap = manifest.map(() => false);
- const groupWeightsToFetch = {};
- const weightsFound = weightNames != null ? weightNames.map(() => false) : [];
- const allManifestWeightNames = [];
- manifest.forEach((manifestGroupConfig, groupIndex) => {
- let groupOffset = 0;
- manifestGroupConfig.weights.forEach(weightsEntry => {
- const rawDtype = ('quantization' in weightsEntry) ?
- weightsEntry.quantization.dtype :
- weightsEntry.dtype;
- const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
- sizeFromShape(weightsEntry.shape);
- const enqueueWeightsForFetchingFn = () => {
- groupIndicesToFetchMap[groupIndex] = true;
- if (groupWeightsToFetch[groupIndex] == null) {
- groupWeightsToFetch[groupIndex] = [];
- }
- groupWeightsToFetch[groupIndex].push({
- manifestEntry: weightsEntry,
- groupOffset,
- sizeBytes: weightsBytes
- });
- };
- if (weightNames != null) {
- weightNames.forEach((weightName, weightIndex) => {
- if (weightName === weightsEntry.name) {
- enqueueWeightsForFetchingFn();
- weightsFound[weightIndex] = true;
- }
- });
- }
- else {
- enqueueWeightsForFetchingFn();
- }
- allManifestWeightNames.push(weightsEntry.name);
- groupOffset += weightsBytes;
- });
- });
- if (!weightsFound.every(found => found)) {
- const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);
- throw new Error(`Could not find weights in manifest with names: ` +
- `${weightsNotFound.join(', ')}. \n` +
- `Manifest JSON has weights with names: ` +
- `${allManifestWeightNames.join(', ')}.`);
- }
- // Convert the one-hot boolean groupId => shouldFetch map to a list of group
- // IDs.
- const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {
- if (shouldFetch) {
- accumulator.push(i);
- }
- return accumulator;
- }, []);
- const fetchUrls = [];
- groupIndicesToFetch.forEach(i => {
- manifest[i].paths.forEach(filepath => {
- const fetchUrl = filePathPrefix +
- (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
- fetchUrls.push(fetchUrl);
- });
- });
- const buffers = await fetchWeightsFunction(fetchUrls);
- const weightsTensorMap = {};
- let bufferIndexOffset = 0;
- groupIndicesToFetch.forEach(i => {
- const numBuffers = manifest[i].paths.length;
- let groupBytes = 0;
- for (let i = 0; i < numBuffers; i++) {
- groupBytes += buffers[bufferIndexOffset + i].byteLength;
- }
- // Create a buffer for the whole group.
- const groupBuffer = new ArrayBuffer(groupBytes);
- const groupByteBuffer = new Uint8Array(groupBuffer);
- let groupBufferOffset = 0;
- for (let i = 0; i < numBuffers; i++) {
- const buffer = new Uint8Array(buffers[bufferIndexOffset + i]);
- groupByteBuffer.set(buffer, groupBufferOffset);
- groupBufferOffset += buffer.byteLength;
- }
- const weightsEntries = groupWeightsToFetch[i];
- weightsEntries.forEach(weightsEntry => {
- const byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
- const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
- for (const name in nameToTensorMap) {
- weightsTensorMap[name] = nameToTensorMap[name];
- }
- });
- bufferIndexOffset += numBuffers;
- });
- return weightsTensorMap;
- };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
- const JSON_TYPE = 'application/json';
- class HTTPRequest {
- constructor(path, loadOptions) {
- this.DEFAULT_METHOD = 'POST';
- if (loadOptions == null) {
- loadOptions = {};
- }
- this.weightPathPrefix = loadOptions.weightPathPrefix;
- this.onProgress = loadOptions.onProgress;
- this.weightUrlConverter = loadOptions.weightUrlConverter;
- if (loadOptions.fetchFunc != null) {
- assert(typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' +
- '`fetch` (see ' +
- 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
- this.fetch = loadOptions.fetchFunc;
- }
- else {
- this.fetch = env().platform.fetch;
- }
- assert(path != null && path.length > 0, () => 'URL path for http must not be null, undefined or ' +
- 'empty.');
- if (Array.isArray(path)) {
- assert(path.length === 2, () => 'URL paths for http must have a length of 2, ' +
- `(actual length is ${path.length}).`);
- }
- this.path = path;
- if (loadOptions.requestInit != null &&
- loadOptions.requestInit.body != null) {
- throw new Error('requestInit is expected to have no pre-existing body, but has one.');
- }
- this.requestInit = loadOptions.requestInit || {};
- }
- async save(modelArtifacts) {
- if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
- throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' +
- 'in binary formats yet.');
- }
- const init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit);
- init.body = new FormData();
- const weightsManifest = [{
- paths: ['./model.weights.bin'],
- weights: modelArtifacts.weightSpecs,
- }];
- const modelTopologyAndWeightManifest = {
- modelTopology: modelArtifacts.modelTopology,
- format: modelArtifacts.format,
- generatedBy: modelArtifacts.generatedBy,
- convertedBy: modelArtifacts.convertedBy,
- userDefinedMetadata: modelArtifacts.userDefinedMetadata,
- weightsManifest
- };
- init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json');
- if (modelArtifacts.weightData != null) {
- init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin');
- }
- const response = await this.fetch(this.path, init);
- if (response.ok) {
- return {
- modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
- responses: [response],
- };
- }
- else {
- throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ` +
- `${response.status}.`);
- }
- }
- /**
- * Load model artifacts via HTTP request(s).
- *
- * See the documentation to `tf.io.http` for details on the saved
- * artifacts.
- *
- * @returns The loaded model artifacts (if loading succeeds).
- */
- async load() {
- const modelConfigRequest = await this.fetch(this.path, this.requestInit);
- if (!modelConfigRequest.ok) {
- throw new Error(`Request to ${this.path} failed with status code ` +
- `${modelConfigRequest.status}. Please verify this URL points to ` +
- `the model JSON of the model to load.`);
- }
- let modelConfig;
- try {
- modelConfig = await modelConfigRequest.json();
- }
- catch (e) {
- let message = `Failed to parse model JSON of response from ${this.path}.`;
- // TODO(nsthorat): Remove this after some time when we're comfortable that
- // .pb files are mostly gone.
- if (this.path.endsWith('.pb')) {
- message += ' Your path contains a .pb file extension. ' +
- 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
- 'in favor of .json models. You can re-convert your Python ' +
- 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
- 'or you can convert your.pb models with the \'pb2json\'' +
- 'NPM script in the tensorflow/tfjs-converter repository.';
- }
- else {
- message += ' Please make sure the server is serving valid ' +
- 'JSON for this request.';
- }
- throw new Error(message);
- }
- const modelTopology = modelConfig.modelTopology;
- const weightsManifest = modelConfig.weightsManifest;
- const generatedBy = modelConfig.generatedBy;
- const convertedBy = modelConfig.convertedBy;
- const format = modelConfig.format;
- const userDefinedMetadata = modelConfig.userDefinedMetadata;
- // We do not allow both modelTopology and weightsManifest to be missing.
- if (modelTopology == null && weightsManifest == null) {
- throw new Error(`The JSON from HTTP path ${this.path} contains neither model ` +
- `topology or manifest for weights.`);
- }
- let weightSpecs;
- let weightData;
- if (weightsManifest != null) {
- const results = await this.loadWeights(weightsManifest);
- [weightSpecs, weightData] = results;
- }
- const artifacts = {
- modelTopology,
- weightSpecs,
- weightData,
- userDefinedMetadata,
- generatedBy,
- convertedBy,
- format
- };
- const initializer = modelConfig.modelInitializer;
- if (initializer) {
- artifacts.modelInitializer = initializer;
- }
- return artifacts;
- }
- async loadWeights(weightsManifest) {
- const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
- const [prefix, suffix] = parseUrl(weightPath);
- const pathPrefix = this.weightPathPrefix || prefix;
- const weightSpecs = [];
- for (const entry of weightsManifest) {
- weightSpecs.push(...entry.weights);
- }
- const fetchURLs = [];
- const urlPromises = [];
- for (const weightsGroup of weightsManifest) {
- for (const path of weightsGroup.paths) {
- if (this.weightUrlConverter != null) {
- urlPromises.push(this.weightUrlConverter(path));
- }
- else {
- fetchURLs.push(pathPrefix + path + suffix);
- }
- }
- }
- if (this.weightUrlConverter) {
- fetchURLs.push(...await Promise.all(urlPromises));
- }
- const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
- requestInit: this.requestInit,
- fetchFunc: this.fetch,
- onProgress: this.onProgress
- });
- return [weightSpecs, concatenateArrayBuffers(buffers)];
- }
- }
- HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
- /**
- * Extract the prefix and suffix of the url, where the prefix is the path before
- * the last file, and suffix is the search params after the last file.
- * ```
- * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
- * [prefix, suffix] = parseUrl(url)
- * // prefix = 'http://tfhub.dev/model/1/'
- * // suffix = '?tfjs-format=file'
- * ```
- * @param url the model url to be parsed.
- */
- function parseUrl(url) {
- const lastSlash = url.lastIndexOf('/');
- const lastSearchParam = url.lastIndexOf('?');
- const prefix = url.substring(0, lastSlash);
- const suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
- return [prefix + '/', suffix];
- }
- function isHTTPScheme(url) {
- return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
- }
- const httpRouter = (url, loadOptions) => {
- if (typeof fetch === 'undefined' &&
- (loadOptions == null || loadOptions.fetchFunc == null)) {
- // `http` uses `fetch` or `node-fetch`, if one wants to use it in
- // an environment that is not the browser or node they have to setup a
- // global fetch polyfill.
- return null;
- }
- else {
- let isHTTP = true;
- if (Array.isArray(url)) {
- isHTTP = url.every(urlItem => isHTTPScheme(urlItem));
- }
- else {
- isHTTP = isHTTPScheme(url);
- }
- if (isHTTP) {
- return http(url, loadOptions);
- }
- }
- return null;
- };
- IORouterRegistry.registerSaveRouter(httpRouter);
- IORouterRegistry.registerLoadRouter(httpRouter);
- /**
- * Creates an IOHandler subtype that sends model artifacts to HTTP server.
- *
- * An HTTP request of the `multipart/form-data` mime type will be sent to the
- * `path` URL. The form data includes artifacts that represent the topology
- * and/or weights of the model. In the case of Keras-style `tf.Model`, two
- * blobs (files) exist in form-data:
- * - A JSON file consisting of `modelTopology` and `weightsManifest`.
- * - A binary weights file consisting of the concatenated weight values.
- * These files are in the same format as the one generated by
- * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
- *
- * The following code snippet exemplifies the client-side code that uses this
- * function:
- *
- * ```js
- * const model = tf.sequential();
- * model.add(
- * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
- *
- * const saveResult = await model.save(tf.io.http(
- * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
- * console.log(saveResult);
- * ```
- *
- * If the default `POST` method is to be used, without any custom parameters
- * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
- *
- * ```js
- * const saveResult = await model.save('http://model-server:5000/upload');
- * ```
- *
- * The following GitHub Gist
- * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
- * implements a server based on [flask](https://github.com/pallets/flask) that
- * can receive the request. Upon receiving the model artifacts via the requst,
- * this particular server reconsistutes instances of [Keras
- * Models](https://keras.io/models/model/) in memory.
- *
- *
- * @param path A URL path to the model.
- * Can be an absolute HTTP path (e.g.,
- * 'http://localhost:8000/model-upload)') or a relative path (e.g.,
- * './model-upload').
- * @param requestInit Request configurations to be used when sending
- * HTTP request to server using `fetch`. It can contain fields such as
- * `method`, `credentials`, `headers`, `mode`, etc. See
- * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
- * for more information. `requestInit` must not have a body, because the
- * body will be set by TensorFlow.js. File blobs representing the model
- * topology (filename: 'model.json') and the weights of the model (filename:
- * 'model.weights.bin') will be appended to the body. If `requestInit` has a
- * `body`, an Error will be thrown.
- * @param loadOptions Optional configuration for the loading. It includes the
- * following fields:
- * - weightPathPrefix Optional, this specifies the path prefix for weight
- * files, by default this is calculated from the path param.
- * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
- * the `fetch` from node-fetch can be used here.
- * - onProgress Optional, progress callback function, fired periodically
- * before the load is completed.
- * @returns An instance of `IOHandler`.
- *
- * @doc {
- * heading: 'Models',
- * subheading: 'Loading',
- * namespace: 'io',
- * ignoreCI: true
- * }
- */
- function http(path, loadOptions) {
- return new HTTPRequest(path, loadOptions);
- }
- /**
- * Deprecated. Use `tf.io.http`.
- * @param path
- * @param loadOptions
- */
- function browserHTTPRequest(path, loadOptions) {
- return http(path, loadOptions);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class PassthroughLoader {
- constructor(modelArtifacts) {
- this.modelArtifacts = modelArtifacts;
- }
- async load() {
- return this.modelArtifacts;
- }
- }
- class PassthroughSaver {
- constructor(saveHandler) {
- this.saveHandler = saveHandler;
- }
- async save(modelArtifacts) {
- return this.saveHandler(modelArtifacts);
- }
- }
- /**
- * Creates an IOHandler that loads model artifacts from memory.
- *
- * When used in conjunction with `tf.loadLayersModel`, an instance of
- * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
- *
- * ```js
- * const model = await tf.loadLayersModel(tf.io.fromMemory(
- * modelTopology, weightSpecs, weightData));
- * ```
- *
- * @param modelArtifacts a object containing model topology (i.e., parsed from
- * the JSON format).
- * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
- * names, shapes, types, and quantization of the weight data.
- * @param weightData A single `ArrayBuffer` containing the weight data,
- * concatenated in the order described by the weightSpecs.
- * @param trainingConfig Model training configuration. Optional.
- *
- * @returns A passthrough `IOHandler` that simply loads the provided data.
- */
- function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
- if (arguments.length === 1) {
- const isModelArtifacts = modelArtifacts.modelTopology != null ||
- modelArtifacts.weightSpecs != null;
- if (isModelArtifacts) {
- return new PassthroughLoader(modelArtifacts);
- }
- else {
- // Legacy support: with only modelTopology.
- // TODO(cais): Remove this deprecated API.
- console.warn('Please call tf.io.fromMemory() with only one argument. ' +
- 'The argument should be of type ModelArtifacts. ' +
- 'The multi-argument signature of tf.io.fromMemory() has been ' +
- 'deprecated and will be removed in a future release.');
- return new PassthroughLoader({ modelTopology: modelArtifacts });
- }
- }
- else {
- // Legacy support.
- // TODO(cais): Remove this deprecated API.
- console.warn('Please call tf.io.fromMemory() with only one argument. ' +
- 'The argument should be of type ModelArtifacts. ' +
- 'The multi-argument signature of tf.io.fromMemory() has been ' +
- 'deprecated and will be removed in a future release.');
- return new PassthroughLoader({
- modelTopology: modelArtifacts,
- weightSpecs,
- weightData,
- trainingConfig
- });
- }
- }
- /**
- * Creates an IOHandler that passes saved model artifacts to a callback.
- *
- * ```js
- * function handleSave(artifacts) {
- * // ... do something with the artifacts ...
- * return {modelArtifactsInfo: {...}, ...};
- * }
- *
- * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
- * ```
- *
- * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
- * `SaveResult`.
- */
- function withSaveHandler(saveHandler) {
- return new PassthroughSaver(saveHandler);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var io = /*#__PURE__*/Object.freeze({
- __proto__: null,
- browserFiles: browserFiles,
- browserHTTPRequest: browserHTTPRequest,
- concatenateArrayBuffers: concatenateArrayBuffers,
- decodeWeights: decodeWeights,
- encodeWeights: encodeWeights,
- fromMemory: fromMemory,
- getLoadHandlers: getLoadHandlers,
- getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
- getSaveHandlers: getSaveHandlers,
- http: http,
- isHTTPScheme: isHTTPScheme,
- loadWeights: loadWeights,
- registerLoadRouter: registerLoadRouter,
- registerSaveRouter: registerSaveRouter,
- weightsLoaderFactory: weightsLoaderFactory,
- withSaveHandler: withSaveHandler,
- copyModel: copyModel,
- listModels: listModels,
- moveModel: moveModel,
- removeModel: removeModel
- });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reshapes a `tf.Tensor` to a given shape.
- *
- * Given an input tensor, returns a new tensor with the same values as the
- * input tensor with shape `shape`.
- *
- * If one component of shape is the special value -1, the size of that
- * dimension is computed so that the total size remains constant. In
- * particular, a shape of [-1] flattens into 1-D. At most one component of
- * shape can be -1.
- *
- * If shape is 1-D or higher, then the operation returns a tensor with shape
- * shape filled with the values of tensor. In this case, the number of
- * elements implied by shape must be the same as the number of elements in
- * tensor.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * x.reshape([2, 2]).print();
- * ```
- *
- * @param x The input tensor to be reshaped.
- * @param shape An array of integers defining the output tensor shape.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function reshape_(x, shape) {
- const $x = convertToTensor(x, 'x', 'reshape', null);
- const inputs = { x: $x };
- const attrs = { shape };
- const forward = (backend, save) => {
- shape = inferFromImplicitShape(shape, $x.size);
- assert($x.size === sizeFromShape(shape), () => 'new shape and old shape must have the same number of elements.');
- save([$x]);
- return backend.reshape($x, shape);
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Reshape, attrs);
- }
- const reshape = op({ reshape_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the dot product of two matrices, A * B. These must be matrices.
- *
- * ```js
- * const a = tf.tensor2d([1, 2], [1, 2]);
- * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * a.matMul(b).print(); // or tf.matMul(a, b)
- * ```
- * @param a First matrix in dot product operation.
- * @param b Second matrix in dot product operation.
- * @param transposeA If true, `a` is transposed before multiplication.
- * @param transposeB If true, `b` is transposed before multiplication.
- *
- * @doc {heading: 'Operations', subheading: 'Matrices'}
- */
- function matMul_(a, b, transposeA = false, transposeB = false) {
- let $a = convertToTensor(a, 'a', 'matMul');
- let $b = convertToTensor(b, 'b', 'matMul');
- [$a, $b] = makeTypesMatch($a, $b);
- assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, () => `Error in matMul: inputs must have the same rank of at least 2, ` +
- `got ranks ${$a.rank} and ${$b.rank}.`);
- const forward = (backend, save) => {
- save([$a, $b]);
- const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
- const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
- const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
- const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
- const outerDimsA = $a.shape.slice(0, -2);
- const outerDimsB = $b.shape.slice(0, -2);
- assert(arraysEqual(outerDimsA, outerDimsB), () => `Error in matMul: outer dimensions (${outerDimsA}) and (` +
- `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` +
- `${$b.shape} must match.`);
- assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
- `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
- `${$b.shape} and transposeA=${transposeA}` +
- ` and transposeB=${transposeB} must match.`);
- const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);
- const batchDimA = sizeFromShape(outerDimsA);
- const batchDimB = sizeFromShape(outerDimsB);
- const a3D = transposeA ?
- reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
- reshape($a, [batchDimA, outerShapeA, innerShapeA]);
- const b3D = transposeB ?
- reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
- reshape($b, [batchDimB, innerShapeB, outerShapeB]);
- const res3d = backend.batchMatMul(a3D, b3D, transposeA, transposeB);
- return reshape(res3d, outShape);
- };
- const inputs = { a: $a, b: $b };
- const attrs = { transposeA, transposeB };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, BatchMatMul, attrs);
- }
- const matMul = op({ matMul_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
- * value `onValue` (defaults to 1), while all other locations take value
- * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
- * `R+1` with the last axis of size `depth`.
- *
- * ```js
- * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
- * ```
- *
- * @param indices `tf.Tensor` of indices with dtype `int32`.
- * @param depth The depth of the one hot dimension.
- * @param onValue A number used to fill in the output when the index matches
- * the location.
- * @param offValue A number used to fill in the output when the index does
- * not match the location.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function oneHot_(indices, depth, onValue = 1, offValue = 0) {
- if (depth < 2) {
- throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
- }
- const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
- const outShape = [...$indices.shape, depth];
- const forward = (backend, save) => {
- save([$indices]);
- return reshape(backend.oneHot(reshape($indices, [$indices.size]), depth, onValue, offValue), outShape);
- };
- const inputs = { indices: $indices };
- const attrs = { depth, onValue, offValue };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, OneHot, attrs);
- }
- const oneHot = op({ oneHot_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
- *
- * The returned `tf.Tensor`'s dimension `i` will correspond to the input
- * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
- * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
- * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
- *
- * ```js
- * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
- *
- * a.transpose().print(); // or tf.transpose(a)
- * ```
- *
- * @param x The tensor to transpose.
- * @param perm The permutation of the dimensions of a.
- *
- * @doc {heading: 'Operations', subheading: 'Matrices'}
- */
- function transpose_(x, perm) {
- const $x = convertToTensor(x, 'x', 'transpose');
- if (perm == null) {
- perm = $x.shape.map((s, i) => i).reverse();
- }
- assert($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
- `must match length of perm ${perm}.`);
- perm.forEach(axis => {
- assert(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
- ` but got ${perm}`);
- });
- if ($x.rank <= 1) {
- return $x.clone();
- }
- const inputs = { x: $x };
- const attrs = { perm };
- return ENGINE.runKernelFunc(backend => backend.transpose($x, perm), inputs, null /* gradient */, Transpose, attrs);
- }
- const transpose = op({ transpose_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the confusion matrix from true labels and predicted labels.
- *
- * ```js
- * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
- * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
- * const numClasses = 3;
- * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
- * out.print();
- * // Expected output matrix:
- * // [[2, 0, 0],
- * // [0, 1, 1],
- * // [0, 0, 1]]
- * ```
- *
- * @param labels The target labels, assumed to be 0-based integers
- * for the classes. The shape is `[numExamples]`, where
- * `numExamples` is the number of examples included.
- * @param predictions The predicted classes, assumed to be
- * 0-based integers for the classes. Must have the same shape as `labels`.
- * @param numClasses Number of all classes, as an integer.
- * Its value must be larger than the largest element in `labels` and
- * `predictions`.
- * @returns The confusion matrix as a int32-type 2D tensor. The value at
- * row `r` and column `c` is the number of times examples of actual class
- * `r` were predicted as class `c`.
- *
- * @doc {heading: 'Operations', subheading: 'Evaluation'}
- */
- function confusionMatrix_(labels, predictions, numClasses) {
- const $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
- const $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
- assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, ` +
- `but got ${numClasses}`);
- assert($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`);
- assert($predictions.rank === 1, () => `Expected the rank of predictions to be 1, ` +
- `but got ${$predictions.rank}`);
- assert($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ` +
- `${$labels.shape[0]} vs. ${$predictions.shape[0]}. ` +
- `Labels and predictions should have the same number of elements.`);
- assert(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ` +
- `${numClasses}`);
- // TODO(cais): In the future, if oneHot supports tensors inputs for
- // `numClasses`, `confusionMatrix` can make `numClasses` optional.
- const oneHotLabels = oneHot(cast($labels, 'int32'), numClasses);
- const oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses);
- const oneHotLabelsT = transpose(oneHotLabels);
- return cast(matMul(oneHotLabelsT, oneHotPredictions), 'int32');
- }
- const confusionMatrix = op({ confusionMatrix_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var math = /*#__PURE__*/Object.freeze({
- __proto__: null,
- confusionMatrix: confusionMatrix
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor3d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. If not provided, it is inferred from
- * `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor3d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 3) {
- throw new Error('tensor3d() requires shape to have three numbers');
- }
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 3 && inferredShape.length !== 1) {
- throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor3d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- let fromPixels2DContext;
- /**
- * Creates a `tf.Tensor` from an image.
- *
- * ```js
- * const image = new ImageData(1, 1);
- * image.data[0] = 100;
- * image.data[1] = 150;
- * image.data[2] = 200;
- * image.data[3] = 255;
- *
- * tf.browser.fromPixels(image).print();
- * ```
- *
- * @param pixels The input image to construct the tensor from. The
- * supported image types are all 4-channel. You can also pass in an image
- * object with following attributes:
- * `{data: Uint8Array; width: number; height: number}`
- * @param numChannels The number of channels of the output tensor. A
- * numChannels value less than 4 allows you to ignore channels. Defaults to
- * 3 (ignores alpha channel of input image).
- *
- * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
- */
- function fromPixels_(pixels, numChannels = 3) {
- // Sanity checks.
- if (numChannels > 4) {
- throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
- }
- if (pixels == null) {
- throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
- }
- let isPixelData = false;
- let isImageData = false;
- let isVideo = false;
- let isImage = false;
- let isCanvasLike = false;
- if (pixels.data instanceof Uint8Array) {
- isPixelData = true;
- }
- else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {
- isImageData = true;
- }
- else if (typeof (HTMLVideoElement) !== 'undefined' &&
- pixels instanceof HTMLVideoElement) {
- isVideo = true;
- }
- else if (typeof (HTMLImageElement) !== 'undefined' &&
- pixels instanceof HTMLImageElement) {
- isImage = true;
- // tslint:disable-next-line: no-any
- }
- else if (pixels.getContext != null) {
- isCanvasLike = true;
- }
- else {
- throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
- `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
- `in browser, or OffscreenCanvas, ImageData in webworker` +
- ` or {data: Uint32Array, width: number, height: number}, ` +
- `but was ${pixels.constructor.name}`);
- }
- if (isVideo) {
- const HAVE_CURRENT_DATA_READY_STATE = 2;
- if (isVideo &&
- pixels.readyState <
- HAVE_CURRENT_DATA_READY_STATE) {
- throw new Error('The video element has not loaded data yet. Please wait for ' +
- '`loadeddata` event on the element.');
- }
- }
- // If the current backend has 'FromPixels' registered, it has a more
- // efficient way of handling pixel uploads, so we call that.
- const kernel = getKernel(FromPixels, ENGINE.backendName);
- if (kernel != null) {
- const inputs = { pixels };
- const attrs = { numChannels };
- return ENGINE.runKernel(FromPixels, inputs, attrs);
- }
- const [width, height] = isVideo ?
- [
- pixels.videoWidth,
- pixels.videoHeight
- ] :
- [pixels.width, pixels.height];
- let vals;
- if (isCanvasLike) {
- vals =
- // tslint:disable-next-line:no-any
- pixels.getContext('2d').getImageData(0, 0, width, height).data;
- }
- else if (isImageData || isPixelData) {
- vals = pixels.data;
- }
- else if (isImage || isVideo) {
- if (fromPixels2DContext == null) {
- fromPixels2DContext = document.createElement('canvas').getContext('2d');
- }
- fromPixels2DContext.canvas.width = width;
- fromPixels2DContext.canvas.height = height;
- fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
- vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
- }
- let values;
- if (numChannels === 4) {
- values = new Int32Array(vals);
- }
- else {
- const numPixels = width * height;
- values = new Int32Array(numPixels * numChannels);
- for (let i = 0; i < numPixels; i++) {
- for (let channel = 0; channel < numChannels; ++channel) {
- values[i * numChannels + channel] = vals[i * 4 + channel];
- }
- }
- }
- const outShape = [height, width, numChannels];
- return tensor3d(values, outShape, 'int32');
- }
- /**
- * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
- * canvas.
- *
- * When the dtype of the input is 'float32', we assume values in the range
- * [0-1]. Otherwise, when input is 'int32', we assume values in the range
- * [0-255].
- *
- * Returns a promise that resolves when the canvas has been drawn to.
- *
- * @param img A rank-2 or rank-3 tensor. If rank-2, draws grayscale. If
- * rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
- * grayscale. When depth of 3, we draw with the first three components of
- * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
- * 4, all four components of the depth dimension correspond to r, g, b, a.
- * @param canvas The canvas to draw to.
- *
- * @doc {heading: 'Browser', namespace: 'browser'}
- */
- async function toPixels(img, canvas) {
- let $img = convertToTensor(img, 'img', 'toPixels');
- if (!(img instanceof Tensor)) {
- // Assume int32 if user passed a native array.
- const originalImgTensor = $img;
- $img = cast(originalImgTensor, 'int32');
- originalImgTensor.dispose();
- }
- if ($img.rank !== 2 && $img.rank !== 3) {
- throw new Error(`toPixels only supports rank 2 or 3 tensors, got rank ${$img.rank}.`);
- }
- const [height, width] = $img.shape.slice(0, 2);
- const depth = $img.rank === 2 ? 1 : $img.shape[2];
- if (depth > 4 || depth === 2) {
- throw new Error(`toPixels only supports depth of size ` +
- `1, 3 or 4 but got ${depth}`);
- }
- if ($img.dtype !== 'float32' && $img.dtype !== 'int32') {
- throw new Error(`Unsupported type for toPixels: ${$img.dtype}.` +
- ` Please use float32 or int32 tensors.`);
- }
- const data = await $img.data();
- const multiplier = $img.dtype === 'float32' ? 255 : 1;
- const bytes = new Uint8ClampedArray(width * height * 4);
- for (let i = 0; i < height * width; ++i) {
- const rgba = [0, 0, 0, 255];
- for (let d = 0; d < depth; d++) {
- const value = data[i * depth + d];
- if ($img.dtype === 'float32') {
- if (value < 0 || value > 1) {
- throw new Error(`Tensor values for a float32 Tensor must be in the ` +
- `range [0 - 1] but encountered ${value}.`);
- }
- }
- else if ($img.dtype === 'int32') {
- if (value < 0 || value > 255) {
- throw new Error(`Tensor values for a int32 Tensor must be in the ` +
- `range [0 - 255] but encountered ${value}.`);
- }
- }
- if (depth === 1) {
- rgba[0] = value * multiplier;
- rgba[1] = value * multiplier;
- rgba[2] = value * multiplier;
- }
- else {
- rgba[d] = value * multiplier;
- }
- }
- const j = i * 4;
- bytes[j + 0] = Math.round(rgba[0]);
- bytes[j + 1] = Math.round(rgba[1]);
- bytes[j + 2] = Math.round(rgba[2]);
- bytes[j + 3] = Math.round(rgba[3]);
- }
- if (canvas != null) {
- canvas.width = width;
- canvas.height = height;
- const ctx = canvas.getContext('2d');
- const imageData = new ImageData(bytes, width, height);
- ctx.putImageData(imageData, 0, 0);
- }
- if ($img !== img) {
- $img.dispose();
- }
- return bytes;
- }
- const fromPixels = op({ fromPixels_ });
-
- var browser = /*#__PURE__*/Object.freeze({
- __proto__: null,
- toPixels: toPixels,
- fromPixels: fromPixels
- });
-
- /**
- * Validate gather nd inputs.
- *
- * @param tensor The tensor contains the source values.
- * @param indices The tensor contains the indices to slice the source.
- *
- * @returns [resultShape, numUpdates, sliceSize, strides]
- */
- function prepareAndValidate(tensor, indices) {
- if (tensor.rank < 1) {
- throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
- ` but the rank was ${tensor.rank}.`);
- }
- if (indices.rank < 1) {
- throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
- ` but the rank was ${indices.rank}.`);
- }
- if (indices.dtype !== 'int32') {
- throw new Error('tf.gatherND() expects the indices to be int32 type,' +
- ` but the dtype was ${indices.dtype}.`);
- }
- if (indices.shape[indices.rank - 1] > tensor.rank) {
- throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
- `${indices.shape[indices.rank - 1]} vs. ${tensor.rank}`);
- }
- if (tensor.size === 0) {
- throw new Error('Requested more than 0 entries, but input is empty.' +
- ` Input shape: ${tensor.shape}.`);
- }
- const indicesShape = indices.shape;
- const sliceRank = indicesShape[indicesShape.length - 1];
- // The result shape is
- // indices.shape[:-1] + params.shape[indices.shape[-1]:]
- let nResult = 1;
- for (let i = 0; i < indicesShape.length - 1; ++i) {
- nResult *= indicesShape[i];
- }
- const inputShape = tensor.shape;
- const resultShape = indicesShape.slice();
- resultShape.pop();
- let sliceSize = 1;
- for (let i = sliceRank; i < tensor.rank; ++i) {
- sliceSize *= inputShape[i];
- resultShape.push(inputShape[i]);
- }
- const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
- 1].slice(0, sliceRank);
- return [resultShape, nResult, sliceSize, strides];
- }
-
- var gather_nd_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- prepareAndValidate: prepareAndValidate
- });
-
- /**
- * Check whether updates.shape = indices.shape[:batchDim] +
- * shape[sliceDim:]
- *
- * @param x The input tensor.
- */
- function validateUpdateShape(shape, indices, updates) {
- const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
- const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
- const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
- `shape[sliceDim:], got updates.shape: ${updates.shape}` +
- `, indices.shape: ${indices.shape}, shape: ${shape}` +
- `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;
- if (updates.rank < batchDim) {
- throw new Error(shapeError + ` update.rank < ${batchDim}. `);
- }
- if (shape.length < sliceDim + (updates.rank - batchDim)) {
- throw new Error(shapeError +
- ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);
- }
- if (updates.rank !== batchDim + shape.length - sliceDim) {
- throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);
- }
- for (let d = 0; d < batchDim; ++d) {
- if (updates.shape[d] !== indices.shape[d]) {
- throw new Error(shapeError +
- ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`);
- }
- }
- for (let d = 0; d < updates.rank - batchDim; ++d) {
- if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
- throw new Error(shapeError +
- ` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`);
- }
- }
- }
- /**
- * Validate scatter nd inputs.
- *
- * @param update The tensor contains the update values.
- * @param indices The tensor contains the indices for the update values.
- * @param shape The shape of the output tensor.
- */
- function validateInput(updates, indices, shape) {
- if (indices.rank < 1) {
- throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
- ` but the rank was ${indices.rank}.`);
- }
- if (updates.rank < 1) {
- throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
- ` but the rank was ${updates.rank}.`);
- }
- if (indices.dtype !== 'int32') {
- throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`);
- }
- if (shape.length < 1) {
- throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`);
- }
- if (shape.length === 0) {
- if (indices.size === 0) {
- throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`);
- }
- if (updates.size === 0) {
- throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`);
- }
- }
- validateUpdateShape(shape, indices, updates);
- }
- /**
- * Calculate the shape information for the output.
- *
- * @param update The tensor contains the update values.
- * @param indices The tensor contains the indices for the update values.
- * @param shape The shape of the output tensor.
- *
- * @returns ScatterShapeInfo
- */
- function calculateShapes(updates, indices, shape) {
- // Calculate the number of dimensions in indices
- const indicesRank = indices.shape.length;
- const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
- // Calculate the number of elements that make up each slice of our updated
- // tensor. This allows us to work with flattened tensors and copy over whole
- // slices at a time.
- const totalNd = shape.length;
- let sliceSize = 1;
- for (let i = sliceRank; i < totalNd; ++i) {
- sliceSize *= shape[i];
- }
- const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
- const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
- const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
- const outputSize = sizeFromShape(shape);
- return { sliceRank, numUpdates, sliceSize, strides, outputSize };
- }
-
- var scatter_nd_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- validateUpdateShape: validateUpdateShape,
- validateInput: validateInput,
- calculateShapes: calculateShapes
- });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertParamsValid(input, begin, size) {
- const inputRank = input.shape.length;
- assert(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` +
- `match the rank of the array (${inputRank}).`);
- assert(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` +
- `match the rank of the array (${inputRank}).`);
- for (let i = 0; i < inputRank; ++i) {
- assert(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` +
- `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`);
- }
- }
- /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
- function maskToAxes(mask) {
- const axes = [];
- let axis = 0;
- while (mask > 0) {
- if (mask & 1) {
- axes.push(axis);
- }
- mask /= 2;
- axis++;
- }
- return axes;
- }
- /** Computes the output shape given the strided slice params. */
- function computeOutShape(begin, end, strides) {
- const size = [];
- for (let axis = 0; axis < begin.length; axis++) {
- size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
- }
- return size;
- }
- // Creates full selection at the elided dimensions. If the dimension matches
- // the ellipsis mask, override the current stride value. Otherwise, insert.
- function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
- const newStrides = [...strides];
- for (let i = newStrides.length; i < inputShape.length; i++) {
- newStrides.push(1);
- }
- for (let i = 0; i < numElidedAxes; i++) {
- if (i === 0) {
- newStrides[ellipsisInsertionIndex] = 1;
- }
- else {
- newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */);
- newStrides.pop();
- }
- }
- return newStrides;
- }
- function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
- if (normalizedAxis <= ellipsisInsertionIndex) {
- return normalizedAxis;
- }
- return normalizedAxis - (numElidedAxes - 1);
- }
- function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
- const elidedAxes = [];
- for (let i = 0; i < numElidedAxes; i++) {
- elidedAxes.push(ellipsisInsertionIndex + i);
- }
- return elidedAxes;
- }
- // Normalize the start, end and strides.
- function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
- const inputRank = inputShape.length;
- let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
- if (ellipsisAxes.length && numInterpolatedAxes > 0) {
- const fullIndex = ellipsisAxes[0];
- // The ellipsis applies to the masked index as well as any dimensions
- // that are interpolated.
- const numElidedAxes = numInterpolatedAxes + 1;
- normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
- normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
- normalizedStrides =
- stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
- }
- else {
- for (let axis = 0; axis < inputRank; axis++) {
- normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
- normalizedEnd[axis] =
- stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
- normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
- }
- }
- return {
- begin: normalizedBegin,
- end: normalizedEnd,
- strides: normalizedStrides
- };
- }
- // Creates full selection at the elided dimensions. If the dimension matches
- // the ellipsis mask, override the current start value. Otherwise, insert.
- function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
- const newIndices = [...inputShape];
- const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
- for (let axis = 0; axis < newIndices.length; axis++) {
- if (elidedAxes.indexOf(axis) > -1) {
- newIndices[axis] = 0;
- }
- else {
- const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
- let originalValue = originalBegin[originalAxis];
- if (beginMask & 1 << originalAxis) {
- originalValue = 0;
- }
- newIndices[axis] = originalValue;
- }
- }
- return newIndices;
- }
- // Creates full selection at the elided dimensions. If the dimension matches
- // the ellipsis mask, override the current stop value. Otherwise, insert.
- function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
- const newIndices = [...inputShape];
- const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
- for (let axis = 0; axis < newIndices.length; axis++) {
- if (elidedAxes.indexOf(axis) > -1) {
- newIndices[axis] = Number.MAX_SAFE_INTEGER;
- }
- else {
- const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
- let originalValue = originalEnd[originalAxis];
- if (endMask & 1 << originalAxis) {
- originalValue = Number.MAX_SAFE_INTEGER;
- }
- newIndices[axis] = originalValue;
- }
- }
- for (let i = 0; i < newIndices.length; i++) {
- // Handle negative indices
- const axisSize = inputShape[i];
- if (newIndices[i] < 0) {
- newIndices[i] += axisSize;
- }
- newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
- }
- return newIndices;
- }
- function stridesForAxis(strides, axis, ellipsisMask) {
- let stride = strides[axis];
- if (ellipsisMask & (1 << axis) || stride == null) {
- stride = 1;
- }
- return stride;
- }
- function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
- // Begin with the specified index
- let start = startIndices[axis];
- const stride = strides[axis] || 1;
- // Check the axis bit from right of masked axes, or the begin index is not set
- // for the axis.
- if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
- if (stride > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = Number.MIN_SAFE_INTEGER;
- }
- else {
- // Backward iteration - use the last element.
- start = Number.MAX_SAFE_INTEGER;
- }
- }
- // Handle negative indices
- const axisSize = inputShape[axis];
- if (start < 0) {
- start += axisSize;
- }
- // Clamping
- start = clamp(0, start, axisSize - 1);
- return start;
- }
- function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
- // Begin with the specified index
- let stop = stopIndices[axis];
- const stride = strides[axis] || 1;
- // Check the axis bit from right of masked axes, or if the stop index is not
- // set for this axis.
- if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
- if (stride > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = Number.MAX_SAFE_INTEGER;
- }
- else {
- // Backward iteration - use the first element.
- stop = Number.MIN_SAFE_INTEGER;
- }
- }
- // Handle negative indices
- const axisSize = inputShape[axis];
- if (stop < 0) {
- stop += axisSize;
- }
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (stride > 0) {
- // Forward iteration
- stop = clamp(0, stop, axisSize);
- }
- else {
- // Backward iteration
- stop = clamp(-1, stop, axisSize - 1);
- }
- return stop;
- }
- /**
- * Returns true if the slice occupies a continous set of elements in the
- * 'flat' space.
- */
- function isSliceContinous(shape, begin, size) {
- // Index of the first axis that has size > 1.
- let firstNonOneAxis = size.length;
- for (let i = 0; i < size.length; i++) {
- if (size[i] > 1) {
- firstNonOneAxis = i;
- break;
- }
- }
- for (let i = firstNonOneAxis + 1; i < size.length; i++) {
- if (begin[i] > 0 || size[i] !== shape[i]) {
- return false;
- }
- }
- return true;
- }
- function computeFlatOffset(begin, strides) {
- let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
- for (let i = 0; i < begin.length - 1; i++) {
- flatOffset += begin[i] * strides[i];
- }
- return flatOffset;
- }
- function parseSliceParams(x, begin, size) {
- // The following logic allows for more ergonomic calls.
- let begin_;
- const xRank = x.shape.length;
- if (typeof begin === 'number') {
- begin_ = [begin, ...new Array(xRank - 1).fill(0)];
- }
- else if (begin.length < xRank) {
- begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
- }
- else {
- begin_ = begin.slice();
- }
- begin_.forEach(d => {
- assert(d !== -1, () => 'slice() does not support negative begin indexing.');
- });
- let size_;
- if (size == null) {
- size_ = new Array(xRank).fill(-1);
- }
- else if (typeof size === 'number') {
- size_ = [size, ...new Array(xRank - 1).fill(-1)];
- }
- else if (size.length < xRank) {
- size_ = size.concat(new Array(xRank - size.length).fill(-1));
- }
- else {
- size_ = size;
- }
- size_ = size_.map((d, i) => {
- if (d >= 0) {
- return d;
- }
- else {
- assert(d === -1, () => `Negative size values should be exactly -1 but got ` +
- `${d} for the slice() size at index ${i}.`);
- return x.shape[i] - begin_[i];
- }
- });
- return [begin_, size_];
- }
-
- var slice_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- assertParamsValid: assertParamsValid,
- maskToAxes: maskToAxes,
- computeOutShape: computeOutShape,
- stridesWithElidedDims: stridesWithElidedDims,
- getNormalizedAxes: getNormalizedAxes,
- startIndicesWithElidedDims: startIndicesWithElidedDims,
- stopIndicesWithElidedDims: stopIndicesWithElidedDims,
- stridesForAxis: stridesForAxis,
- startForAxis: startForAxis,
- stopForAxis: stopForAxis,
- isSliceContinous: isSliceContinous,
- computeFlatOffset: computeFlatOffset,
- parseSliceParams: parseSliceParams
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Serializable defines the serialization contract.
- *
- * TFJS requires serializable classes to return their className when asked
- * to avoid issues with minification.
- */
- class Serializable {
- /**
- * Return the class name for this class to use in serialization contexts.
- *
- * Generally speaking this will be the same thing that constructor.name
- * would have returned. However, the class name needs to be robust
- * against minification for serialization/deserialization to work properly.
- *
- * There's also places such as initializers.VarianceScaling, where
- * implementation details between different languages led to different
- * class hierarchies and a non-leaf node is used for serialization purposes.
- */
- getClassName() {
- return this.constructor
- .className;
- }
- /**
- * Creates an instance of T from a ConfigDict.
- *
- * This works for most descendants of serializable. A few need to
- * provide special handling.
- * @param cls A Constructor for the class to instantiate.
- * @param config The Configuration for the object.
- */
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config);
- }
- }
- /**
- * Maps string keys to class constructors.
- *
- * Used during (de)serialization from the cross-language JSON format, which
- * requires the class name in the serialization format matches the class
- * names as used in Python, should it exist.
- */
- class SerializationMap {
- constructor() {
- this.classNameMap = {};
- }
- /**
- * Returns the singleton instance of the map.
- */
- static getMap() {
- if (SerializationMap.instance == null) {
- SerializationMap.instance = new SerializationMap();
- }
- return SerializationMap.instance;
- }
- /**
- * Registers the class as serializable.
- */
- static register(cls) {
- SerializationMap.getMap().classNameMap[cls.className] =
- [cls, cls.fromConfig];
- }
- }
- /**
- * Register a class with the serialization map of TensorFlow.js.
- *
- * This is often used for registering custom Layers, so they can be
- * serialized and deserialized.
- *
- * Example:
- *
- * ```js
- * class MyCustomLayer extends tf.layers.Layer {
- * static className = 'MyCustomLayer';
- *
- * constructor(config) {
- * super(config);
- * }
- * }
- * tf.serialization.registerClass(MyCustomLayer);
- * ```
- *
- * @param cls The class to be registered. It must have a public static member
- * called `className` defined and the value must be a non-empty string.
- *
- * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
- */
- function registerClass(cls) {
- assert(cls.className != null, () => `Class being registered does not have the static className ` +
- `property defined.`);
- assert(typeof cls.className === 'string', () => `className is required to be a string, but got type ` +
- typeof cls.className);
- assert(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` +
- `which is disallowed.`);
- SerializationMap.register(cls);
- }
-
- var serialization = /*#__PURE__*/Object.freeze({
- __proto__: null,
- Serializable: Serializable,
- SerializationMap: SerializationMap,
- registerClass: registerClass
- });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const TEST_EPSILON_FLOAT32 = 1e-3;
- const TEST_EPSILON_FLOAT16 = 1e-1;
- function expectArraysClose(actual, expected, epsilon) {
- if (epsilon == null) {
- epsilon = testEpsilon();
- }
- return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, epsilon));
- }
- function testEpsilon() {
- return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :
- TEST_EPSILON_FLOAT16;
- }
- function expectArraysPredicate(actual, expected, predicate) {
- let checkClassType = true;
- if (isTypedArray(actual) || isTypedArray(expected)) {
- checkClassType = false;
- }
- if (isTypedArray(actual) && isTypedArray(expected)) {
- checkClassType = true;
- }
- if (checkClassType) {
- const aType = actual.constructor.name;
- const bType = expected.constructor.name;
- if (aType !== bType) {
- throw new Error(`Arrays are of different type. Actual: ${aType}. ` +
- `Expected: ${bType}`);
- }
- }
- if (Array.isArray(actual) && Array.isArray(expected)) {
- const actualShape = inferShape(actual);
- const expectedShape = inferShape(expected);
- if (!arraysEqual(actualShape, expectedShape)) {
- throw new Error(`Arrays have different shapes. ` +
- `Actual: [${actualShape}]. Expected: [${expectedShape}]`);
- }
- }
- const actualFlat = isTypedArray(actual) ? actual : flatten(actual);
- const expectedFlat = isTypedArray(expected) ?
- expected :
- flatten(expected);
- if (actualFlat.length !== expectedFlat.length) {
- throw new Error(`Arrays have different lengths actual: ${actualFlat.length} vs ` +
- `expected: ${expectedFlat.length}.\n` +
- `Actual: ${actualFlat}.\n` +
- `Expected: ${expectedFlat}.`);
- }
- for (let i = 0; i < expectedFlat.length; ++i) {
- const a = actualFlat[i];
- const e = expectedFlat[i];
- if (!predicate(a, e)) {
- throw new Error(`Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\n` +
- `Actual: ${actualFlat}.\n` +
- `Expected: ${expectedFlat}.`);
- }
- }
- }
- function expectPromiseToFail(fn, done) {
- fn().then(() => done.fail(), () => done());
- }
- function expectArraysEqual(actual, expected) {
- const exp = typeof expected === 'string' || typeof expected === 'number' ||
- typeof expected === 'boolean' ?
- [expected] :
- expected;
- if (isString(actual) || isString(actual[0]) ||
- isString(expected) || isString(expected[0])) {
- // tslint:disable-next-line: triple-equals
- return expectArraysPredicate(actual, exp, (a, b) => a == b);
- }
- return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, 0));
- }
- function expectNumbersClose(a, e, epsilon) {
- if (epsilon == null) {
- epsilon = testEpsilon();
- }
- if (!areClose(a, e, epsilon)) {
- throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`);
- }
- }
- function areClose(a, e, epsilon) {
- if (!isFinite(a) && !isFinite(e)) {
- return true;
- }
- if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
- return false;
- }
- return true;
- }
- function expectValuesInRange(actual, low, high) {
- for (let i = 0; i < actual.length; i++) {
- if (actual[i] < low || actual[i] > high) {
- throw new Error(`Value out of range:${actual[i]} low: ${low}, high: ${high}`);
- }
- }
- }
- function expectArrayBuffersEqual(actual, expected) {
- // Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in
- // a Float32Array solves this issue.
- expect(new Float32Array(actual)).toEqual(new Float32Array(expected));
- }
-
- var test_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
- expectArraysClose: expectArraysClose,
- testEpsilon: testEpsilon,
- expectPromiseToFail: expectPromiseToFail,
- expectArraysEqual: expectArraysEqual,
- expectNumbersClose: expectNumbersClose,
- expectValuesInRange: expectValuesInRange,
- expectArrayBuffersEqual: expectArrayBuffersEqual
- });
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version = '0.0.0';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Enables production mode which disables correctness checks in favor of
- * performance.
- *
- * @doc {heading: 'Environment'}
- */
- function enableProdMode() {
- env().set('PROD', true);
- }
- /**
- * Enables debug mode which will log information about all executed kernels:
- * the elapsed time of the kernel execution, as well as the rank, shape, and
- * size of the output tensor.
- *
- * Debug mode will significantly slow down your application as it will
- * download the result of every operation to the CPU. This should not be used in
- * production. Debug mode does not affect the timing information of the kernel
- * execution as we do not measure download time in the kernel execution time.
- *
- * See also: `tf.profile`, `tf.memory`.
- *
- * @doc {heading: 'Environment'}
- */
- function enableDebugMode() {
- env().set('DEBUG', true);
- }
- /** Globally disables deprecation warnings */
- function disableDeprecationWarnings() {
- env().set('DEPRECATION_WARNINGS_ENABLED', false);
- console.warn(`TensorFlow.js deprecation warnings have been disabled.`);
- }
- /** Warn users about deprecated functionality. */
- function deprecationWarn(msg) {
- if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
- console.warn(msg + ' You can disable deprecation warnings with ' +
- 'tf.disableDeprecationWarnings().');
- }
- }
- setDeprecationWarningFn(deprecationWarn);
- /**
- * Dispose all variables kept in backend engine.
- *
- * @doc {heading: 'Environment'}
- */
- function disposeVariables() {
- ENGINE.disposeVariables();
- }
- /**
- * It returns the global engine that keeps track of all tensors and backends.
- *
- * @doc {heading: 'Environment'}
- */
- function engine() {
- return ENGINE;
- }
- /**
- * Returns memory info at the current time in the program. The result is an
- * object with the following properties:
- *
- * - `numBytes`: Number of bytes allocated (undisposed) at this time.
- * - `numTensors`: Number of unique tensors allocated.
- * - `numDataBuffers`: Number of unique data buffers allocated
- * (undisposed) at this time, which is ≤ the number of tensors
- * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
- * data buffer with `a`).
- * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
- * `unreliable` is true.
- * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
- * `unreliable` is true.
- *
- * WebGL Properties:
- * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
- * this time.
- *
- * @doc {heading: 'Performance', subheading: 'Memory'}
- */
- function memory() {
- return ENGINE.memory();
- }
- /**
- * Executes the provided function `f()` and returns a promise that resolves
- * with information about the function's memory use:
- * - `newBytes`: the number of new bytes allocated
- * - `newTensors`: the number of new tensors created
- * - `peakBytes`: the peak number of bytes allocated
- * - `kernels`: an array of objects for each kernel involved that reports
- * their input and output shapes, number of bytes used, and number of new
- * tensors created.
- *
- * ```js
- * const profile = await tf.profile(() => {
- * const x = tf.tensor1d([1, 2, 3]);
- * let x2 = x.square();
- * x2.dispose();
- * x2 = x.square();
- * x2.dispose();
- * return x;
- * });
- *
- * console.log(`newBytes: ${profile.newBytes}`);
- * console.log(`newTensors: ${profile.newTensors}`);
- * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
- * k.totalBytesSnapshot)}`);
- * ```
- *
- *
- * @doc {heading: 'Performance', subheading: 'Profile'}
- */
- function profile(f) {
- return ENGINE.profile(f);
- }
- /**
- * Executes the provided function `fn` and after it is executed, cleans up all
- * intermediate tensors allocated by `fn` except those returned by `fn`.
- * `fn` must not return a Promise (async functions not allowed). The returned
- * result can be a complex object.
- *
- * Using this method helps avoid memory leaks. In general, wrap calls to
- * operations in `tf.tidy` for automatic memory cleanup.
- *
- * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
- * dispose variables, please use `tf.disposeVariables` or call dispose()
- * directly on variables.
- *
- * ```js
- * // y = 2 ^ 2 + 1
- * const y = tf.tidy(() => {
- * // a, b, and one will be cleaned up when the tidy ends.
- * const one = tf.scalar(1);
- * const a = tf.scalar(2);
- * const b = a.square();
- *
- * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
- *
- * // The value returned inside the tidy function will return
- * // through the tidy, in this case to the variable y.
- * return b.add(one);
- * });
- *
- * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
- * y.print();
- * ```
- *
- * @param nameOrFn The name of the closure, or the function to execute.
- * If a name is provided, the 2nd argument should be the function.
- * If debug mode is on, the timing and the memory usage of the function
- * will be tracked and displayed on the console using the provided name.
- * @param fn The function to execute.
- *
- * @doc {heading: 'Performance', subheading: 'Memory'}
- */
- function tidy(nameOrFn, fn) {
- return ENGINE.tidy(nameOrFn, fn);
- }
- /**
- * Disposes any `tf.Tensor`s found within the provided object.
- *
- * @param container an object that may be a `tf.Tensor` or may directly
- * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
- * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
- * happens. In general it is safe to pass any object here, except that
- * `Promise`s are not supported.
- *
- * @doc {heading: 'Performance', subheading: 'Memory'}
- */
- function dispose(container) {
- const tensors = getTensorsInContainer(container);
- tensors.forEach(tensor => tensor.dispose());
- }
- /**
- * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
- * automatically.
- *
- * ```js
- * let b;
- * const y = tf.tidy(() => {
- * const one = tf.scalar(1);
- * const a = tf.scalar(2);
- *
- * // b will not be cleaned up by the tidy. a and one will be cleaned up
- * // when the tidy ends.
- * b = tf.keep(a.square());
- *
- * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
- *
- * // The value returned inside the tidy function will return
- * // through the tidy, in this case to the variable y.
- * return b.add(one);
- * });
- *
- * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
- * console.log('y:');
- * y.print();
- * console.log('b:');
- * b.print();
- * ```
- *
- * @param result The tensor to keep from being disposed.
- *
- * @doc {heading: 'Performance', subheading: 'Memory'}
- */
- function keep(result) {
- return ENGINE.keep(result);
- }
- /**
- * Executes `f()` and returns a promise that resolves with timing
- * information.
- *
- * The result is an object with the following properties:
- *
- * - `wallMs`: Wall execution time.
- * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
- * WebGL backend and the query timer extension is not available, this will
- * return an error object.
- * - On `WebGL` The following additional properties exist:
- * - `uploadWaitMs`: CPU blocking time on texture uploads.
- * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
- *
- * ```js
- * const x = tf.randomNormal([20, 20]);
- * const time = await tf.time(() => x.matMul(x));
- *
- * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
- * ```
- *
- * @param f The function to execute and time.
- *
- * @doc {heading: 'Performance', subheading: 'Timing'}
- */
- function time(f) {
- return ENGINE.time(f);
- }
- /**
- * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
- * executing operations on those tensors. Returns a promise that resolves
- * to a boolean if the backend initialization was successful.
- *
- * Note this disposes the current backend, if any, as well as any tensors
- * associated with it. A new backend is initialized, even if it is of the
- * same type as the previous one.
- *
- * @param backendName The name of the backend. Currently supports
- * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
- * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
- *
- * @doc {heading: 'Backends'}
- */
- function setBackend(backendName) {
- return ENGINE.setBackend(backendName);
- }
- /**
- * Returns a promise that resolves when the currently selected backend (or the
- * highest priority one) has initialized. Await this promise when you are using
- * a backend that has async initialization.
- *
- * @doc {heading: 'Backends'}
- */
- function ready() {
- return ENGINE.ready();
- }
- /**
- * Returns the current backend name (cpu, webgl, etc). The backend is
- * responsible for creating tensors and executing operations on those tensors.
- *
- * @doc {heading: 'Backends'}
- */
- function getBackend() {
- return ENGINE.backendName;
- }
- /**
- * Removes a backend and the registered factory.
- *
- * @doc {heading: 'Backends'}
- */
- function removeBackend(name) {
- ENGINE.removeBackend(name);
- }
- /**
- * Finds the backend registered under the provided name. Returns null if the
- * name is not in the registry, or the registration hasn't finished yet.
- */
- function findBackend(name) {
- return ENGINE.findBackend(name);
- }
- /**
- * Finds the backend factory registered under the provided name. Returns a
- * function that produces a new backend when called. Returns null if the name
- * is not in the registry.
- */
- function findBackendFactory(name) {
- return ENGINE.findBackendFactory(name);
- }
- /**
- * Registers a global backend. The registration should happen when importing
- * a module file (e.g. when importing `backend_webgl.ts`), and is used for
- * modular builds (e.g. custom tfjs bundle with only webgl support).
- *
- * @param factory The backend factory function. When called, it should
- * return a backend instance, or a promise of an instance.
- * @param priority The priority of the backend (higher = more important).
- * In case multiple backends are registered, the priority is used to find
- * the best backend. Defaults to 1.
- * @return False if there is already a registered backend under this name, true
- * if not.
- *
- * @doc {heading: 'Backends'}
- */
- function registerBackend(name, factory, priority = 1) {
- return ENGINE.registerBackend(name, factory, priority);
- }
- /**
- * Gets the current backend. If no backends have been initialized, this will
- * attempt to initialize the best backend. Will throw an error if the highest
- * priority backend has async initialization, in which case, you should call
- * 'await tf.ready()' before running other code.
- *
- * @doc {heading: 'Backends'}
- */
- function backend() {
- return ENGINE.backend;
- }
- /**
- * Sets the global platform.
- *
- * @param platformName The name of this platform.
- * @param platform A platform implementation.
- */
- function setPlatform(platformName, platform) {
- env().setPlatform(platformName, platform);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
- *
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.tensor1d([10, 20, 30, 40]);
- *
- * a.add(b).print(); // or tf.add(a, b)
- * ```
- *
- * ```js
- * // Broadcast add a with b.
- * const a = tf.scalar(5);
- * const b = tf.tensor1d([10, 20, 30, 40]);
- *
- * a.add(b).print(); // or tf.add(a, b)
- * ```
- * @param a The first `tf.Tensor` to add.
- * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function add_(a, b) {
- let $a = convertToTensor(a, 'a', 'add');
- let $b = convertToTensor(b, 'b', 'add');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.add($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Add);
- }
- const add$1 = op({ add_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
- * The result is rounded with floor function.
- *
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 9, 16]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.floorDiv(b).print(); // or tf.div(a, b)
- * ```
- *
- * ```js
- * // Broadcast div a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(2);
- *
- * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
- * ```
- *
- * @param a The first tensor as the numerator.
- * @param b The second tensor as the denominator. Must have the same dtype as
- * `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function floorDiv_(a, b) {
- let $a = convertToTensor(a, 'a', 'floorDiv');
- let $b = convertToTensor(b, 'b', 'floorDiv');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.floorDiv($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, FloorDiv);
- }
- const floorDiv = op({ floorDiv_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 9, 16]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.div(b).print(); // or tf.div(a, b)
- * ```
- *
- * ```js
- * // Broadcast div a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(2);
- *
- * a.div(b).print(); // or tf.div(a, b)
- * ```
- *
- * @param a The first tensor as the numerator.
- * @param b The second tensor as the denominator. Must have the same dtype as
- * `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function div_(a, b) {
- let $a = convertToTensor(a, 'a', 'div');
- let $b = convertToTensor(b, 'b', 'div');
- [$a, $b] = makeTypesMatch($a, $b);
- if ($a.dtype === 'int32' && $b.dtype === 'int32') {
- return floorDiv($a, $b);
- }
- const forward = (backend, save) => {
- const res = backend.realDivide($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- const attrs = {};
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Div, attrs);
- }
- const div = op({ div_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
- *
- * We also expose `tf.mulStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.tensor1d([2, 3, 4, 5]);
- *
- * a.mul(b).print(); // or tf.mul(a, b)
- * ```
- *
- * ```js
- * // Broadcast mul a with b.
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.scalar(5);
- *
- * a.mul(b).print(); // or tf.mul(a, b)
- * ```
- * @param a The first tensor to multiply.
- * @param b The second tensor to multiply. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function mul_(a, b) {
- let $a = convertToTensor(a, 'a', 'mul');
- let $b = convertToTensor(b, 'b', 'mul');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.multiply($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Multiply);
- }
- const mul = op({ mul_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes absolute value element-wise: `abs(x)`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.abs().print(); // or tf.abs(x)
- * ```
- * @param x The input `tf.Tensor`.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function abs_(x) {
- const $x = convertToTensor(x, 'x', 'abs');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- save([$x]);
- if ($x.dtype === 'complex64') {
- return backend.complexAbs($x);
- }
- return backend.abs($x);
- }, inputs, null /* grad */, Abs);
- }
- const abs = op({ abs_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.acos().print(); // or tf.acos(x)
- * ```
- * @param x The input tensor.
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function acos_(x) {
- const $x = convertToTensor(x, 'x', 'acos');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.acos($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Acos);
- }
- const acos = op({ acos_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
- * `acosh(x)`
- *
- * ```js
- * const x = tf.tensor1d([10, 1, 3, 5.7]);
- *
- * x.acosh().print(); // or tf.acosh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function acosh_(x) {
- const $x = convertToTensor(x, 'x', 'acosh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.acosh($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Acosh);
- }
- const acosh = op({ acosh_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * const c = tf.tensor1d([5, 6]);
- *
- * tf.addN([a, b, c]).print();
- * ```
- * @param tensors A list of tensors with the same shape and dtype.
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function addN_(tensors) {
- assert(Array.isArray(tensors), () => 'The argument passed to tf.addN() must be a list of tensors');
- assert(tensors.length >= 1, () => `Must pass at least one tensor to tf.addN(), but got ` +
- `${tensors.length}`);
- const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'addN'));
- const firstTensor = $tensors[0];
- $tensors.forEach(t => {
- if (t.dtype !== firstTensor.dtype) {
- throw new Error('All tensors passed to tf.addN() must have the same dtype');
- }
- });
- $tensors.forEach(t => {
- if (!arraysEqual(t.shape, firstTensor.shape)) {
- throw new Error('All tensors passed to tf.addN() must have the same shape');
- }
- });
- const forward = (backend, save) => {
- const res = backend.addN($tensors);
- save($tensors);
- return res;
- };
- const inputs = $tensors;
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, AddN);
- }
- const addN = op({ addN_ });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns true if the axis specifies the inner most dimensions of the
- * array.
- */
- function axesAreInnerMostDims(axes, rank) {
- for (let i = 0; i < axes.length; ++i) {
- if (axes[axes.length - i - 1] !== rank - 1 - i) {
- return false;
- }
- }
- return true;
- }
- function combineLocations(outputLoc, reduceLoc, axes) {
- const rank = outputLoc.length + reduceLoc.length;
- const loc = [];
- let outIdx = 0;
- let reduceIdx = 0;
- for (let dim = 0; dim < rank; dim++) {
- if (axes.indexOf(dim) === -1) {
- loc.push(outputLoc[outIdx++]);
- }
- else {
- loc.push(reduceLoc[reduceIdx++]);
- }
- }
- return loc;
- }
- function computeOutAndReduceShapes(aShape, axes) {
- const outShape = [];
- const rank = aShape.length;
- for (let dim = 0; dim < rank; dim++) {
- if (axes.indexOf(dim) === -1) {
- outShape.push(aShape[dim]);
- }
- }
- const reduceShape = axes.map(dim => aShape[dim]);
- return [outShape, reduceShape];
- }
- function expandShapeToKeepDim(shape, axes) {
- const reduceSubShape = axes.map(x => 1);
- return combineLocations(shape, reduceSubShape, axes);
- }
- function assertAxesAreInnerMostDims(msg, axes, rank) {
- assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
- `Got axes ${axes} and rank-${rank} input.`);
- }
- /**
- * Returns the axes permutation to be used with `tf.transpose`, if such
- * permutation is necessary. Otherwise it returns null. This method is used by
- * operations that operate only on inner-most axes.
- */
- function getAxesPermutation(axes, rank) {
- if (axesAreInnerMostDims(axes, rank)) {
- return null;
- }
- const result = [];
- for (let i = 0; i < rank; ++i) {
- if (axes.indexOf(i) === -1) {
- result.push(i);
- }
- }
- axes.forEach(axis => result.push(axis));
- return result;
- }
- /** Returns the axes permutation that undoes the original permutation. */
- function getUndoAxesPermutation(axes) {
- return axes.map((axis, i) => [i, axis])
- .sort((a, b) => a[1] - b[1])
- .map(x => x[0]);
- }
- function getInnerMostAxes(numAxes, rank) {
- const res = [];
- for (let i = rank - numAxes; i < rank; ++i) {
- res.push(i);
- }
- return res;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the logical and of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
- * `axes`. If `keepDims` is true, the reduced dimensions are retained with
- * length 1. If `axes` has no entries, all dimensions are reduced, and an
- * `tf.Tensor` with a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 1, 1], 'bool');
- *
- * x.all().print(); // or tf.all(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
- *
- * const axis = 1;
- * x.all(axis).print(); // or tf.all(x, axis)
- * ```
- *
- * @param x The input tensor. Must be of dtype bool.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function all_(x, axis = null, keepDims = false) {
- let $x = convertToTensor(x, 'x', 'all', 'bool');
- const forward = (backend) => {
- const origAxes = parseAxisParam(axis, $x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- if (permutedAxes != null) {
- $x = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, $x.rank);
- }
- const res = backend.all($x, axes);
- if (keepDims) {
- const newShape = expandShapeToKeepDim(res.shape, origAxes);
- return reshape(res, newShape);
- }
- return res;
- };
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, All, attrs);
- }
- const all = op({ all_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the logical or of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
- * `axes`. If `keepDims` is true, the reduced dimensions are retained with
- * length 1. If `axes` has no entries, all dimensions are reduced, and an
- * `tf.Tensor` with a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 1, 1], 'bool');
- *
- * x.any().print(); // or tf.any(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
- *
- * const axis = 1;
- * x.any(axis).print(); // or tf.any(x, axis)
- * ```
- *
- * @param x The input tensor. Must be of dtype bool.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function any_(x, axis = null, keepDims = false) {
- let $x = convertToTensor(x, 'x', 'any', 'bool');
- const forward = (backend) => {
- const origAxes = parseAxisParam(axis, $x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- if (permutedAxes != null) {
- $x = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, $x.rank);
- }
- const res = backend.any($x, axes);
- if (keepDims) {
- const newShape = expandShapeToKeepDim(res.shape, origAxes);
- return reshape(res, newShape);
- }
- return res;
- };
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Any, attrs);
- }
- // tslint:disable-next-line:variable-name
- const any = op({ any_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the indices of the maximum values along an `axis`.
- *
- * The result has the same shape as `input` with the dimension along `axis`
- * removed.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.argMax().print(); // or tf.argMax(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
- *
- * const axis = 1;
- * x.argMax(axis).print(); // or tf.argMax(x, axis)
- * ```
- *
- * @param x The input tensor.
- * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function argMax_(x, axis = 0) {
- let $x = convertToTensor(x, 'x', 'argMax');
- const forward = (backend, save) => {
- save([$x]);
- let axes = parseAxisParam(axis, $x.shape);
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- if (permutedAxes != null) {
- $x = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, $x.rank);
- }
- return backend.argMax($x, axes[0]);
- };
- const inputs = { x: $x };
- const attrs = { axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, ArgMax, attrs);
- }
- const argMax = op({ argMax_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the indices of the minimum values along an `axis`.
- *
- * The result has the same shape as `input` with the dimension along `axis`
- * removed.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.argMin().print(); // or tf.argMin(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
- *
- * const axis = 1;
- * x.argMin(axis).print(); // or tf.argMin(x, axis)
- * ```
- *
- * @param x The input tensor.
- * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function argMin_(x, axis = 0) {
- let $x = convertToTensor(x, 'x', 'argMin');
- const forward = (backend, save) => {
- save([$x]);
- if (axis == null) {
- axis = 0;
- }
- let axes = parseAxisParam(axis, $x.shape);
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- if (permutedAxes != null) {
- $x = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, $x.rank);
- }
- return backend.argMin($x, axes[0]);
- };
- const inputs = { x: $x };
- const attrs = { axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, ArgMin, attrs);
- }
- const argMin = op({ argMin_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.asin().print(); // or tf.asin(x)
- * ```
- * @param x The input tensor.
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function asin_(x) {
- const $x = convertToTensor(x, 'x', 'asin');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.asin($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Asin);
- }
- const asin = op({ asin_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
- * `asinh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.asinh().print(); // or tf.asinh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function asinh_(x) {
- const $x = convertToTensor(x, 'x', 'asinh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.asinh($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Asinh);
- }
- const asinh = op({ asinh_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.atan().print(); // or tf.atan(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function atan_(x) {
- const $x = convertToTensor(x, 'x', 'atan');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.atan($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Atan);
- }
- const atan = op({ atan_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
- * Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
- * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
- *
- * tf.atan2(a, b).print()
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function atan2_(a, b) {
- let $a = convertToTensor(a, 'a', 'atan2');
- let $b = convertToTensor(b, 'b', 'atan2');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.atan2($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Atan2);
- }
- const atan2 = op({ atan2_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
- * `atanh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, .1, -.1, .7]);
- *
- * x.atanh().print(); // or tf.atanh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function atanh_(x) {
- const $x = convertToTensor(x, 'x', 'atanh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.atanh($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Atanh);
- }
- const atanh = op({ atanh_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- *
- * @param inputShape Input tensor shape is of the following dimensions:
- * `[batch, height, width, inChannels]`.
- * @param filterShape The filter shape is of the following dimensions:
- * `[filterHeight, filterWidth, depth]`.
- * @param strides The strides of the sliding window for each dimension of the
- * input tensor: `[strideHeight, strideWidth]`.
- * If `strides` is a single number,
- * then `strideHeight == strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1*1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat The data format of the input and output data.
- * Defaults to 'NHWC'.
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
- * Defaults to `[1, 1]`. If `dilations` is a single number, then
- * `dilationHeight == dilationWidth`.
- */
- function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) {
- // `computerConv2DInfo` require filterShape to be in the dimension of:
- // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
- // outDepth, it should have the same depth as the input.
- // Input shape: [batch, height, width, inChannels]
- const inputChannels = inputShape[3];
- const $filterShape = [...filterShape, inputChannels];
- const $dataFormat = convertConv2DDataFormat(dataFormat);
- return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null /* roundingMode */, null /* depthWise */, $dataFormat);
- }
- function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') {
- const [filterHeight, filterWidth] = parseTupleParam(filterSize);
- let filterShape;
- if (dataFormat === 'channelsLast') {
- filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
- }
- else if (dataFormat === 'channelsFirst') {
- filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
- }
- else {
- throw new Error(`Unknown dataFormat ${dataFormat}`);
- }
- return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
- }
- /**
- * Computes the information for a forward pass of a pooling3D operation.
- */
- function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') {
- const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);
- let filterShape;
- let $dataFormat;
- if (dataFormat === 'NDHWC') {
- $dataFormat = 'channelsLast';
- filterShape =
- [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
- }
- else if (dataFormat === 'NCDHW') {
- $dataFormat = 'channelsFirst';
- filterShape =
- [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
- }
- else {
- throw new Error(`Unknown dataFormat ${dataFormat}`);
- }
- return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
- }
- /**
- * Computes the information for a forward pass of a convolution/pooling
- * operation.
- */
- function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') {
- let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
- if (dataFormat === 'channelsLast') {
- [batchSize, inHeight, inWidth, inChannels] = inShape;
- }
- else if (dataFormat === 'channelsFirst') {
- [batchSize, inChannels, inHeight, inWidth] = inShape;
- }
- else {
- throw new Error(`Unknown dataFormat ${dataFormat}`);
- }
- const [filterHeight, filterWidth, , filterChannels] = filterShape;
- const [strideHeight, strideWidth] = parseTupleParam(strides);
- const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
- const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
- const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
- const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat);
- const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
- let outShape;
- if (dataFormat === 'channelsFirst') {
- outShape = [batchSize, outChannels, outHeight, outWidth];
- }
- else if (dataFormat === 'channelsLast') {
- outShape = [batchSize, outHeight, outWidth, outChannels];
- }
- return {
- batchSize,
- dataFormat,
- inHeight,
- inWidth,
- inChannels,
- outHeight,
- outWidth,
- outChannels,
- padInfo,
- strideHeight,
- strideWidth,
- filterHeight,
- filterWidth,
- effectiveFilterHeight,
- effectiveFilterWidth,
- dilationHeight,
- dilationWidth,
- inShape,
- outShape,
- filterShape
- };
- }
- /**
- * Computes the information for a forward pass of a 3D convolution/pooling
- * operation.
- */
- function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) {
- let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1];
- if (dataFormat === 'channelsLast') {
- [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
- }
- else if (dataFormat === 'channelsFirst') {
- [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
- }
- else {
- throw new Error(`Unknown dataFormat ${dataFormat}`);
- }
- const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape;
- const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
- const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
- const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
- const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
- const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
- const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode);
- const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
- let outShape;
- if (dataFormat === 'channelsFirst') {
- outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
- }
- else if (dataFormat === 'channelsLast') {
- outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
- }
- return {
- batchSize,
- dataFormat,
- inDepth,
- inHeight,
- inWidth,
- inChannels,
- outDepth,
- outHeight,
- outWidth,
- outChannels,
- padInfo,
- strideDepth,
- strideHeight,
- strideWidth,
- filterDepth,
- filterHeight,
- filterWidth,
- effectiveFilterDepth,
- effectiveFilterHeight,
- effectiveFilterWidth,
- dilationDepth,
- dilationHeight,
- dilationWidth,
- inShape,
- outShape,
- filterShape
- };
- }
- function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
- if (zeroPad == null) {
- zeroPad = computeDefaultPad(inShape, fieldSize, stride);
- }
- const inputRows = inShape[0];
- const inputCols = inShape[1];
- const outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputRows), () => `The output # of rows (${outputRows}) must be an integer. ` +
- `Change the stride and/or zero pad parameters`);
- const outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputCols), () => `The output # of columns (${outputCols}) must be an integer. ` +
- `Change the stride and/or zero pad parameters`);
- return [outputRows, outputCols];
- }
- function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
- if (zeroPad == null) {
- zeroPad = computeDefaultPad(inShape, fieldSize, stride);
- }
- const inputDepth = inShape[0];
- const inputRows = inShape[1];
- const inputCols = inShape[2];
- const outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputDepths), () => `The output # of depths (${outputDepths}) must be an integer. ` +
- `Change the stride and/or zero pad parameters`);
- const outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputRows), () => `The output # of rows (${outputRows}) must be an integer. ` +
- `Change the stride and/or zero pad parameters`);
- const outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputCols), () => `The output # of columns (${outputCols}) must be an integer. ` +
- `Change the stride and/or zero pad parameters`);
- return [outputDepths, outputRows, outputCols, outChannels];
- }
- function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) {
- const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
- return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
- }
- function parseTupleParam(param) {
- if (typeof param === 'number') {
- return [param, param, param];
- }
- if (param.length === 2) {
- return [param[0], param[1], 1];
- }
- return param;
- }
- function parse3TupleParam(param) {
- return typeof param === 'number' ? [param, param, param] : param;
- }
- /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
- * Atrous convolution is equivalent to standard convolution with upsampled
- * filters with effective_filter_height =
- * filter_height + (filter_height - 1) * (dilation - 1)
- * and effective_filter_width =
- * filter_width + (filter_width - 1) * (dilation - 1),
- * produced by inserting dilation - 1 zeros along consecutive elements across
- * the filters' spatial dimensions.
- * When there is a dilation, this converts a filter dimension to the
- * effective filter dimension, so it can be used in a standard convolution.
- */
- function getEffectiveFilterSize(filterSize, dilation) {
- if (dilation <= 1) {
- return filterSize;
- }
- return filterSize + (filterSize - 1) * (dilation - 1);
- }
- function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
- let padInfo;
- let outHeight;
- let outWidth;
- if (typeof pad === 'number') {
- const padType = (pad === 0) ? 'VALID' : 'NUMBER';
- padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
- const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
- outHeight = outShape[0];
- outWidth = outShape[1];
- }
- else if (pad === 'same') {
- outHeight = Math.ceil(inHeight / strideHeight);
- outWidth = Math.ceil(inWidth / strideWidth);
- const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
- const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
- const top = Math.floor(padAlongHeight / 2);
- const bottom = padAlongHeight - top;
- const left = Math.floor(padAlongWidth / 2);
- const right = padAlongWidth - left;
- padInfo = { top, bottom, left, right, type: 'SAME' };
- }
- else if (pad === 'valid') {
- padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
- outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
- outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
- }
- else if (typeof pad === 'object') {
- const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
- const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
- const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
- const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
- const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
- 'VALID' :
- 'EXPLICIT';
- padInfo = { top, bottom, left, right, type: padType };
- outHeight = conditionalRound((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
- outWidth = conditionalRound((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
- }
- else {
- throw Error(`Unknown padding parameter: ${pad}`);
- }
- return { padInfo, outHeight, outWidth };
- }
- function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
- let padInfo;
- let outDepth;
- let outHeight;
- let outWidth;
- if (typeof pad === 'number') {
- const padType = (pad === 0) ? 'VALID' : 'NUMBER';
- padInfo = {
- top: pad,
- bottom: pad,
- left: pad,
- right: pad,
- front: pad,
- back: pad,
- type: padType
- };
- const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
- outDepth = outShape[0];
- outHeight = outShape[1];
- outWidth = outShape[2];
- }
- else if (pad === 'same') {
- outDepth = Math.ceil(inDepth / strideDepth);
- outHeight = Math.ceil(inHeight / strideHeight);
- outWidth = Math.ceil(inWidth / strideWidth);
- const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
- const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
- const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
- const front = Math.floor(padAlongDepth / 2);
- const back = padAlongDepth - front;
- const top = Math.floor(padAlongHeight / 2);
- const bottom = padAlongHeight - top;
- const left = Math.floor(padAlongWidth / 2);
- const right = padAlongWidth - left;
- padInfo = { top, bottom, left, right, front, back, type: 'SAME' };
- }
- else if (pad === 'valid') {
- padInfo = {
- top: 0,
- bottom: 0,
- left: 0,
- right: 0,
- front: 0,
- back: 0,
- type: 'VALID'
- };
- outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
- outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
- outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
- }
- else {
- throw Error(`Unknown padding parameter: ${pad}`);
- }
- return { padInfo, outDepth, outHeight, outWidth };
- }
- /**
- * Rounds a value depending on the rounding mode
- * @param value
- * @param roundingMode
- */
- function conditionalRound(value, roundingMode) {
- if (!roundingMode) {
- return value;
- }
- switch (roundingMode) {
- case 'round':
- // used for Caffe Conv
- return Math.round(value);
- case 'ceil':
- // used for Caffe Pool
- return Math.ceil(value);
- case 'floor':
- return Math.floor(value);
- default:
- throw new Error(`Unknown roundingMode ${roundingMode}`);
- }
- }
- function tupleValuesAreOne(param) {
- const [dimA, dimB, dimC] = parseTupleParam(param);
- return dimA === 1 && dimB === 1 && dimC === 1;
- }
- function eitherStridesOrDilationsAreOne(strides, dilations) {
- return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
- }
- /**
- * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
- * 'channelsLast'|'channelsFirst'
- * @param dataFormat in 'NHWC'|'NCHW' mode
- * @return dataFormat in 'channelsLast'|'channelsFirst' mode
- * @throws unknown dataFormat
- */
- function convertConv2DDataFormat(dataFormat) {
- if (dataFormat === 'NHWC') {
- return 'channelsLast';
- }
- else if (dataFormat === 'NCHW') {
- return 'channelsFirst';
- }
- else {
- throw new Error(`Unknown dataFormat ${dataFormat}`);
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the 2D average pooling of an image.
- *
- * @param x The input tensor, of rank 4 or rank 3 of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param pad The type of padding algorithm:
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- */
- function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'avgPool', 'float32');
- const dilations = 1;
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in avgPool: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- const convInfo = computePool2DInfo(x4D.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
- save([x4D]);
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- return x4D.clone();
- }
- return backend.avgPool(x4D, convInfo);
- };
- const inputs = { x: x4D };
- const attrs = { filterSize, strides, pad, dimRoundingMode };
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, AvgPool, attrs);
- res = cast(res, $x.dtype);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const avgPool = op({ avgPool_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the 3D average pooling.
- *
- * ```js
- * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
- * const result = tf.avgPool3d(x, 2, 1, 'valid');
- * result.print();
- * ```
- *
- * @param x The input tensor, of rank 5 or rank 4 of shape
- * `[batch, depth, height, width, inChannels]`.
- * @param filterSize The filter size:
- * `[filterDepth, filterHeight, filterWidth]`.
- * If `filterSize` is a single number,
- * then `filterDepth == filterHeight == filterWidth`.
- * @param strides The strides of the pooling:
- * `[strideDepth, strideHeight, strideWidth]`.
- * If `strides` is a single number,
- * then `strideDepth == strideHeight == strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1*1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
- * "NDHWC". Specify the data format of the input and output data. With the
- * default format "NDHWC", the data is stored in the order of: [batch,
- * depth, height, width, channels]. Only "NDHWC" is currently supported.
- * @param dilations Deprecated, this field will be gone in v3.0.0.
- * The dilation rates:
- * `[dilationDepth, dilationHeight, dilationWidth]`
- * in which we sample input values across the depth, height and width
- * dimensions in dilated pooling.
- * Defaults to `[1, 1, 1]`. If `dilations` is a single number,
- * then `dilationDepth == dilationHeight == dilationWidth`.
- * If it is greater than 1, then all values of `strides` must be 1.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat = 'NDHWC', dilations) {
- if (dilations == null) {
- dilations = [1, 1, 1];
- }
- else {
- deprecationWarn('dilations is deprecated, this field will be gone in ' +
- 'v3.0.0.');
- }
- const $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
- let x5D = $x;
- let reshapedTo5D = false;
- if ($x.rank === 4) {
- reshapedTo5D = true;
- x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
- }
- assert(x5D.rank === 5, () => `Error in avgPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
- assert(dataFormat === 'NDHWC', () => `Error in avgPool3d: Only NDHWC is currently supported, ` +
- `but got dataFormat of ${dataFormat}`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool3d: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in avgPool3d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- if (dilations == null) {
- dilations = [1, 1, 1];
- }
- const convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
- save([x5D]);
- return backend.avgPool3d(x5D, convInfo);
- };
- const inputs = { x: x5D };
- const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat, dilations };
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, AvgPool3D, attrs);
- res = cast(res, x5D.dtype);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const avgPool3d = op({ avgPool3d_ });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertParamsConsistent(shapes, axis) {
- const rank = shapes[0].length;
- shapes.forEach((shape, i) => {
- assert(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
- `as the rank of the rest (${rank})`);
- });
- assert(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
- const firstShape = shapes[0];
- shapes.forEach((shape, i) => {
- for (let r = 0; r < rank; r++) {
- assert((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
- `does not match the shape of the rest (${firstShape}) ` +
- `along the non-concatenated axis ${i}.`);
- }
- });
- }
- function computeOutShape$1(shapes, axis) {
- const outputShape = shapes[0].slice();
- for (let i = 1; i < shapes.length; i++) {
- outputShape[axis] += shapes[i][axis];
- }
- return outputShape;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Concatenates a list of `tf.Tensor`s along a given axis.
- *
- * The tensors ranks and types must match, and their sizes must match in all
- * dimensions except `axis`.
- *
- * Also available are stricter rank-specific methods that assert that
- * `tensors` are of the given rank:
- * - `tf.concat1d`
- * - `tf.concat2d`
- * - `tf.concat3d`
- * - `tf.concat4d`
- *
- * Except `tf.concat1d` (which does not have axis param), all methods have
- * same signature as this method.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * a.concat(b).print(); // or a.concat(b)
- * ```
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * const c = tf.tensor1d([5, 6]);
- * tf.concat([a, b, c]).print();
- * ```
- *
- * ```js
- * const a = tf.tensor2d([[1, 2], [10, 20]]);
- * const b = tf.tensor2d([[3, 4], [30, 40]]);
- * const axis = 1;
- * tf.concat([a, b], axis).print();
- * ```
- * @param tensors A list of tensors to concatenate.
- * @param axis The axis to concate along. Defaults to 0 (the first dim).
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function concat_(tensors, axis = 0) {
- assert(tensors.length >= 1, () => 'Pass at least one tensor to concat');
- let $tensors = convertToTensorArray(tensors, 'tensors', 'concat');
- if ($tensors[0].dtype === 'complex64') {
- $tensors.forEach(tensor => {
- if (tensor.dtype !== 'complex64') {
- throw new Error(`Cannot concatenate complex64 tensors with a tensor
- with dtype ${tensor.dtype}. `);
- }
- });
- }
- const forward = (backend, save) => {
- const $axis = parseAxisParam(axis, $tensors[0].shape)[0];
- const outShape = computeOutShape$1($tensors.map(t => t.shape), $axis);
- if (sizeFromShape(outShape) === 0) {
- return tensor([], outShape);
- }
- // Keep only non-empty tensors (ignore tensors with 0 in their shape).
- $tensors = $tensors.filter(t => t.size > 0);
- if ($tensors.length === 1) {
- return $tensors[0];
- }
- const shapes = $tensors.map(t => t.shape);
- assertParamsConsistent(shapes, $axis);
- const res = backend.concat($tensors, $axis);
- save($tensors);
- return res;
- };
- const inputs = $tensors;
- const attr = { axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Concat, attr);
- }
- const concat = op({ concat_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
- *
- * ```js
- * const x = tf.tensor1d([0, -1, 2, -3]);
- *
- * x.sigmoid().print(); // or tf.sigmoid(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function sigmoid_(x) {
- const $x = convertToTensor(x, 'x', 'sigmoid');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.sigmoid($x);
- save([res]);
- return res;
- }, inputs, null /* grad */, Sigmoid);
- }
- const sigmoid = op({ sigmoid_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
- * and is of size `size`.
- *
- * Also available are stricter rank-specific methods with the same signature
- * as this method that assert that `x` is of the given rank:
- * - `tf.slice1d`
- * - `tf.slice2d`
- * - `tf.slice3d`
- * - `tf.slice4d`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- *
- * x.slice([1], [2]).print();
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * x.slice([1, 0], [1, 2]).print();
- * ```
- * @param x The input `tf.Tensor` to slice from.
- * @param begin The coordinates to start the slice from. The length can be
- * less than the rank of x - the rest of the axes will have implicit 0 as
- * start. Can also be a single number, in which case it specifies the
- * first axis.
- * @param size The size of the slice. The length can be less than the rank of
- * x - the rest of the axes will have implicit -1. A value of -1 requests
- * the rest of the dimensions in the axis. Can also be a single number,
- * in which case it specifies the size of the first axis.
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function slice_(x, begin, size) {
- const $x = convertToTensor(x, 'x', 'slice');
- if ($x.rank === 0) {
- throw new Error('Slicing scalar is not possible');
- }
- const forward = (backend, save) => {
- const [begin_, size_] = parseSliceParams($x, begin, size);
- assertParamsValid($x, begin_, size_);
- save([$x]);
- return backend.slice($x, begin_, size_);
- };
- const inputs = { x: $x };
- const attrs = { begin, size };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Slice, attrs);
- }
- const slice = op({ slice_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, 70]);
- *
- * x.tanh().print(); // or tf.tanh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function tanh_(x) {
- const $x = convertToTensor(x, 'x', 'tanh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const y = backend.tanh($x);
- save([y]);
- return y;
- }, inputs, null /* grad */, Tanh);
- }
- const tanh$1 = op({ tanh_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the next state and output of a BasicLSTMCell.
- *
- * Returns `[newC, newH]`.
- *
- * Derived from tf.contrib.rnn.BasicLSTMCell.
- *
- * @param forgetBias Forget bias for the cell.
- * @param lstmKernel The weights for the cell.
- * @param lstmBias The bias for the cell.
- * @param data The input to the cell.
- * @param c Previous cell state.
- * @param h Previous cell output.
- *
- * @doc {heading: 'Operations', subheading: 'RNN'}
- */
- function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
- const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
- const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
- const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
- const $data = convertToTensor(data, 'data', 'basicLSTMCell');
- const $c = convertToTensor(c, 'c', 'basicLSTMCell');
- const $h = convertToTensor(h, 'h', 'basicLSTMCell');
- const combined = concat([$data, $h], 1);
- const weighted = matMul(combined, $lstmKernel);
- const res = add$1(weighted, $lstmBias);
- // i = input_gate, j = new_input, f = forget_gate, o = output_gate
- const batchSize = res.shape[0];
- const sliceCols = res.shape[1] / 4;
- const sliceSize = [batchSize, sliceCols];
- const i = slice(res, [0, 0], sliceSize);
- const j = slice(res, [0, sliceCols], sliceSize);
- const f = slice(res, [0, sliceCols * 2], sliceSize);
- const o = slice(res, [0, sliceCols * 3], sliceSize);
- const newC = add$1(mul(sigmoid(i), tanh$1(j)), mul($c, sigmoid(add$1($forgetBias, f))));
- const newH = mul(tanh$1(newC), sigmoid(o));
- return [newC, newH];
- }
- const basicLSTMCell = op({ basicLSTMCell_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
- * shape `blockShape + [batch]`, interleaves these blocks back into the grid
- * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
- * the same rank as the input. The spatial dimensions of this intermediate
- * result are then optionally cropped according to `crops` to produce the
- * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
- * description.
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
- * const blockShape = [2, 2];
- * const crops = [[0, 0], [0, 0]];
- *
- * x.batchToSpaceND(blockShape, crops).print();
- * ```
- *
- * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
- * remainingShape`, where spatialShape has `M` dimensions.
- * @param blockShape A 1-D array. Must have shape `[M]`, all values must
- * be >= 1.
- * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
- * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
- * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
- * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
- *
- * This operation is equivalent to the following steps:
- *
- * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
- * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
- * x.shape[N-1]]`
- *
- * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
- * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
- * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
- *
- * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
- * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
- * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
- *
- * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
- * according to `crops` to produce the output of shape: `[batch /
- * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
- * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
- * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function batchToSpaceND_(x, blockShape, crops) {
- const $x = convertToTensor(x, 'x', 'batchToSpaceND');
- const prod = blockShape.reduce((a, b) => a * b);
- assert($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);
- assert(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);
- assert($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` +
- `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
- const forward = backend => {
- return backend.batchToSpaceND($x, blockShape, crops);
- };
- const inputs = { x: $x };
- const attrs = { blockShape, crops };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, BatchToSpaceND, attrs);
- }
- const batchToSpaceND = op({ batchToSpaceND_ });
-
- function xAs4D(x) {
- let x4D;
- if (x.rank === 0 || x.rank === 1) {
- x4D = reshape(x, [1, 1, 1, x.size]);
- }
- else if (x.rank === 2) {
- x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
- }
- else if (x.rank === 3) {
- x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
- }
- else {
- x4D = x;
- }
- return x4D;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Batch normalization.
- *
- * As described in
- * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
- *
- * Mean, variance, scale, and offset can be of two shapes:
- * - The same shape as the input.
- * - In the common case, the depth dimension is the last dimension of x, so
- * the values would be an `tf.Tensor1D` of shape [depth].
- *
- * Also available are stricter rank-specific methods with the same signature
- * as this method that assert that parameters passed are of given rank
- * - `tf.batchNorm2d`
- * - `tf.batchNorm3d`
- * - `tf.batchNorm4d`
- *
- * @param x The input Tensor.
- * @param mean A mean Tensor.
- * @param variance A variance Tensor.
- * @param offset An offset Tensor.
- * @param scale A scale Tensor.
- * @param varianceEpsilon A small float number to avoid dividing by 0.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
- if (varianceEpsilon == null) {
- varianceEpsilon = 0.001;
- }
- const $x = convertToTensor(x, 'x', 'batchNorm');
- const $mean = convertToTensor(mean, 'mean', 'batchNorm');
- const $variance = convertToTensor(variance, 'variance', 'batchNorm');
- let $scale;
- if (scale != null) {
- $scale = convertToTensor(scale, 'scale', 'batchNorm');
- }
- let $offset;
- if (offset != null) {
- $offset = convertToTensor(offset, 'offset', 'batchNorm');
- }
- assert($mean.rank === $variance.rank, () => 'Batch normalization gradient requires mean and variance to have ' +
- 'equal ranks.');
- assert($offset == null || $mean.rank === $offset.rank, () => 'Batch normalization gradient requires mean and offset to have ' +
- 'equal ranks.');
- assert($scale == null || $mean.rank === $scale.rank, () => 'Batch normalization gradient requires mean and scale to have ' +
- 'equal ranks.');
- const x4D = xAs4D($x);
- const forward = (backend, save) => {
- save([x4D, $mean, $variance, $scale]);
- return backend.batchNorm(x4D, as1DOr4D($mean), as1DOr4D($variance), as1DOr4D($offset), as1DOr4D($scale), varianceEpsilon);
- };
- const inputs = {
- x: x4D,
- scale: $scale,
- offset: $offset,
- mean: $mean,
- variance: $variance
- };
- const attrs = { varianceEpsilon };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* gradient */, FusedBatchNorm, attrs);
- return reshape(res, $x.shape);
- }
- function as1DOr4D(x) {
- if (x == null) {
- return null;
- }
- if (x.rank === 0) {
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return reshape(x, [x.size]);
- }
- else if (x.rank === 1) {
- return x;
- }
- else if (x.rank === 2) {
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return reshape(x, [1, 1, x.shape[0], x.shape[1]]);
- }
- else if (x.rank === 3) {
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
- }
- return x;
- }
- const batchNorm = op({ batchNorm_ });
-
- /**
- * Batch normalization, strictly for 2D. For the more relaxed version, see
- * `tf.batchNorm`.
- *
- * @param x The input Tensor.
- * @param mean A mean Tensor.
- * @param variance A variance Tensor.
- * @param offset An offset Tensor.
- * @param scale A scale Tensor.
- * @param varianceEpsilon A small float number to avoid dividing by 0.
- */
- function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
- const $x = convertToTensor(x, 'x', 'batchNorm');
- const $mean = convertToTensor(mean, 'mean', 'batchNorm');
- const $variance = convertToTensor(variance, 'variance', 'batchNorm');
- let $scale;
- if (scale != null) {
- $scale = convertToTensor(scale, 'scale', 'batchNorm');
- }
- let $offset;
- if (offset != null) {
- $offset = convertToTensor(offset, 'offset', 'batchNorm');
- }
- assert($x.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ` +
- `${$x.rank}.`);
- assert($mean.rank === 2 || $mean.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +
- `got rank ${$mean.rank}.`);
- assert($variance.rank === 2 || $variance.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +
- `but got rank ${$variance.rank}.`);
- if ($scale != null) {
- assert($scale.rank === 2 || $scale.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +
- `but got rank ${$scale.rank}.`);
- }
- if ($offset != null) {
- assert($offset.rank === 2 || $offset.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +
- `but got rank ${$offset.rank}.`);
- }
- return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
- }
- const batchNorm2d = op({ batchNorm2d_ });
-
- /**
- * Batch normalization, strictly for 3D. For the more relaxed version, see
- * `tf.batchNorm`.
- *
- * @param x The input Tensor.
- * @param mean A mean Tensor.
- * @param variance A variance Tensor.
- * @param offset An offset Tensor.
- * @param scale A scale Tensor.
- * @param varianceEpsilon A small float number to avoid dividing by 0.
- */
- function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
- const $x = convertToTensor(x, 'x', 'batchNorm');
- const $mean = convertToTensor(mean, 'mean', 'batchNorm');
- const $variance = convertToTensor(variance, 'variance', 'batchNorm');
- let $scale;
- if (scale != null) {
- $scale = convertToTensor(scale, 'scale', 'batchNorm');
- }
- let $offset;
- if (offset != null) {
- $offset = convertToTensor(offset, 'offset', 'batchNorm');
- }
- assert($x.rank === 3, () => `Error in batchNorm3D: x must be rank 3 but got rank ` +
- `${$x.rank}.`);
- assert($mean.rank === 3 || $mean.rank === 1, () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` +
- `got rank ${$mean.rank}.`);
- assert($variance.rank === 3 || $variance.rank === 1, () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` +
- `but got rank ${$variance.rank}.`);
- if ($scale != null) {
- assert($scale.rank === 3 || $scale.rank === 1, () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` +
- `but got rank ${$scale.rank}.`);
- }
- if ($offset != null) {
- assert($offset.rank === 3 || $offset.rank === 1, () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` +
- `but got rank ${$offset.rank}.`);
- }
- return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
- }
- const batchNorm3d = op({ batchNorm3d_ });
-
- /**
- * Batch normalization, strictly for 4D. For the more relaxed version, see
- * `tf.batchNorm`.
- *
- * @param x The input Tensor.
- * @param mean A mean Tensor.
- * @param variance A variance Tensor.
- * @param offset An offset Tensor.
- * @param scale A scale Tensor.
- * @param varianceEpsilon A small float number to avoid dividing by 0.
- */
- function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
- const $x = convertToTensor(x, 'x', 'batchNorm');
- const $mean = convertToTensor(mean, 'mean', 'batchNorm');
- const $variance = convertToTensor(variance, 'variance', 'batchNorm');
- let $scale;
- if (scale != null) {
- $scale = convertToTensor(scale, 'scale', 'batchNorm');
- }
- let $offset;
- if (offset != null) {
- $offset = convertToTensor(offset, 'offset', 'batchNorm');
- }
- assert($x.rank === 4, () => `Error in batchNorm4D: x must be rank 4 but got rank ` +
- `${$x.rank}.`);
- assert($mean.rank === 4 || $mean.rank === 1, () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` +
- `got rank ${$mean.rank}.`);
- assert($variance.rank === 4 || $variance.rank === 1, () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` +
- `but got rank ${$variance.rank}.`);
- if ($scale != null) {
- assert($scale.rank === 4 || $scale.rank === 1, () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` +
- `but got rank ${$scale.rank}.`);
- }
- if ($offset != null) {
- assert($offset.rank === 4 || $offset.rank === 1, () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` +
- `but got rank ${$offset.rank}.`);
- }
- return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
- }
- const batchNorm4d = op({ batchNorm4d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Broadcast an array to a compatible shape NumPy-style.
- *
- * The tensor's shape is compared to the broadcast shape from end to beginning.
- * Ones are prepended to the tensor's shape until is has the same length as
- * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
- * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
- * the input tensor is tiled N times along that axis (using tf.tile).
- *
- * @param input The tensor that is to be broadcasted.
- * @param shape The input is to be broadcast to this shape.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function broadcastTo_(x, shape) {
- let input = convertToTensor(x, 'broadcastTo', 'x');
- const xShape = input.shape;
- if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
- throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
- }
- if (shape.length < input.rank) {
- throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
- }
- if (shape.length > input.rank) {
- const newShape = input.shape.slice();
- while (newShape.length < shape.length) {
- newShape.unshift(1);
- }
- input = reshape(input, newShape);
- }
- const inputShape = input.shape;
- const reps = Array.from(shape);
- for (let i = shape.length - 1; i >= 0; i--) {
- if (inputShape[i] === shape[i]) {
- reps[i] = 1;
- }
- else if (input.shape[i] !== 1) {
- throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
- }
- }
- const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
- if (axes.length === 0) {
- return clone(input);
- }
- const forward = (backend) => backend.tile(input, reps);
- const inputs = { x: input };
- const attrs = { shape, inputShape };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, BroadcastTo, attrs);
- }
- const broadcastTo = op({ broadcastTo_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.ceil().print(); // or tf.ceil(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function ceil_(x) {
- const $x = convertToTensor(x, 'x', 'ceil');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.ceil($x), inputs, null /* grad */, Ceil);
- }
- const ceil = op({ ceil_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
- * ```
- * @param x The input tensor.
- * @param clipValueMin Lower-bound of range to be clipped to.
- * @param clipValueMax Upper-bound of range to be clipped to.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function clipByValue_(x, clipValueMin, clipValueMax) {
- const $x = convertToTensor(x, 'x', 'clipByValue');
- assert((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` +
- `less than or equal to max (${clipValueMax}).`);
- const inputs = { x: $x };
- const attrs = { clipValueMin, clipValueMax };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.clip($x, clipValueMin, clipValueMax);
- save([$x]);
- return res;
- }, inputs, null /* grad */, ClipByValue, attrs);
- }
- const clipByValue = op({ clipByValue_ });
-
- /**
- * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
- *
- * For example, if:
- * A: shape(3) = |r1, g1, b1|
- * B: shape(2) = |r2, g2|
- * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
- *
- * @param tensors A list of`tf.Tensor`s to concatenate.
- * @return The concatenated array.
- */
- function concat1d_(tensors) {
- return concat(tensors, 0 /* axis */);
- }
- const concat1d = op({ concat1d_ });
-
- /**
- * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
- *
- * For example, if:
- * A: shape(2, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- *
- * B: shape(2, 3) = | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * C = tf.concat2d([A, B], axis)
- *
- * if axis = 0:
- * C: shape(4, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- * | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * if axis = 1:
- * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- *
- * @param tensors A list of `tf.Tensor`s to concatenate.
- * @param axis The axis to concatenate along.
- * @return The concatenated array.
- */
- function concat2d_(tensors, axis) {
- return concat(tensors, axis);
- }
- const concat2d = op({ concat2d_ });
-
- /**
- * Concatenates a list of `tf.Tensor3D`s along an axis.
- * See `concat` for details.
- *
- * For example, if:
- * A: shape(2, 1, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- *
- * B: shape(2, 1, 3) = | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * C = tf.concat3d([A, B], axis)
- *
- * if axis = 0:
- * C: shape(4, 1, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- * | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * if axis = 1:
- * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- * if axis = 2:
- * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- * @param tensors A list of`tf.Tensor`s to concatenate.
- * @param axis The axis to concate along.
- * @return The concatenated array.
- */
- function concat3d_(tensors, axis) {
- return concat(tensors, axis);
- }
- const concat3d = op({ concat3d_ });
-
- /**
- * Concatenates a list of `tf.Tensor4D`s along an axis.
- * See `concat` for details.
- *
- * @param tensors A list of `tf.Tensor`s to concatenate.
- * @param axis The axis to concate along.
- * @return The concatenated array.
- */
- function concat4d_(tensors, axis) {
- return concat(tensors, axis);
- }
- const concat4d = op({ concat4d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes a 2D convolution over the input x.
- *
- * @param x The input tensor, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
- * assumed.
- * @param filter The filter, rank 4, of shape
- * `[filterHeight, filterWidth, inDepth, outDepth]`.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels].
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'conv2d');
- const $filter = convertToTensor(filter, 'filter', 'conv2d');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);
- assert($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` +
- `${$filter.rank}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in conv2d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
- assert(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` +
- `input depth for filter ${$filter.shape[2]}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const forward = (backend, save) => {
- const $dataFormat = convertConv2DDataFormat(dataFormat);
- const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat);
- const res = backend.conv2d(x4D, $filter, convInfo);
- save([x4D, $filter]);
- return res;
- };
- const inputs = { x: x4D, filter: $filter };
- const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Conv2D, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const conv2d = op({ conv2d_ });
-
- /**
- * Computes a 1D convolution over the input x.
- *
- * @param x The input tensor, of rank 3 or rank 2, of shape
- * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
- * @param filter The filter, rank 3, of shape
- * `[filterWidth, inDepth, outDepth]`.
- * @param stride The number of entries by which the filter is moved right at
- * each step.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
- * the data is stored in the order of [batch, in_width, in_channels]. Only
- * "NWC" is currently supported.
- * @param dilation The dilation rate in which we sample input values in
- * atrous convolution. Defaults to `1`. If it is greater than 1, then
- * stride must be `1`.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function conv1d_(x, filter, stride, pad, dataFormat = 'NWC', dilation = 1, dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'conv1d');
- const $filter = convertToTensor(filter, 'filter', 'conv1d');
- let x3D = $x;
- let reshapedTo3D = false;
- if ($x.rank === 2) {
- reshapedTo3D = true;
- x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);
- }
- assert(x3D.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`);
- assert($filter.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ` +
- `${$filter.rank}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in conv1d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- assert(x3D.shape[2] === $filter.shape[1], () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
- `input depth for filter ${$filter.shape[1]}.`);
- assert(eitherStridesOrDilationsAreOne(stride, dilation), () => 'Error in conv1D: Either stride or dilation must be 1. ' +
- `Got stride ${stride} and dilation '${dilation}'`);
- assert(dataFormat === 'NWC', () => `Error in conv1d: got dataFormat of ${dataFormat} but only NWC is currently supported.`);
- const filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
- const input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
- const strides = [1, stride];
- const dilations = [1, dilation];
- const conv2dDataFormat = 'NHWC';
- const res = conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
- if (reshapedTo3D) {
- return reshape(res, [res.shape[2], res.shape[3]]);
- }
- return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]);
- }
- const conv1d = op({ conv1d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the derivative of the input of a 2D convolution.
- *
- * @param xShape The shape of the input: [batch, height, width, inDepth].
- * If length of 3, batch of 1 is assumed.
- * @param dy The derivative of the output, of rank 4 or rank 3 of shape
- * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
- * assumed.
- * @param filter The filter, rank 4, of shape
- * `[filterHeight, filterWidth, inDepth, outDepth]`.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`.
- * @param pad The type of padding algorithm used:
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels].
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- */
- function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
- assert(xShape.length === dy.rank, () => `Length of inShape ` +
- `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
- let xShape4D = xShape;
- let dy4D = dy;
- let reshapedTo4D = false;
- if (dy.rank === 3) {
- reshapedTo4D = true;
- dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
- xShape4D = [1, xShape[0], xShape[1], xShape[2]];
- }
- assert(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` +
- `${xShape4D.length}.`);
- assert(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` +
- `rank ${dy4D.rank}`);
- assert(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` +
- `rank ${filter.rank}`);
- const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
- const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
- assert(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +
- `match input depth for filter ${filter.shape[2]}.`);
- assert(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
- `match output depth for filter ${filter.shape[3]}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in conv2dDerInput: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- const dilations = 1;
- const $dataFormat = convertConv2DDataFormat(dataFormat);
- const convInfo = computeConv2DInfo(xShape4D, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat);
- const res = backend.conv2dDerInput(dy4D, filter, convInfo);
- save([dy4D, filter]);
- return res;
- };
- const inputs = { dy: dy4D, filter };
- const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Conv2DBackpropInput, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const conv2DBackpropInput = op({ conv2DBackpropInput_ });
-
- /**
- * Computes the transposed 2D convolution of an image, also known as a
- * deconvolution.
- *
- * @param x The input image, of rank 4 or rank 3, of shape
- * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
- * @param filter The filter, rank 4, of shape
- * `[filterHeight, filterWidth, outDepth, inDepth]`.
- * `inDepth` must match `inDepth` in `x`.
- * @param outputShape Output shape, of rank 4 or rank 3:
- * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
- * @param strides The strides of the original convolution:
- * `[strideHeight, strideWidth]`.
- * @param pad The type of padding algorithm used in the non-transpose version
- * of the op.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'conv2dTranspose');
- const $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
- return conv2DBackpropInput(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
- }
- const conv2dTranspose = op({ conv2dTranspose_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes a 3D convolution over the input x.
- *
- * @param x The input tensor, of rank 5 or rank 4, of shape
- * `[batch, depth, height, width, channels]`. If rank 4,
- * batch of 1 is assumed.
- * @param filter The filter, rank 5, of shape
- * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
- * inChannels must match between input and filter.
- * @param strides The strides of the convolution: `[strideDepth, strideHeight,
- * strideWidth]`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
- * "NDHWC". Specify the data format of the input and output data. With the
- * default format "NDHWC", the data is stored in the order of: [batch,
- * depth, height, width, channels]. Only "NDHWC" is currently supported.
- * @param dilations The dilation rates: `[dilationDepth, dilationHeight,
- * dilationWidth]` in which we sample input values across the height
- * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
- * If `dilations` is a single number, then
- * `dilationDepth == dilationHeight == dilationWidth`. If it is greater
- * than 1, then all values of `strides` must be 1.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function conv3d_(x, filter, strides, pad, dataFormat = 'NDHWC', dilations = [1, 1, 1]) {
- const $x = convertToTensor(x, 'x', 'conv3d');
- const $filter = convertToTensor(filter, 'filter', 'conv3d');
- let x5D = $x;
- let reshapedTo5D = false;
- if ($x.rank === 4) {
- reshapedTo5D = true;
- x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
- }
- assert(x5D.rank === 5, () => `Error in conv3d: input must be rank 5, but got rank ${x5D.rank}.`);
- assert($filter.rank === 5, () => `Error in conv3d: filter must be rank 5, but got rank ` +
- `${$filter.rank}.`);
- assert(x5D.shape[4] === $filter.shape[3], () => `Error in conv3d: depth of input (${x5D.shape[4]}) must match ` +
- `input depth for filter ${$filter.shape[3]}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv3D: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- assert(dataFormat === 'NDHWC', () => `Error in conv3d: got dataFormat of ${dataFormat} but only NDHWC is currently supported.`);
- const forward = (backend, save) => {
- const convInfo = computeConv3DInfo(x5D.shape, $filter.shape, strides, dilations, pad);
- const res = backend.conv3d(x5D, $filter, convInfo);
- save([x5D, $filter]);
- return res;
- };
- const inputs = { x: x5D, filter: $filter };
- const attrs = { strides, pad, dataFormat, dilations };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Conv3D, attrs);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const conv3d = op({ conv3d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the derivative of the input of a 3D convolution.
- *
- * @param xShape The shape of the input: [batch, depth, height, width,
- * in_channels]. If length of 4, batch of 1 is assumed.
- * @param dy The derivative of the output, of rank 5 or rank 4 of shape
- * `[batch, outDepth, outHeight, outWidth, in_channels]`.
- * If rank 4, batch of 1 is assumed.
- * @param filter The filter, rank 5, of shape
- * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
- * @param strides The strides of the convolution: `[strideDepth, strideHeight,
- * strideWidth]`.
- * @param pad The type of padding algorithm used:
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- */
- function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
- assert(xShape.length === dy.rank, () => `Length of inShape ` +
- `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
- let xShape5D = xShape;
- let dy5D = dy;
- let reshapedTo5D = false;
- if (dy.rank === 4) {
- reshapedTo5D = true;
- dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
- xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
- }
- const inDepth = xShape5D[4];
- const outDepth = dy5D.shape[4];
- assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
- `${xShape5D.length}.`);
- assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
- `rank ${dy5D.rank}`);
- assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
- `rank ${filter.rank}`);
- assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
- `match input depth for filter ${filter.shape[3]}.`);
- assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
- `match output depth for filter ${filter.shape[4]}.`);
- const forward = backend => {
- const dilations = 1;
- const convInfo = computeConv3DInfo(xShape5D, filter.shape, strides, dilations, pad);
- return backend.conv3dDerInput(dy5D, filter, convInfo);
- };
- const inputs = { dy: dy5D };
- const attrs = { pad };
- const res = ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropInputV2, attrs);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const conv3DBackpropInput = op({ conv3DBackpropInput_ });
-
- /**
- * Computes the transposed 3D convolution of a volume, also known as a
- * deconvolution.
- *
- * @param x The input image, of rank 5 or rank 4, of shape
- * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
- * @param filter The filter, rank 4, of shape
- * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
- * `inDepth` must match `inDepth` in `x`.
- * @param outputShape Output shape, of rank 5 or rank 4:
- * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
- * assumed.
- * @param strides The strides of the original convolution:
- * `[strideDepth, strideHeight, strideWidth]`.
- * @param pad The type of padding algorithm used in the non-transpose version
- * of the op.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function conv3dTranspose_(x, filter, outputShape, strides, pad) {
- const $x = convertToTensor(x, 'x', 'conv3dTranspose');
- const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
- return conv3DBackpropInput(outputShape, $x, $filter, strides, pad);
- }
- const conv3dTranspose = op({ conv3dTranspose_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.cos().print(); // or tf.cos(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function cos_(x) {
- const $x = convertToTensor(x, 'x', 'cos');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.cos($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Cos);
- }
- const cos = op({ cos_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.cosh().print(); // or tf.cosh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function cosh_(x) {
- const $x = convertToTensor(x, 'x', 'cosh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.cosh($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Cosh);
- }
- const cosh = op({ cosh_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the cumulative sum of a `tf.Tensor` along `axis`.
- *
- * ```js
- * const x = tf.tensor([1, 2, 3, 4]);
- * x.cumsum().print();
- * ```
- * ```js
- * const x = tf.tensor([[1, 2], [3, 4]]);
- * x.cumsum().print();
- * ```
- *
- * @param x The input tensor to be summed.
- * @param axis The axis along which to sum. Optional. Defaults to 0.
- * @param exclusive Whether to perform exclusive cumulative sum. Optional.
- * Defaults to false. If set to true then the sum of each tensor entry
- * does not include its own value, but only the values previous to it
- * along the specified axis.
- * @param reverse Whether to sum in the opposite direction. Optional.
- * Defaults to false.
- *
- * @doc {heading: 'Operations', subheading: 'Scan'}
- */
- function cumsum_(x, axis = 0, exclusive = false, reverse = false) {
- const $x = convertToTensor(x, 'x', 'cumsum');
- const forward = (backend, save) => {
- const permutation = getAxesPermutation([axis], $x.rank);
- let permutedX = $x;
- if (permutation != null) {
- permutedX = transpose($x, permutation);
- }
- const permutedAxis = getInnerMostAxes(1, $x.rank)[0];
- let value = backend.cumsum(permutedX, permutedAxis, exclusive, reverse);
- save([$x]);
- if (permutation != null) {
- const reversePermutation = getUndoAxesPermutation(permutation);
- value = transpose(value, reversePermutation);
- }
- return value;
- };
- const inputs = { x: $x };
- const attrs = { axis, exclusive, reverse };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Cumsum, attrs);
- }
- const cumsum = op({ cumsum_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Rearranges data from depth into blocks of spatial data. More specifically,
- * this op outputs a copy of the input tensor where values from the `depth`
- * dimension are moved in spatial blocks to the `height` and `width` dimensions.
- * The attr `blockSize` indicates the input block size and how the data is
- * moved.
- *
- * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
- * into non-overlapping blocks of size `blockSize x blockSize`
- *
- * - The width the output tensor is `inputWidth * blockSize`, whereas the
- * height is `inputHeight * blockSize`
- *
- * - The Y, X coordinates within each block of the output image are determined
- * by the high order component of the input channel index
- *
- * - The depth of the input tensor must be divisible by `blockSize *
- * blockSize`
- *
- * The `dataFormat` attr specifies the layout of the input and output tensors
- * with the following options: "NHWC": [ `batch, height, width, channels` ]
- * "NCHW": [ `batch, channels, height, width` ]
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
- * const blockSize = 2;
- * const dataFormat = "NHWC";
- *
- * tf.depthToSpace(x, blockSize, dataFormat).print();
- * ```
- *
- * @param x The input tensor of rank 4
- * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
- * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function depthToSpace_(x, blockSize, dataFormat = 'NHWC') {
- const $x = convertToTensor(x, 'x', 'depthToSpace');
- const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
- const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
- const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
- assert(inputHeight * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
- ${inputHeight} and ${blockSize} for depthToSpace with input shape
- ${$x.shape}`);
- assert(inputWidth * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
- ${inputWidth} and ${blockSize} for depthToSpace with input shape
- ${$x.shape}`);
- assert((inputDepth % (blockSize * blockSize) === 0), () => `Dimension size must be evenly divisible by ${blockSize * blockSize} but is ${inputDepth} for depthToSpace with input shape ${$x.shape}`);
- const forward = backend => backend.depthToSpace($x, blockSize, dataFormat);
- const inputs = { x: $x };
- const attrs = { blockSize, dataFormat };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, DepthToSpace, attrs);
- }
- const depthToSpace = op({ depthToSpace_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Depthwise 2D convolution.
- *
- * Given a 4D `input` array and a `filter` array of shape
- * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
- * `inChannels` convolutional filters of depth 1, this op applies a
- * different filter to each input channel (expanding from 1 channel to
- * `channelMultiplier` channels for each), then concatenates the results
- * together. The output has `inChannels * channelMultiplier` channels.
- *
- * See
- * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
- * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
- * for more details.
- *
- * @param x The input tensor, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
- * assumed.
- * @param filter The filter tensor, rank 4, of shape
- * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`. If strides is a single number, then `strideHeight ==
- * strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels]. Only "NHWC" is currently supported.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function depthwiseConv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'depthwiseConv2d');
- const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got ` +
- `rank ${x4D.rank}.`);
- assert($filter.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +
- `${$filter.rank}.`);
- assert(x4D.shape[3] === $filter.shape[2], () => `Error in depthwiseConv2d: number of input channels ` +
- `(${x4D.shape[3]}) must match the inChannels dimension in ` +
- `filter ${$filter.shape[2]}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in depthwiseConv2d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- if (dilations == null) {
- dilations = [1, 1];
- }
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
- `1. Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
- const res = backend.depthwiseConv2D(x4D, $filter, convInfo);
- save([x4D, $filter]);
- return res;
- };
- const inputs = { x: x4D, filter: $filter };
- const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, DepthwiseConv2dNative, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const depthwiseConv2d = op({ depthwiseConv2d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns a diagonal tensor with a given diagonal values.
- *
- * Given a diagonal, this operation returns a tensor with the diagonal and
- * everything else padded with zeros.
- *
- * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
- * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- *
- * tf.diag(x).print()
- * ```
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
- *
- * tf.diag(x).print()
- * ```
- * @param x The input tensor.
- */
- function diag_(x) {
- const $x = convertToTensor(x, 'x', 'diag');
- const forward = backend => {
- const flat = reshape($x, [$x.size]);
- const result = backend.diag(flat);
- const outShape = [...x.shape, ...x.shape];
- return reshape(result, outShape);
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Diag);
- }
- const diag = op({ diag_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the grayscale dilation over the input `x`.
- *
- * @param x The input tensor, rank 3 or rank 4 of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param filter The filter tensor, rank 3, of shape
- * `[filterHeight, filterWidth, depth]`.
- * @param strides The strides of the sliding window for each dimension of the
- * input tensor: `[strideHeight, strideWidth]`.
- * If `strides` is a single number,
- * then `strideHeight == strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1*1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat Specify the data format of the input and output data.
- * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels].
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
- * is a single number, then `dilationHeight == dilationWidth`. If it is
- * greater than 1, then all values of `strides` must be 1.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function dilation2d_(x, filter, strides, pad, dilations = [1, 1], dataFormat = 'NHWC') {
- const $x = convertToTensor(x, 'x', 'dilation2d');
- const $filter = convertToTensor(filter, 'filter', 'dilation2d');
- assert($x.rank === 3 || $x.rank === 4, () => `Error in dilation2d: input must be rank 3 or 4, but got rank ` +
- `${$x.rank}.`);
- assert($filter.rank === 3, () => `Error in dilation2d: filter must be rank 3, but got rank ` +
- `${$filter.rank}.`);
- assert(dataFormat === 'NHWC', () => `Error in dilation2d: Only NHWC is currently supported, ` +
- `but got dataFormat of ${dataFormat}`);
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- reshapedTo4D = true;
- }
- const inputs = { x: x4D, filter: $filter };
- const attrs = { strides, pad, dilations };
- const res = ENGINE.runKernel(Dilation2D, inputs, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const dilation2d = op({ dilation2d_ });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the dimensions in the input shape that are broadcasted to
- * produce the provided output shape.
- *
- * The returned dimensions are 0-indexed and sorted. An example:
- * inShape = [4, 1, 3]
- * outShape = [5, 4, 3, 3]
- * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
- */
- function getBroadcastDims(inShape, outShape) {
- const inRank = inShape.length;
- const dims = [];
- for (let i = 0; i < inRank; i++) {
- const dim = inRank - 1 - i;
- const a = inShape[dim] || 1;
- const b = outShape[outShape.length - 1 - i] || 1;
- if (b > 1 && a === 1) {
- dims.unshift(dim);
- }
- }
- return dims;
- }
- /**
- * Returns the axes in the output space that should be reduced to produce
- * the input space.
- */
- function getReductionAxes(inShape, outShape) {
- const result = [];
- for (let i = 0; i < outShape.length; i++) {
- const inDim = inShape[inShape.length - i - 1];
- const outAxis = outShape.length - i - 1;
- const outDim = outShape[outAxis];
- if (inDim == null || (inDim === 1 && outDim > 1)) {
- result.unshift(outAxis);
- }
- }
- return result;
- }
- function assertAndGetBroadcastShape(shapeA, shapeB) {
- const result = [];
- const l = Math.max(shapeA.length, shapeB.length);
- for (let i = 0; i < l; i++) {
- let a = shapeA[shapeA.length - i - 1];
- if (a == null) {
- a = 1;
- }
- let b = shapeB[shapeB.length - i - 1];
- if (b == null) {
- b = 1;
- }
- if (a === 1) {
- result.unshift(b);
- }
- else if (b === 1) {
- result.unshift(a);
- }
- else if (a !== b) {
- const errMsg = `Operands could not be broadcast together with shapes ` +
- `${shapeA} and ${shapeB}.`;
- throw Error(errMsg);
- }
- else {
- result.unshift(a);
- }
- }
- return result;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a == b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([2, 2, 2]);
- *
- * a.equal(b).print();
- * ```
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function equal_(a, b) {
- let $a = convertToTensor(a, 'a', 'equal');
- let $b = convertToTensor(b, 'b', 'equal');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = backend => backend.equal($a, $b);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null, Equal);
- }
- const equal = op({ equal_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the elements, either `a` or `b` depending on the `condition`.
- *
- * If the condition is true, select from `a`, otherwise select from `b`.
- *
- * ```js
- * const cond = tf.tensor1d([false, false, true], 'bool');
- * const a = tf.tensor1d([1 , 2, 3]);
- * const b = tf.tensor1d([-1, -2, -3]);
- *
- * a.where(cond, b).print();
- * ```
- *
- * @param condition The input condition. Must be of dtype bool.
- * @param a If `condition` is rank 1, `a` may have a higher rank but
- * its first dimension must match the size of `condition`.
- * @param b A tensor with the same dtype as `a` and with shape that is
- * compatible with `a`.
- * @return A tensor with same dtype as `a` and `b`, and shape that is
- * broadcastable from `a` and `b`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function where_(condition, a, b) {
- const $a = convertToTensor(a, 'a', 'where');
- const $b = convertToTensor(b, 'b', 'where');
- const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
- // TODO: move this logic to forward function when the broadcastTo op is
- // implemented in WASM.
- // Find the broadcastable shape for $a and $b.
- const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- const $broadcastedA = broadcastTo($a, broadcastShape);
- const $broadcastedB = broadcastTo($b, broadcastShape);
- if ($condition.rank === 1) {
- // If condition rank is 1, then the first dimension must match the size of
- // condition.
- assert($condition.shape[0] === $a.shape[0], () => 'The first dimension of `a` must match the size of `condition`.');
- }
- if ($condition.rank !== 1) {
- // A must have the same shape as condition.
- assertShapesMatch($condition.shape, $broadcastedB.shape, 'Error in where: ');
- }
- const forward = (backend, save) => {
- const res = backend.select($condition, $broadcastedA, $broadcastedB);
- save([$condition]);
- return res;
- };
- const inputs = {
- condition: $condition,
- t: $broadcastedA,
- e: $broadcastedB
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, SelectV2);
- }
- const where = op({ where_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
- * given tensor.
- *
- * ```js
- * const x = tf.tensor([1, 2]);
- * tf.zerosLike(x).print();
- * ```
- *
- * @param x The tensor of required shape.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function zerosLike_(x) {
- const $x = convertToTensor(x, 'x', 'zerosLike');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.zerosLike($x), inputs, null /* grad */, ZerosLike);
- }
- const zerosLike = op({ zerosLike_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
- * if denominator is 0.
- *
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 9, 16]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- * const c = tf.tensor1d([0, 0, 0, 0]);
- *
- * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
- * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
- * ```
- *
- * ```js
- * // Broadcast div a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(2);
- * const c = tf.scalar(0);
- *
- * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
- * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
- * ```
- *
- * @param a The first tensor as the numerator.
- * @param b The second tensor as the denominator. Must have the same dtype as
- * `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function divNoNan_(a, b) {
- // TODO: Make this into its own kernel.
- let $a = convertToTensor(a, 'a', 'div');
- let $b = convertToTensor(b, 'b', 'div');
- [$a, $b] = makeTypesMatch($a, $b);
- const divResult = div($a, $b);
- const zeros = zerosLike(divResult);
- const bEqualsZero = equal($b, zeros);
- return where(bEqualsZero, zeros, divResult);
- }
- const divNoNan = op({ divNoNan_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor2d([[1, 2], [3, 4]]);
- * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
- *
- * a.dot(b).print(); // or tf.dot(a, b)
- * b.dot(a).print();
- * b.dot(c).print();
- * ```
- * @param t1 The first tensor in the dot operation.
- * @param t2 The second tensor in the dot operation.
- *
- * @doc {heading: 'Operations', subheading: 'Matrices'}
- */
- function dot_(t1, t2) {
- const $t1 = convertToTensor(t1, 't1', 'dot');
- const $t2 = convertToTensor(t2, 't2', 'dot');
- assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
- `${$t1.rank} and ${$t2.rank}.`);
- const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
- const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
- assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
- `${t1Inner} and ${t2Inner}.`);
- if ($t1.rank === 1 && $t2.rank === 1) {
- const t12D = reshape($t1, [1, -1]);
- const t22D = reshape($t2, [-1, 1]);
- const t1t2 = matMul(t12D, t22D);
- return reshape(t1t2, []);
- }
- else if ($t1.rank === 1 && $t2.rank === 2) {
- const t12D = reshape($t1, [1, -1]);
- const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
- const t1t2 = matMul(t12D, t22D);
- return reshape(t1t2, [t1t2.size]);
- }
- else if ($t1.rank === 2 && $t2.rank === 1) {
- const t22D = reshape($t2, [-1, 1]);
- const t1t2 = matMul($t1, t22D);
- return reshape(t1t2, [t1t2.size]);
- }
- else {
- const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
- const t1t2 = matMul($t1, t22D);
- return t1t2;
- }
- }
- const dot = op({ dot_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`.
- *
- * ```js
- * const x = tf.tensor1d([-1, 1, -3, 2]);
- *
- * x.elu().print(); // or tf.elu(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function elu_(x) {
- const $x = convertToTensor(x, 'x', 'elu');
- const forward = (backend, save) => {
- const y = backend.elu($x);
- save([y]);
- return y;
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Elu);
- }
- const elu = op({ elu_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes gause error function of the input `tf.Tensor` element-wise:
- * `erf(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, .1, -.1, .7]);
- *
- * x.erf().print(); // or tf.erf(x);
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function erf_(x) {
- let $x = convertToTensor(x, 'x', 'erf');
- assert($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.');
- if ($x.dtype === 'int32') {
- $x = cast($x, 'float32');
- }
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.erf($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Erf);
- }
- const erf = op({ erf_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, -3]);
- *
- * x.exp().print(); // or tf.exp(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function exp_(x) {
- const $x = convertToTensor(x, 'x', 'exp');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.exp($x);
- save([res]);
- return res;
- }, inputs, null /* grad */, Exp);
- }
- const exp = op({ exp_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
- * into the tensor's shape.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * const axis = 1;
- * x.expandDims(axis).print();
- * ```
- *
- * @param x The input tensor whose dimensions to be expanded.
- * @param axis The dimension index at which to insert shape of `1`. Defaults
- * to 0 (the first dimension).
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function expandDims_(x, axis = 0) {
- const parseAs = null;
- const $x = convertToTensor(x, 'x', 'expandDims', parseAs);
- assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
- const newShape = $x.shape.slice();
- if (axis < 0) {
- // Negative value is counted from the tail of rank.
- assert(-($x.rank + 1) <= axis, () => `Axis must be in the interval [${-($x.rank + 1)}, ${$x.rank}]`);
- axis = $x.rank + axis + 1;
- }
- newShape.splice(axis, 0, 1);
- return reshape($x, newShape);
- }
- const expandDims = op({ expandDims_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes exponential of the input `tf.Tensor` minus one element-wise.
- * `e ^ x - 1`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, -3]);
- *
- * x.expm1().print(); // or tf.expm1(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function expm1_(x) {
- const $x = convertToTensor(x, 'x', 'expm1');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.expm1($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Expm1);
- }
- const expm1 = op({ expm1_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Construct a tensor by repeating it the number of times given by reps.
- *
- * This operation creates a new tensor by replicating `input` `reps`
- * times. The output tensor's i'th dimension has `input.shape[i] *
- * reps[i]` elements, and the values of `input` are replicated
- * `reps[i]` times along the i'th dimension. For example, tiling
- * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- *
- * a.tile([2]).print(); // or a.tile([2])
- * ```
- *
- * ```js
- * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * a.tile([1, 2]).print(); // or a.tile([1, 2])
- * ```
- * @param x The tensor to tile.
- * @param reps Determines the number of replications per dimension.
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function tile_(x, reps) {
- const parseAs = null;
- const $x = convertToTensor(x, 'x', 'tile', parseAs);
- assert($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` +
- `must match length of reps ${reps}.`);
- const forward = (backend, save) => {
- const res = backend.tile($x, reps);
- save([$x]);
- return res;
- };
- const inputsToSave = [$x];
- const inputs = { x: $x };
- const attrs = { reps };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Tile, attrs, inputsToSave);
- }
- const tile = op({ tile_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Create an identity matrix.
- *
- * @param numRows Number of rows.
- * @param numColumns Number of columns. Defaults to `numRows`.
- * @param batchShape If provided, will add the batch shape to the beginning
- * of the shape of the returned `tf.Tensor` by repeating the identity
- * matrix.
- * @param dtype Data type.
- * @returns Identity matrix of the specified size and data type, possibly
- * with batch repetition if `batchShape` is specified.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
- if (numColumns == null) {
- numColumns = numRows;
- }
- const buff = buffer([numRows, numColumns], dtype);
- const n = numRows <= numColumns ? numRows : numColumns;
- for (let i = 0; i < n; ++i) {
- buff.set(1, i, i);
- }
- const out = reshape(buff.toTensor(), [numRows, numColumns]);
- if (batchShape == null) {
- return out;
- }
- else {
- if (batchShape.length === 1) {
- return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
- }
- else if (batchShape.length === 2) {
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
- }
- else if (batchShape.length === 3) {
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
- batchShape[0], batchShape[1], batchShape[2], 1, 1
- ]);
- }
- else {
- throw new Error(`eye() currently supports only 1D and 2D ` +
- // tslint:disable-next-line:no-any
- `batchShapes, but received ${batchShape.length}D.`);
- }
- }
- }
- const eye = op({ eye_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` filled with a scalar value.
- *
- * ```js
- * tf.fill([2, 2], 4).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param value The scalar value to fill the tensor with.
- * @param dtype The type of an element in the resulting tensor. Defaults to
- * 'float'.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function fill(shape, value, dtype) {
- const attrs = { shape, value, dtype };
- return ENGINE.runKernelFunc(backend => backend.fill(shape, value, dtype), {}, null, Fill, attrs);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.floor().print(); // or tf.floor(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function floor_(x) {
- const $x = convertToTensor(x, 'x', 'floor');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.floor($x), inputs, null /* grad */, Floor);
- }
- const floor = op({ floor_ });
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const PARALLELIZE_THRESHOLD = 30;
- function computeOptimalWindowSize(inSize) {
- if (inSize <= PARALLELIZE_THRESHOLD) {
- return inSize;
- }
- return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function segOpComputeOptimalWindowSize(inSize, numSegments) {
- let done = false;
- let res;
- if (inSize <= PARALLELIZE_THRESHOLD) {
- res = inSize;
- done = true;
- }
- else {
- res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
- }
- while (!done) {
- if (res > numSegments || res === inSize) {
- done = true;
- }
- else {
- res = nearestDivisor(inSize, res + 1);
- }
- }
- return res;
- }
- function computeOutShape$2(aShape, axis, numSegments) {
- const outShape = [];
- const rank = aShape.length;
- for (let dim = 0; dim < rank; dim++) {
- if (dim !== axis) {
- outShape.push(aShape[dim]);
- }
- else {
- outShape.push(numSegments);
- }
- }
- return outShape;
- }
- function collectGatherOpShapeInfo(x, indices, axis) {
- const dimSize = x.shape[axis];
- const outputShape = [];
- let batchSize = 1;
- let sliceSize = 1;
- for (let i = 0; i < axis; i++) {
- outputShape.push(x.shape[i]);
- batchSize *= x.shape[i];
- }
- for (let i = 0; i < indices.rank; i++) {
- outputShape.push(indices.shape[i]);
- }
- for (let i = axis + 1; i < x.rank; i++) {
- outputShape.push(x.shape[i]);
- sliceSize *= x.shape[i];
- }
- return { batchSize, sliceSize, dimSize, outputShape };
- }
-
- var segment_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize,
- computeOutShape: computeOutShape$2,
- collectGatherOpShapeInfo: collectGatherOpShapeInfo
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gather slices from tensor `x`'s axis `axis` according to `indices`.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * const indices = tf.tensor1d([1, 3, 3], 'int32');
- *
- * x.gather(indices).print();
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const indices = tf.tensor1d([1, 1, 0], 'int32');
- *
- * x.gather(indices).print();
- * ```
- * @param x The input tensor whose slices to be gathered.
- * @param indices The indices of the values to extract.
- * @param axis The axis over which to select values. Defaults to 0.
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function gather_(x, indices, axis = 0) {
- const $x = convertToTensor(x, 'x', 'gather');
- const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
- const inputs = { x: $x, indices: $indices };
- const attrs = { axis };
- const forward = (backend, save) => {
- const parsedAxis = parseAxisParam(axis, $x.shape)[0];
- const shapeInfo = collectGatherOpShapeInfo($x, $indices, parsedAxis);
- const res = backend.gather($x, reshape($indices, [$indices.size]), parsedAxis);
- save([$x, $indices]);
- return reshape(res, shapeInfo.outputShape);
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, GatherV2, attrs);
- }
- const gather = op({ gather_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a > b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([2, 2, 2]);
- *
- * a.greater(b).print();
- * ```
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function greater_(a, b) {
- let $a = convertToTensor(a, 'a', 'greater');
- let $b = convertToTensor(b, 'b', 'greater');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = backend => backend.greater($a, $b);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Greater);
- }
- const greater = op({ greater_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a >= b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([2, 2, 2]);
- *
- * a.greaterEqual(b).print();
- * ```
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function greaterEqual_(a, b) {
- let $a = convertToTensor(a, 'a', 'greaterEqual');
- let $b = convertToTensor(b, 'b', 'greaterEqual');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend, save) => {
- const res = backend.greaterEqual($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, GreaterEqual);
- }
- const greaterEqual = op({ greaterEqual_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the imaginary part of a complex (or real) tensor.
- *
- * Given a tensor input, this operation returns a tensor of type float that is
- * the imaginary part of each element in input considered as a complex number.
- * If input is real, a tensor of all zeros is returned.
- *
- * ```js
- * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
- * tf.imag(x).print();
- * ```
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function imag_(input) {
- const $input = convertToTensor(input, 'input', 'imag');
- const forward = (backend) => {
- return backend.imag($input);
- };
- const inputs = { input: $input };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Imag);
- }
- const imag = op({ imag_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns which elements of x are finite.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isFinite().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function isFinite_(x) {
- const $x = convertToTensor(x, 'x', 'isFinite');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend) => backend.isFinite($x), inputs, null /* grad */, IsFinite);
- }
- const isFinite$1 = op({ isFinite_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns which elements of x are Infinity or -Infinity.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isInf().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function isInf_(x) {
- const $x = convertToTensor(x, 'x', 'isInf');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend) => backend.isInf($x), inputs, null /* grad */, IsInf);
- }
- const isInf = op({ isInf_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * RReturns which elements of x are NaN.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isNaN().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function isNaN_(x) {
- const $x = convertToTensor(x, 'x', 'isNaN');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.isNaN($x), inputs, null /* grad */, IsNan);
- }
- const isNaN$1 = op({ isNaN_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the max of a and b (`a > b ? a : b`) element-wise.
- * Supports broadcasting.
- *
- * We also expose `tf.maximumStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.maximum(b).print(); // or tf.maximum(a, b)
- * ```
- *
- * ```js
- * // Broadcast maximum a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.maximum(b).print(); // or tf.maximum(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function maximum_(a, b) {
- let $a = convertToTensor(a, 'a', 'maximum');
- let $b = convertToTensor(b, 'b', 'maximum');
- [$a, $b] = makeTypesMatch($a, $b);
- if ($a.dtype === 'bool') {
- $a = cast($a, 'int32');
- $b = cast($b, 'int32');
- }
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend, save) => {
- const res = backend.maximum($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Maximum);
- }
- const maximum = op({ maximum_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.scalar` as it makes the code more readable.
- *
- * ```js
- * tf.scalar(3.14).print();
- * ```
- *
- * @param value The value of the scalar.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function scalar(value, dtype) {
- if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
- dtype !== 'complex64') {
- throw new Error('Error creating a new Scalar: value must be a primitive ' +
- '(number|boolean|string)');
- }
- if (dtype === 'string' && isTypedArray(value) &&
- !(value instanceof Uint8Array)) {
- throw new Error('When making a scalar from encoded string, ' +
- 'the value must be `Uint8Array`.');
- }
- const shape = [];
- const inferredShape = [];
- return makeTensor(value, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes leaky rectified linear element-wise.
- *
- * See
- * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
- * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
- * ```
- * @param x The input tensor.
- * @param alpha The scaling factor for negative values, defaults to 0.2.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function leakyRelu_(x, alpha = 0.2) {
- const $x = convertToTensor(x, 'x', 'leakyRelu');
- return maximum(mul(scalar(alpha), $x), $x);
- }
- const leakyRelu = op({ leakyRelu_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a < b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([2, 2, 2]);
- *
- * a.less(b).print();
- * ```
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function less_(a, b) {
- let $a = convertToTensor(a, 'a', 'less');
- let $b = convertToTensor(b, 'b', 'less');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = backend => backend.less($a, $b);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Less);
- }
- const less = op({ less_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a <= b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([2, 2, 2]);
- *
- * a.lessEqual(b).print();
- * ```
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function lessEqual_(a, b) {
- let $a = convertToTensor(a, 'a', 'lessEqual');
- let $b = convertToTensor(b, 'b', 'lessEqual');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend, save) => {
- const res = backend.lessEqual($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, LessEqual);
- }
- const lessEqual = op({ lessEqual_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Return an evenly spaced sequence of numbers over the given interval.
- *
- * ```js
- * tf.linspace(0, 9, 10).print();
- * ```
- * @param start The start value of the sequence.
- * @param stop The end value of the sequence.
- * @param num The number of values to generate.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function linspace(start, stop, num) {
- if (num <= 0) {
- throw new Error('The number of values should be positive.');
- }
- const attrs = { start, stop, num };
- return ENGINE.runKernelFunc(backend => backend.linspace(start, stop, num), {} /* inputs */, null /* grad */, LinSpace, attrs);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Normalizes the activation of a local neighborhood across or within
- * channels.
- *
- * @param x The input tensor. The 4-D input tensor is treated as a 3-D array
- * of 1D vectors (along the last dimension), and each vector is
- * normalized independently.
- * @param depthRadius The number of adjacent channels in the 1D normalization
- * window.
- * @param bias A constant bias term for the basis.
- * @param alpha A scale factor, usually positive.
- * @param beta An exponent.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function localResponseNormalization_(x, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
- const $x = convertToTensor(x, 'x', 'localResponseNormalization');
- assert($x.rank === 4 || $x.rank === 3, () => `Error in localResponseNormalization: x must be rank 3 or 4 but got
- rank ${$x.rank}.`);
- assert(isInt(depthRadius), () => `Error in localResponseNormalization: depthRadius must be an ` +
- `integer but got depthRadius ${depthRadius}.`);
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- const forward = (backend, save) => {
- const y = backend.localResponseNormalization4D(x4D, depthRadius, bias, alpha, beta);
- save([x4D, y]);
- return y;
- };
- const inputs = { x: x4D };
- const attrs = { depthRadius, bias, alpha, beta };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, LRN, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- else {
- return res;
- }
- }
- const localResponseNormalization = op({ localResponseNormalization_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, Math.E]);
- *
- * x.log().print(); // or tf.log(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function log_(x) {
- const $x = convertToTensor(x, 'x', 'log');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.log($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Log);
- }
- const log = op({ log_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes natural logarithm of the input `tf.Tensor` plus one
- * element-wise: `ln(1 + x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, Math.E - 1]);
- *
- * x.log1p().print(); // or tf.log1p(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function log1p_(x) {
- const $x = convertToTensor(x, 'x', 'log1p');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.log1p($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Log1p);
- }
- const log1p = op({ log1p_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
- * gradient of `f(x)` with respect to `x`.
- *
- * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
- * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
- * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
- *
- * ```js
- * // f(x) = x ^ 2
- * const f = x => x.square();
- * // f'(x) = 2x
- * const g = tf.grad(f);
- *
- * const x = tf.tensor1d([2, 3]);
- * g(x).print();
- * ```
- *
- * ```js
- * // f(x) = x ^ 3
- * const f = x => x.pow(tf.scalar(3, 'int32'));
- * // f'(x) = 3x ^ 2
- * const g = tf.grad(f);
- * // f''(x) = 6x
- * const gg = tf.grad(g);
- *
- * const x = tf.tensor1d([2, 3]);
- * gg(x).print();
- * ```
- *
- * @param f The function f(x), to compute gradient for.
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function grad(f) {
- assert(isFunction(f), () => 'The f passed in grad(f) must be a function');
- return (x, dy) => {
- // x can be of any dtype, thus null as the last argument.
- const $x = convertToTensor(x, 'x', 'tf.grad', null);
- const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
- return ENGINE.tidy(() => {
- const { value, grads } = ENGINE.gradients(() => f($x), [$x], $dy);
- if ($dy != null) {
- assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
- 'returned by f(x)');
- }
- checkGrads(grads);
- return grads[0];
- });
- };
- }
- /**
- * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
- * which gives an array of gradients of `f()` with respect to each input
- * [`x1`,`x2`,...].
- *
- * If `dy` is passed when calling `g()`, the gradient of
- * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
- * The provided `f` must take one or more tensors and return a single tensor
- * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
- *
- * ```js
- * // f(a, b) = a * b
- * const f = (a, b) => a.mul(b);
- * // df / da = b, df / db = a
- * const g = tf.grads(f);
- *
- * const a = tf.tensor1d([2, 3]);
- * const b = tf.tensor1d([-2, -3]);
- * const [da, db] = g([a, b]);
- * console.log('da');
- * da.print();
- * console.log('db');
- * db.print();
- * ```
- *
- * @param f The function `f(x1, x2,...)` to compute gradients for.
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function grads(f) {
- assert(isFunction(f), () => 'The f passed in grads(f) must be a function');
- return (args, dy) => {
- assert(Array.isArray(args), () => 'The args passed in grads(f)(args) must be an array ' +
- 'of `Tensor`s or `TensorLike`s');
- // args can be of any dtype, thus null as the last argument.
- const $args = convertToTensorArray(args, 'args', 'tf.grads', null);
- const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
- return ENGINE.tidy(() => {
- const { value, grads } = ENGINE.gradients(() => f(...$args), $args, $dy);
- if ($dy != null) {
- assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
- 'match the shape returned by f([x1,...])');
- }
- checkGrads(grads);
- return grads;
- });
- };
- }
- /**
- * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
- * returns a metric you want to show.
- *
- * The result is a rich object with the following properties:
- * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).
- * - value: The value returned by `f(x)`.
- *
- * ```js
- * // f(x) = x ^ 2
- * const f = x => x.square();
- * // f'(x) = 2x
- * const g = tf.valueAndGrad(f);
- *
- * const x = tf.tensor1d([2, 3]);
- * const {value, grad} = g(x);
- *
- * console.log('value');
- * value.print();
- * console.log('grad');
- * grad.print();
- * ```
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function valueAndGrad(f) {
- assert(isFunction(f), () => 'The f passed in valueAndGrad(f) must be a function');
- return (x, dy) => {
- assert(x instanceof Tensor, () => 'The x passed in valueAndGrad(f)(x) must be a tensor');
- assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
- const { grads, value } = ENGINE.gradients(() => f(x), [x], dy);
- checkGrads(grads);
- return { grad: grads[0], value };
- };
- }
- /**
- * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
- * returns a metric you want to show.
- *
- * The result is a rich object with the following properties:
- * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).
- * - value: The value returned by `f(x)`.
- *
- * ```js
- * // f(a, b) = a * b
- * const f = (a, b) => a.mul(b);
- * // df/da = b, df/db = a
- * const g = tf.valueAndGrads(f);
- *
- * const a = tf.tensor1d([2, 3]);
- * const b = tf.tensor1d([-2, -3]);
- * const {value, grads} = g([a, b]);
- *
- * const [da, db] = grads;
- *
- * console.log('value');
- * value.print();
- *
- * console.log('da');
- * da.print();
- * console.log('db');
- * db.print();
- * ```
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function valueAndGrads(f) {
- assert(isFunction(f), () => 'The f passed in valueAndGrads(f) must be a function');
- return (args, dy) => {
- assert(Array.isArray(args) && args.every(arg => arg instanceof Tensor), () => 'The args passed in valueAndGrads(f)(args) must be array of ' +
- 'tensors');
- assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
- const res = ENGINE.gradients(() => f(...args), args, dy);
- if (dy != null) {
- assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
- 'match the shape returned by f([x1,...])');
- }
- checkGrads(res.grads);
- return res;
- };
- }
- /**
- * Computes and returns the gradient of f(x) with respect to the list of
- * trainable variables provided by `varList`. If no list is provided, it
- * defaults to all trainable variables.
- *
- * ```js
- * const a = tf.variable(tf.tensor1d([3, 4]));
- * const b = tf.variable(tf.tensor1d([5, 6]));
- * const x = tf.tensor1d([1, 2]);
- *
- * // f(a, b) = a * x ^ 2 + b * x
- * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
- * // df/da = x ^ 2, df/db = x
- * const {value, grads} = tf.variableGrads(f);
- *
- * Object.keys(grads).forEach(varName => grads[varName].print());
- * ```
- *
- * @param f The function to execute. f() should return a scalar.
- * @param varList The list of variables to compute the gradients with respect
- * to. Defaults to all trainable variables.
- * @returns An object with the following keys and values:
- * - `value`: The value of the function `f`.
- * - `grads`: A map from the names of the variables to the gradients.
- * If the `varList` argument is provided explicitly and contains a subset of
- * non-trainable variables, this map in the return value will contain keys
- * that map the names of the non-trainable variables to `null`.
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function variableGrads(f, varList) {
- assert(isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
- assert(varList == null ||
- Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
- 'of variables');
- const specifiedVarList = varList != null;
- if (!specifiedVarList) {
- // Get all of the trainable variables.
- varList = [];
- for (const varName in ENGINE.registeredVariables) {
- varList.push(ENGINE.registeredVariables[varName]);
- }
- }
- const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
- // Prune non-trainable variables.
- const originalVarCount = varList.length;
- varList = varList.filter(variable => variable.trainable);
- assert(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
- `be trainable, but none of the ${originalVarCount} variables is ` +
- `trainable.`);
- const allowNoGradients = true;
- const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
- assert(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
- 'the loss function y=f(x). Please make sure the operations that ' +
- 'use variables are inside the function f passed to minimize().');
- assert(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
- `returned a rank-${value.rank} tensor`);
- const namedGrads = {};
- varList.forEach((v, i) => {
- if (grads[i] != null) {
- namedGrads[v.name] = grads[i];
- }
- });
- if (specifiedNonTrainable != null) {
- // If varList is explicitly provided and contains non-trainable values,
- // add them to the returned gradients with `null` values.
- specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
- }
- return { value, grads: namedGrads };
- }
- /**
- * Overrides the gradient computation of a function `f`.
- *
- * Takes a function
- * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
- * and returns another function `g(...inputs)` which takes the same inputs as
- * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
- * with respect to each input of `f` are computed using `f().gradFunc`.
- *
- * The `save` function passsed to `f` should be used for saving tensors needed
- * in the gradient. And the `saved` passed to the `gradFunc` is a
- * `NamedTensorMap`, which contains those saved tensor.
- *
- * ```js
- * const customOp = tf.customGrad((x, save) => {
- * // Save x to make sure it's available later for the gradient.
- * save([x]);
- * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
- * return {
- * value: x.square(),
- * // Note `saved.x` which points to the `x` we saved earlier.
- * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
- * };
- * });
- *
- * const x = tf.tensor1d([-1, -2, 3]);
- * const dx = tf.grad(x => customOp(x));
- *
- * console.log(`f(x):`);
- * customOp(x).print();
- * console.log(`f'(x):`);
- * dx(x).print();
- * ```
- *
- * @param f The function to evaluate in forward mode, which should return
- * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
- * returns the custom gradients of `f` with respect to its inputs.
- *
- * @doc {heading: 'Training', subheading: 'Gradients'}
- */
- function customGrad(f) {
- return ENGINE.customGrad(f);
- }
- function checkGrads(grads) {
- const numNullGradients = grads.filter(g => g == null).length;
- if (numNullGradients > 0) {
- throw new Error(`Cannot compute gradient of y=f(x) with respect to x. Make sure that
- the f you passed encloses all operations that lead from x to y.`);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes `-1 * x` element-wise.
- *
- * ```js
- * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
- *
- * x.neg().print(); // or tf.neg(x)
- * ```
- *
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function neg_(x) {
- const $x = convertToTensor(x, 'x', 'neg');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.neg($x), inputs, null /* grad */, Negate);
- }
- const neg = op({ neg_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.softplus().print(); // or tf.softplus(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function softplus_(x) {
- const $x = convertToTensor(x, 'x', 'softplus');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.softplus($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Softplus);
- }
- const softplus = op({ softplus_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes log sigmoid of the input `tf.Tensor` element-wise:
- * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.logSigmoid().print(); // or tf.logSigmoid(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function logSigmoid_(x) {
- const $x = convertToTensor(x, 'x', 'logSigmoid');
- // Use a custom gradient to maintain previous implementation.
- // There is no LogSigmoid kernel in TF so we can't use engine.runKernel
- // directly
- const customOp = customGrad((x) => {
- // TODO(yassogba) we can remove the chained softplus call here only
- // after backends have modualrized softplus at which point we can call
- // engine runKernel(..., Sotfplus, ...) directly.
- const value = neg(softplus(neg(x)));
- const gradFunc = (dy) => {
- const derX = mul(dy, sigmoid(neg(x)));
- return derX;
- };
- return { value, gradFunc };
- });
- return customOp($x);
- }
- const logSigmoid = op({ logSigmoid_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the maximum of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
- * `axes`. If `keepDims` is true, the reduced dimensions are retained with
- * length 1. If `axes` has no entries, all dimensions are reduced, and an
- * `tf.Tensor` with a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.max().print(); // or tf.max(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.max(axis).print(); // or tf.max(x, axis)
- * ```
- *
- * @param x The input tensor.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function max_(x, axis = null, keepDims = false) {
- const $x = convertToTensor(x, 'x', 'max');
- const forward = (backend, save) => {
- const origAxes = parseAxisParam(axis, $x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- let maxInput = $x;
- if (permutedAxes != null) {
- maxInput = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, maxInput.rank);
- }
- const y = backend.max(maxInput, axes);
- if (permutedAxes != null) {
- maxInput.dispose();
- }
- let res = y;
- if (keepDims) {
- const expandedShape = expandShapeToKeepDim(res.shape, parseAxisParam(axis, $x.shape));
- res = reshape(res, expandedShape);
- y.dispose();
- }
- save([$x, res]);
- return res;
- };
- const inputs = { x: $x };
- const attrs = { reductionIndices: axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Max, attrs);
- }
- const max = op({ max_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([10, 20, 30, 40]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.sub(b).print(); // or tf.sub(a, b)
- * ```
- *
- * ```js
- * // Broadcast subtract a with b.
- * const a = tf.tensor1d([10, 20, 30, 40]);
- * const b = tf.scalar(5);
- *
- * a.sub(b).print(); // or tf.sub(a, b)
- * ```
- * @param a The first `tf.Tensor` to subtract from.
- * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
- * `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function sub_(a, b) {
- let $a = convertToTensor(a, 'a', 'sub');
- let $b = convertToTensor(b, 'b', 'sub');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.subtract($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Sub);
- }
- const sub = op({ sub_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the sum of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
- * `axes`. If `keepDims` is true, the reduced dimensions are retained with
- * length 1. If axes has no entries, all dimensions are reduced, and a
- * `tf.Tensor` with a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.sum().print(); // or tf.sum(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.sum(axis).print(); // or tf.sum(x, axis)
- * ```
- *
- * @param x The input tensor to compute the sum over. If the dtype is `bool`
- * it will be converted to `int32` and the output dtype will be `int32`.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function sum_(x, axis = null, keepDims = false) {
- let $x = convertToTensor(x, 'x', 'sum');
- if ($x.dtype === 'bool') {
- $x = cast($x, 'int32');
- }
- const forward = (backend, save) => {
- save([$x]);
- const axes = parseAxisParam(axis, $x.shape);
- const permutation = getAxesPermutation(axes, $x.rank);
- let reductionAxes = axes;
- let permutedX = $x;
- if (permutation != null) {
- permutedX = transpose($x, permutation);
- reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank);
- }
- let value = backend.sum(permutedX, reductionAxes);
- if (keepDims) {
- const newShape = expandShapeToKeepDim(value.shape, axes);
- value = reshape(value, newShape);
- }
- return value;
- };
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Sum, attrs);
- }
- const sum$1 = op({ sum_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the log softmax.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- *
- * a.logSoftmax().print(); // or tf.logSoftmax(a)
- * ```
- *
- * ```js
- * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
- *
- * a.logSoftmax().print(); // or tf.logSoftmax(a)
- * ```
- *
- * @param logits The logits array.
- * @param axis The dimension softmax would be performed on. Defaults to `-1`
- * which indicates the last dimension.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function logSoftmax_(logits, axis = -1) {
- const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
- if (axis === -1) {
- axis = $logits.rank - 1;
- }
- if (axis !== $logits.rank - 1) {
- throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
- `Logits was rank ${$logits.rank} and axis was ${axis}`);
- }
- const forward = (backend, save) => {
- const keepDims = true;
- const xMax = max(logits, axis, true);
- const shifted = sub(logits, xMax);
- const value = sub(cast(shifted, 'float32'), log(sum$1(exp(shifted), axis, keepDims)));
- save([value]);
- return value;
- };
- const inputs = { logits: $logits };
- const attrs = { axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, LogSoftmax, attrs);
- }
- const logSoftmax = op({ logSoftmax_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the log(sum(exp(elements across the reduction dimensions)).
- *
- * Reduces the input along the dimensions given in `axis`. Unless `keepDims`
- * is true, the rank of the array is reduced by 1 for each entry in `axis`.
- * If `keepDims` is true, the reduced dimensions are retained with length 1.
- * If `axis` has no entries, all dimensions are reduced, and an array with a
- * single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.logSumExp().print(); // or tf.logSumExp(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
- * ```
- * @param x The input tensor.
- * @param axis The dimension(s) to reduce. If null (the default),
- * reduces all dimensions.
- * @param keepDims If true, retains reduced dimensions with length
- * of 1. Defaults to false.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function logSumExp_(x, axis = null, keepDims = false) {
- const $x = convertToTensor(x, 'x', 'logSumExp');
- const axes = parseAxisParam(axis, $x.shape);
- const xMax = max($x, axes, true /* keepDims */);
- const a = sub($x, xMax);
- const b = exp(a);
- const c = sum$1(b, axes);
- const d = log(c);
- const res = add$1(reshape(xMax, d.shape), d);
- if (keepDims) {
- const newShape = expandShapeToKeepDim(res.shape, axes);
- return reshape(res, newShape);
- }
- return res;
- }
- const logSumExp = op({ logSumExp_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([false, false, true, true], 'bool');
- * const b = tf.tensor1d([false, true, false, true], 'bool');
- *
- * a.logicalAnd(b).print();
- * ```
- *
- * @param a The first input tensor. Must be of dtype bool.
- * @param b The second input tensor. Must be of dtype bool.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function logicalAnd_(a, b) {
- const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
- const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(backend => backend.logicalAnd($a, $b), inputs, null /* grad */, LogicalAnd);
- }
- const logicalAnd = op({ logicalAnd_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of `NOT x` element-wise.
- *
- * ```js
- * const a = tf.tensor1d([false, true], 'bool');
- *
- * a.logicalNot().print();
- * ```
- *
- * @param x The input tensor. Must be of dtype 'bool'.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function logicalNot_(x) {
- const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.logicalNot($x), inputs, null /* grad */, LogicalNot);
- }
- const logicalNot = op({ logicalNot_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([false, false, true, true], 'bool');
- * const b = tf.tensor1d([false, true, false, true], 'bool');
- *
- * a.logicalOr(b).print();
- * ```
- * @param a The first input tensor. Must be of dtype bool.
- * @param b The second input tensor. Must be of dtype bool.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function logicalOr_(a, b) {
- const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
- const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(backend => backend.logicalOr($a, $b), inputs, null /* grad */, LogicalOr);
- }
- const logicalOr = op({ logicalOr_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([false, false, true, true], 'bool');
- * const b = tf.tensor1d([false, true, false, true], 'bool');
- *
- * a.logicalXor(b).print();
- * ```
- *
- * @param a The first input tensor. Must be of dtype bool.
- * @param b The second input tensor. Must be of dtype bool.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function logicalXor_(a, b) {
- const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
- const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
- assertAndGetBroadcastShape($a.shape, $b.shape);
- // x ^ y = (x | y) & ~(x & y)
- return logicalAnd(logicalOr(a, b), logicalNot(logicalAnd(a, b)));
- }
- const logicalXor = op({ logicalXor_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the 2D max pooling of an image.
- *
- * @param x The input tensor, of rank 4 or rank 3 of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- */
- function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
- const $x = convertToTensor(x, 'x', 'maxPool');
- const dilations = 1;
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x4D.rank}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in maxPool: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- const convInfo = computePool2DInfo(x4D.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
- let y;
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- y = x4D.clone();
- }
- else {
- y = backend.maxPool(x4D, convInfo);
- }
- save([x4D, y]);
- return y;
- };
- const inputs = { x: x4D };
- const attrs = { filterSize, strides, pad, dimRoundingMode };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, MaxPool, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const maxPool = op({ maxPool_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the 3D max pooling.
- *
- * ```js
- * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
- * const result = tf.maxPool3d(x, 2, 1, 'valid');
- * result.print();
- * ```
- *
- * @param x The input tensor, of rank 5 or rank 4 of shape
- * `[batch, depth, height, width, inChannels]`.
- * @param filterSize The filter size:
- * `[filterDepth, filterHeight, filterWidth]`.
- * If `filterSize` is a single number,
- * then `filterDepth == filterHeight == filterWidth`.
- * @param strides The strides of the pooling:
- * `[strideDepth, strideHeight, strideWidth]`.
- * If `strides` is a single number,
- * then `strideDepth == strideHeight == strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1*1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
- * "NDHWC". Specify the data format of the input and output data. With the
- * default format "NDHWC", the data is stored in the order of: [batch,
- * depth, height, width, channels]. Only "NDHWC" is currently supported.
- * @param dilations Deprecated, this field will be gone in v3.0.0.
- * The dilation rates: `[dilationDepth, dilationHeight, dilationWidth]`
- * in which we sample input values across the depth, height and width
- * dimensions in dilated pooling.
- * Defaults to `[1, 1, 1]`. If `dilations` is a single number,
- * then `dilationDepth == dilationHeight == dilationWidth`.
- * If it is greater than 1, then all values of `strides` must be 1.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function maxPool3d_(x, filterSize = [1, 1, 1], strides, pad, dimRoundingMode, dataFormat = 'NDHWC', dilations) {
- if (dilations == null) {
- dilations = [1, 1, 1];
- }
- else {
- deprecationWarn('dilations is deprecated, this field will be gone in ' +
- 'v3.0.0.');
- }
- const $x = convertToTensor(x, 'x', 'maxPool3d');
- let x5D = $x;
- let reshapedTo5D = false;
- if ($x.rank === 4) {
- reshapedTo5D = true;
- x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
- }
- assert(x5D.rank === 5, () => `Error in maxPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
- assert(dataFormat === 'NDHWC', () => `Error in maxPool3d: Only NDHWC is currently supported, ` +
- `but got dataFormat of ${dataFormat}`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool3d: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in maxPool3d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = (backend, save) => {
- if (dilations == null) {
- dilations = [1, 1, 1];
- }
- const convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
- const y = backend.maxPool3d(x5D, convInfo);
- save([x5D, y]);
- return y;
- };
- const inputs = { x: x5D };
- const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat, dilations };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, MaxPool3D, attrs);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const maxPool3d = op({ maxPool3d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the 2D max pooling of an image with Argmax index.
- * The indices in argmax are flattened, so that a maximum value at position `[b,
- * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
- * include_batch_in_index is False; `((b * height + y) * width + x) * channels
- * +c` if include_batch_in_index is True.
- *
- * The indices returned are always in `[0, height) x [0, width)` before
- * flattening.
- *
- * @param x The input tensor, of rank 4 or rank 3 of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
- * "NDHWC". Specify the data format of the input and output data. With the
- * default format "NDHWC", the data is stored in the order of: [batch,
- * depth, height, width, channels]. Only "NDHWC" is currently supported.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param includeBatchIndex Defaults to False. Whether to include batch
- * dimension in flattened index of argmax.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function maxPoolWithArgmax_(x, filterSize, strides, pad, includeBatchInIndex = false) {
- const $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
- const inputs = { x: $x };
- const attrs = { filterSize, strides, pad, includeBatchInIndex };
- const result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
- return { result: result[0], indexes: result[1] };
- }
- const maxPoolWithArgmax = op({ maxPoolWithArgmax_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with all elements set to 0.
- *
- * ```js
- * tf.zeros([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The type of an element in the resulting tensor. Can
- * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function zeros(shape, dtype = 'float32') {
- if (dtype === 'complex64') {
- const real = zeros(shape, 'float32');
- const imag = zeros(shape, 'float32');
- return complex(real, imag);
- }
- const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
- return ENGINE.makeTensor(values, shape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with all elements set to 1.
- *
- * ```js
- * tf.ones([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The type of an element in the resulting tensor. Defaults to
- * 'float'.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function ones$1(shape, dtype = 'float32') {
- if (dtype === 'complex64') {
- const real = ones$1(shape, 'float32');
- const imag = zeros(shape, 'float32');
- return complex(real, imag);
- }
- const values = makeOnesTypedArray(sizeFromShape(shape), dtype);
- return ENGINE.makeTensor(values, shape, dtype);
- }
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the mean of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
- * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
- * If `keepDims` is true, the reduced dimensions are retained with length 1.
- * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
- * a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.mean().print(); // or tf.mean(a)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.mean(axis).print(); // or tf.mean(x, axis)
- * ```
- *
- * @param x The input tensor.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function mean_(x, axis = null, keepDims = false) {
- const $x = convertToTensor(x, 'x', 'mean');
- const axes = parseAxisParam(axis, $x.shape);
- const shapes = computeOutAndReduceShapes($x.shape, axes);
- const reduceShape = shapes[1];
- const reduceSize = sizeFromShape(reduceShape);
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- const forward = () => {
- const reduceSizeScalar = scalar(reduceSize);
- // Cast if needed.
- const xReduce = reduceSizeScalar.dtype === $x.dtype ?
- $x :
- cast($x, reduceSizeScalar.dtype);
- const res = div(xReduce, reduceSizeScalar);
- return sum$1(res, axis, keepDims);
- };
- // Use a custom gradient to bypass 2 gradient backprops since mean is used
- // extremely often.
- const customOp = customGrad((x) => {
- const value = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Mean, attrs);
- const gradFunc = (dy) => {
- const expandedDyShape = x.shape.slice();
- axes.forEach(axis => {
- expandedDyShape[axis] = 1;
- });
- const expandedDy = reshape(dy, expandedDyShape);
- const derX = div(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize);
- return derX;
- };
- return { value, gradFunc };
- });
- return customOp($x);
- }
- const mean = op({ mean_ });
-
- /**
- * Computes the minimum value from the input.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the array is reduced by 1 for each entry in `axes`.
- * If `keepDims` is true, the reduced dimensions are retained with length 1.
- * If `axes` has no entries, all dimensions are reduced, and an array with a
- * single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.min().print(); // or tf.min(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.min(axis).print(); // or tf.min(x, axis)
- * ```
- *
- * @param x The input Tensor.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function min_(x, axis = null, keepDims = false) {
- const $x = convertToTensor(x, 'x', 'min');
- const forward = (backend, save) => {
- const origAxes = parseAxisParam(axis, $x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, $x.rank);
- let minInput = $x;
- if (permutedAxes != null) {
- minInput = transpose($x, permutedAxes);
- axes = getInnerMostAxes(axes.length, $x.rank);
- }
- const y = backend.min(minInput, axes);
- if (permutedAxes != null) {
- minInput.dispose();
- }
- let res = y;
- if (keepDims) {
- const expandedShape = expandShapeToKeepDim(res.shape, origAxes);
- res = reshape(y, expandedShape);
- y.dispose();
- }
- save([$x, res]);
- return res;
- };
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Min, attrs);
- }
- const min = op({ min_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the min of a and b (`a < b ? a : b`) element-wise.
- * Supports broadcasting.
- *
- * We also expose `minimumStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.minimum(b).print(); // or tf.minimum(a, b)
- * ```
- *
- * ```js
- * // Broadcast minimum a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.minimum(b).print(); // or tf.minimum(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function minimum_(a, b) {
- let $a = convertToTensor(a, 'a', 'minimum');
- let $b = convertToTensor(b, 'b', 'minimum');
- [$a, $b] = makeTypesMatch($a, $b);
- if ($a.dtype === 'bool') {
- $a = cast($a, 'int32');
- $b = cast($b, 'int32');
- }
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend, save) => {
- const res = backend.minimum($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Minimum);
- }
- const minimum = op({ minimum_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Pads a `tf.Tensor` using mirror padding.
- *
- * This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
- *
- * ```js
- * const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
- * x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
- * ```
- * @param x The tensor to pad.
- * @param paddings An array of length `R` (the rank of the tensor), where
- * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
- * specifying how much to pad along each dimension of the tensor.
- * In "reflect" mode, the padded regions do not include the borders,
- * while in "symmetric" mode the padded regions do include the borders.
- * For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
- * then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
- * `[1, 2, 3, 3, 2]` in "symmetric" mode.
- * If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
- * must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
- * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
- * `x.shape[D]`
- * @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function mirrorPad_(x, paddings, mode) {
- assert(mode === 'reflect' || mode === 'symmetric', () => `Invalid mode. Mode must be either reflect or symmetric. ` +
- `Got ${mode}.`);
- const $x = convertToTensor(x, 'x', 'mirrorPad');
- if ($x.rank === 0) {
- throw new Error('mirrorPad(scalar) is not defined. ' +
- 'Pass non-scalar to mirrorPad');
- }
- assert(paddings.length === $x.rank, () => `Padding doesn't match input. Must be ${$x.rank}. ` +
- `Got ${paddings.length}.`);
- const shapeOffset = mode === 'reflect' ? 1 : 0;
- for (let i = 0; i < $x.rank; i++) {
- assert(paddings[i].length === 2, () => `Invalid number of paddings. Must be length of 2 each.`);
- assert(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset &&
- paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, () => `Padding in dimension ${i} cannot be greater than or equal ` +
- `to ${$x.shape[i] - shapeOffset} or less than 0 for input of ` +
- `shape ${$x.shape}`);
- }
- const attrs = { paddings, mode };
- const inputs = { x: $x };
- return ENGINE.runKernel(MirrorPad, inputs, attrs);
- }
- const mirrorPad = op({ mirrorPad_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the mod of a and b element-wise.
- * `floor(x / y) * y + mod(x, y) = x`
- * Supports broadcasting.
- *
- * We also expose `tf.modStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.mod(b).print(); // or tf.mod(a, b)
- * ```
- *
- * ```js
- * // Broadcast a mod b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.mod(b).print(); // or tf.mod(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function mod_(a, b) {
- let $a = convertToTensor(a, 'a', 'mod');
- let $b = convertToTensor(b, 'b', 'mod');
- [$a, $b] = makeTypesMatch($a, $b);
- const forward = (backend, save) => {
- const res = backend.mod($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Mod);
- }
- const mod = op({ mod_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes square of `x` element-wise: `x ^ 2`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
- *
- * x.square().print(); // or tf.square(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function square_(x) {
- const $x = convertToTensor(x, 'x', 'square');
- const attrs = {};
- const inputsToSave = [$x];
- const outputsToSave = [];
- return ENGINE.runKernelFunc((backend, save) => {
- save([$x]);
- return backend.square($x);
- }, { x: $x }, null /* grad */, 'Square', attrs, inputsToSave, outputsToSave);
- }
- const square = op({ square_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Calculates the mean and variance of `x`. The mean and variance are
- * calculated by aggregating the contents of `x` across `axes`. If `x` is
- * 1-D and `axes = [0]` this is just the mean and variance of a vector.
- *
- * @param x The input tensor.
- * @param axis The dimension(s) along with to compute mean and
- * variance. By default it reduces all dimensions.
- * @param keepDims If true, the moments have the same dimensionality as the
- * input.
- * @return An object with two keys: `mean` and `variance`.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function moments_(x, axis = null, keepDims = false) {
- x = convertToTensor(x, 'x', 'moments');
- const axes = parseAxisParam(axis, x.shape);
- const xMean = mean(x, axes, keepDims);
- let keepDimsShape = xMean.shape;
- if (!keepDims) {
- keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
- }
- const devSquared = square(sub(cast(x, 'float32'), reshape(xMean, keepDimsShape)));
- const variance = mean(devSquared, axes, keepDims);
- return { mean: xMean, variance };
- }
- const moments = op({ moments_ });
-
- /**
- * Computes the next states and outputs of a stack of LSTMCells.
- *
- * Each cell output is used as input to the next cell.
- *
- * Returns `[cellState, cellOutput]`.
- *
- * Derived from tf.contrib.rn.MultiRNNCell.
- *
- * @param lstmCells Array of LSTMCell functions.
- * @param data The input to the cell.
- * @param c Array of previous cell states.
- * @param h Array of previous cell outputs.
- *
- * @doc {heading: 'Operations', subheading: 'RNN'}
- */
- function multiRNNCell_(lstmCells, data, c, h) {
- const $data = convertToTensor(data, 'data', 'multiRNNCell');
- const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
- const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
- let input = $data;
- const newStates = [];
- for (let i = 0; i < lstmCells.length; i++) {
- const output = lstmCells[i](input, $c[i], $h[i]);
- newStates.push(output[0]);
- newStates.push(output[1]);
- input = output[1];
- }
- const newC = [];
- const newH = [];
- for (let i = 0; i < newStates.length; i += 2) {
- newC.push(newStates[i]);
- newH.push(newStates[i + 1]);
- }
- return [newC, newH];
- }
- const multiRNNCell = op({ multiRNNCell_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
- *
- * ```js
- * const probs = tf.tensor([.75, .25]);
- * tf.multinomial(probs, 3).print();
- * ```
- *
- * @param logits 1D array with unnormalized log-probabilities, or
- * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
- * parameter.
- * @param numSamples Number of samples to draw for each row slice.
- * @param seed The seed number.
- * @param normalized Whether the provided `logits` are normalized true
- * probabilities (sum to 1). Defaults to false.
- * @return 1D array of shape `[numSamples]`, or 2D array of shape
- * `[batchSize, numSamples]`, depending on the rank of the input.
- *
- * @doc {heading: 'Tensors', subheading: 'Random'}
- */
- function multinomial_(logits, numSamples, seed, normalized = false) {
- const $logits = convertToTensor(logits, 'logits', 'multinomial');
- const numOutcomes = $logits.size;
- const origRank = $logits.rank;
- if (numOutcomes < 2) {
- throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ` +
- `${numOutcomes}.`);
- }
- if (origRank > 2) {
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
- }
- seed = seed || Math.random();
- const logits2D = origRank === 1 ? reshape($logits, [1, -1]) : $logits;
- const res = ENGINE.runKernelFunc(backend => backend.multinomial(logits2D, normalized, numSamples, seed), { logits2D });
- // tslint:disable-next-line:no-unnecessary-type-assertion
- return origRank === 1 ? reshape(res, [res.size]) : res;
- }
- const multinomial = op({ multinomial_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the truth value of (a != b) element-wise. Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([0, 2, 3]);
- *
- * a.notEqual(b).print();
- * ```
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same dtype as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- function notEqual_(a, b) {
- let $a = convertToTensor(a, 'a', 'notEqual');
- let $b = convertToTensor(b, 'b', 'notEqual');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend) => backend.notEqual($a, $b);
- const inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, NotEqual);
- }
- const notEqual = op({ notEqual_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the real part of a complex (or real) tensor.
- *
- * Given a tensor input, this operation returns a tensor of type float that is
- * the real part of each element in input considered as a complex number.
- *
- * If the input is real, it simply makes a clone.
- *
- * ```js
- * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
- * tf.real(x).print();
- * ```
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function real_(input) {
- const $input = convertToTensor(input, 'input', 'real');
- const forward = (backend) => {
- return backend.real($input);
- };
- const inputs = { input: $input };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Real);
- }
- const real = op({ real_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
- * given tensor.
- *
- * ```js
- * const x = tf.tensor([1, 2]);
- * tf.onesLike(x).print();
- * ```
- * @param x A tensor.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function onesLike_(x) {
- const $x = convertToTensor(x, 'x', 'onesLike');
- const forward = (backend, save) => {
- if ($x.dtype === 'complex64') {
- const r = onesLike(real($x));
- const i = zerosLike(imag($x));
- return complex(r, i);
- }
- return backend.onesLike($x);
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, OnesLike);
- }
- const onesLike = op({ onesLike_ });
-
- /**
- * Computes the outer product of two vectors, `v1` and `v2`.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- * const b = tf.tensor1d([3, 4, 5]);
- *
- * tf.outerProduct(a, b).print();
- * ```
- * @param v1 The first vector in the outer product operation.
- * @param v2 The second vector in the outer product operation.
- *
- * @doc {heading: 'Operations', subheading: 'Matrices'}
- */
- function outerProduct_(v1, v2) {
- const $v1 = convertToTensor(v1, 'v1', 'outerProduct');
- const $v2 = convertToTensor(v2, 'v2', 'outerProduct');
- assert($v1.rank === 1 && $v2.rank === 1, () => `Error in outerProduct: inputs must be rank 1, but got ranks ` +
- `${$v1.rank} and ${$v2.rank}.`);
- const v12D = reshape($v1, [-1, 1]);
- const v22D = reshape($v2, [1, -1]);
- return matMul(v12D, v22D);
- }
- const outerProduct = op({ outerProduct_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Pads a `tf.Tensor` with a given value and paddings.
- *
- * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
- * refer to `tf.mirrorPad`
- *
- * Also available are stricter rank-specific methods with the same signature
- * as this method that assert that `paddings` is of given length.
- * - `tf.pad1d`
- * - `tf.pad2d`
- * - `tf.pad3d`
- * - `tf.pad4d`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * x.pad([[1, 2]]).print();
- * ```
- * @param x The tensor to pad.
- * @param paddings An array of length `R` (the rank of the tensor), where
- * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
- * specifying how much to pad along each dimension of the tensor.
- * @param constantValue The pad value to use. Defaults to 0.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function pad_(x, paddings, constantValue = 0) {
- const $x = convertToTensor(x, 'x', 'pad');
- if ($x.rank === 0) {
- throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
- }
- const forward = (backend, save) => {
- save([$x]);
- return backend.pad($x, paddings, constantValue);
- };
- const attrs = { paddings, constantValue };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, PadV2, attrs);
- }
- const pad = op({ pad_ });
-
- /**
- * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
- */
- function pad1d_(x, paddings, constantValue = 0) {
- assert(paddings.length === 2, () => 'Invalid number of paddings. Must be length of 2.');
- return pad(x, [paddings], constantValue);
- }
- const pad1d = op({ pad1d_ });
-
- /**
- * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
- */
- function pad2d_(x, paddings, constantValue = 0) {
- assert(paddings.length === 2 && paddings[0].length === 2 &&
- paddings[1].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
- return pad(x, paddings, constantValue);
- }
- const pad2d = op({ pad2d_ });
-
- /**
- * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
- */
- function pad3d_(x, paddings, constantValue = 0) {
- assert(paddings.length === 3 && paddings[0].length === 2 &&
- paddings[1].length === 2 && paddings[2].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
- return pad(x, paddings, constantValue);
- }
- const pad3d = op({ pad3d_ });
-
- /**
- * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
- */
- function pad4d_(x, paddings, constantValue = 0) {
- assert(paddings.length === 4 && paddings[0].length === 2 &&
- paddings[1].length === 2 && paddings[2].length === 2 &&
- paddings[3].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
- return pad(x, paddings, constantValue);
- }
- const pad4d = op({ pad4d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
- * a grid of blocks of shape `blockShape`, and interleaves these blocks with
- * the "batch" dimension (0) such that in the output, the spatial
- * dimensions `[1, ..., M]` correspond to the position within the grid,
- * and the batch dimension combines both the position within a spatial block
- * and the original batch position. Prior to division into blocks,
- * the spatial dimensions of the input are optionally zero padded
- * according to `paddings`. See below for a precise description.
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
- * const blockShape = [2, 2];
- * const paddings = [[0, 0], [0, 0]];
- *
- * x.spaceToBatchND(blockShape, paddings).print();
- * ```
- *
- * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
- * remainingShape`, where spatialShape has `M` dimensions.
- * @param blockShape A 1-D array. Must have shape `[M]`, all values must
- * be >= 1.
- * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
- * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
- * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
- * is required that
- * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
- *
- * This operation is equivalent to the following steps:
- *
- * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
- * according to `paddings` to produce `padded` of shape paddedShape.
- *
- * 2. Reshape `padded` to `reshapedPadded` of shape:
- * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
- *
- * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
- * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1]] + remainingShape`
- *
- * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
- * batch dimension, producing an output tensor of shape:
- * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1]] + remainingShape`
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function spaceToBatchND_(x, blockShape, paddings) {
- const $x = convertToTensor(x, 'x', 'spaceToBatchND');
- assert($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
- assert(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
- assert($x.shape.reduce((a, b, i) => {
- if (i > 0 && i <= blockShape.length) {
- return a &&
- ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
- blockShape[i - 1] ===
- 0);
- }
- return a;
- }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
- const forward = backend => backend.spaceToBatchND($x, blockShape, paddings);
- const inputs = { x: $x };
- const attrs = { blockShape, paddings };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, SpaceToBatchND, attrs);
- }
- const spaceToBatchND = op({ spaceToBatchND_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Performs an N-D pooling operation
- *
- * @param input The input tensor, of rank 4 or rank 3 of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param windowShape The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param poolingType The type of pooling, either 'max' or 'avg'.
- * @param pad The type of padding algorithm:
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function pool_(input, windowShape, poolingType, pad, dilations, strides) {
- if (dilations == null) {
- dilations = [1, 1];
- }
- if (strides == null) {
- strides = 1;
- }
- if (pad === 0) {
- pad = 'valid';
- }
- const $x = convertToTensor(input, 'x', 'maxPool');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in pool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
- const dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
- // The following implementation does batchToSpace(pool(spaceToBatch(x)))
- // whenever dilation > 1 since the TF kernels do not support dilation > 1.
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
- let basePadding;
- if (pad === 'same') {
- basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
- }
- else {
- basePadding = [[0, 0], [0, 0]];
- }
- const isDilationOne = dilation[0] === 1 && dilation[1] === 1;
- const [adjustedPadding, adjustedCrops] = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding);
- const convertedPad = isDilationOne ? pad : 'valid';
- const convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);
- const forwardOp = poolingType === 'avg' ?
- () => avgPool(convertedX, windowShape, strides, convertedPad) :
- () => maxPool(convertedX, windowShape, strides, convertedPad);
- const y = forwardOp();
- const res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- // Helper function to compute crops and paddings for pool with dilation > 1.
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
- function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
- const padStart = basePadding.map(b => b[0]);
- const origPadEnd = basePadding.map(b => b[1]);
- const fullInputShape = inputShape.concat(padStart, origPadEnd);
- const padEndExtra = blockShape.map((b, i) => (b - fullInputShape[i] % b) % b);
- const padEnd = origPadEnd.map((s, i) => s + padEndExtra[i]);
- const paddings = blockShape.map((_, i) => [padStart[i], padEnd[i]]);
- const crops = blockShape.map((_, i) => [0, padEndExtra[i]]);
- return [paddings, crops];
- }
- // Helper function to compute base paddings for pool with dilation > 1.
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
- function withSpaceToBatchBasePaddings(filterShape, dilation) {
- // Spatial dimensions of the filters and the upsampled filters in which we
- // introduce (rate - 1) zeros between consecutive filter values.
- const dilatedFilterShape = filterShape.map((s, i) => {
- return s + (s - 1) * (dilation[i] - 1);
- });
- const padExtraShape = dilatedFilterShape.map(s => s - 1);
- // When padding is odd, we pad more at end, following the same
- // convention as conv2d.
- const padExtraStart = padExtraShape.map(s => Math.floor(s / 2));
- const padExtraEnd = padExtraShape.map((s, i) => s - padExtraStart[i]);
- return padExtraShape.map((_, i) => {
- return [padExtraStart[i], padExtraEnd[i]];
- });
- }
- const pool = op({ pool_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
- *
- * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
- * corresponding elements in x and y. The result's dtype will be the upcasted
- * type of the `base` and `exp` dtypes.
- *
- * ```js
- * const a = tf.tensor([[2, 3], [4, 5]])
- * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
- *
- * a.pow(b).print(); // or tf.pow(a, b)
- * ```
- *
- * ```js
- * const a = tf.tensor([[1, 2], [3, 4]])
- * const b = tf.tensor(2).toInt();
- *
- * a.pow(b).print(); // or tf.pow(a, b)
- * ```
- * We also expose `powStrict` which has the same signature as this op and
- * asserts that `base` and `exp` are the same shape (does not broadcast).
- *
- * @param base The base `tf.Tensor` to pow element-wise.
- * @param exp The exponent `tf.Tensor` to pow element-wise.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function pow_(base, exp) {
- let $base = convertToTensor(base, 'base', 'pow');
- let $exp = convertToTensor(exp, 'exp', 'pow');
- [$base, $exp] = makeTypesMatch($base, $exp);
- const inputs = { a: $base, b: $exp };
- const forward = (backend, save) => {
- const y = backend.pow($base, $exp);
- save([$base, $exp, y]);
- return y;
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Pow);
- }
- const pow = op({ pow_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes leaky rectified linear element-wise with parametric alphas.
- *
- * `x < 0 ? alpha * x : f(x) = x`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- * const alpha = tf.scalar(0.1);
- *
- * x.prelu(alpha).print(); // or tf.prelu(x, alpha)
- * ```
- * @param x The input tensor.
- * @param alpha Scaling factor for negative values.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function prelu_(x, alpha) {
- const $x = convertToTensor(x, 'x', 'prelu');
- const $alpha = convertToTensor(alpha, 'alpha', 'prelu');
- const forward = (backend, save) => {
- const res = backend.prelu($x, $alpha);
- save([$x, $alpha]);
- return res;
- };
- const inputs = { x: $x, alpha: $alpha };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Prelu);
- }
- const prelu = op({ prelu_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the product of elements across dimensions of a `tf.Tensor`.
- *
- * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
- * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
- * `axes`. If `keepDims` is true, the reduced dimensions are retained with
- * length 1. If `axes` has no entries, all dimensions are reduced, and a
- * `tf.Tensor` with a single element is returned.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3]);
- *
- * x.prod().print(); // or tf.prod(x)
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.prod(axis).print(); // or tf.prod(x, axis)
- * ```
- *
- * @param x The input tensor to compute the product over. If the dtype is `bool`
- * it will be converted to `int32` and the output dtype will be `int32`.
- * @param axis The dimension(s) to reduce. By default it reduces
- * all dimensions.
- * @param keepDims If true, retains reduced dimensions with size 1.
- *
- * @doc {heading: 'Operations', subheading: 'Reduction'}
- */
- function prod_(x, axis = null, keepDims = false) {
- let $x = convertToTensor(x, 'x', 'prod');
- const forward = (backend) => {
- if ($x.dtype === 'bool') {
- $x = cast($x, 'int32');
- }
- const axes = parseAxisParam(axis, $x.shape);
- const permutation = getAxesPermutation(axes, $x.rank);
- let reductionAxes = axes;
- let permutedX = $x;
- if (permutation != null) {
- permutedX = transpose($x, permutation);
- reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank);
- }
- let value = backend.prod(permutedX, reductionAxes);
- if (keepDims) {
- const newShape = expandShapeToKeepDim(value.shape, axes);
- value = reshape(value, newShape);
- }
- return value;
- };
- const inputs = { x: $x };
- const attrs = { axis, keepDims };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Prod, attrs);
- }
- const prod = op({ prod_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values sampled from a random number generator
- * function defined by the user.
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param randFunction A random number generator function which is called
- * for each element in the output tensor.
- * @param dtype The data type of the output tensor. Defaults to 'float32'.
- */
- function rand_(shape, randFunction, dtype) {
- const size = sizeFromShape(shape);
- let values = null;
- if (dtype == null || dtype === 'float32') {
- values = new Float32Array(size);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(size);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(size);
- }
- else {
- throw new Error(`Unknown data type ${dtype}`);
- }
- for (let i = 0; i < size; i++) {
- values[i] = randFunction();
- }
- return ENGINE.makeTensor(values, shape, dtype);
- }
- const rand = op({ rand_ });
-
- var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
-
- function unwrapExports (x) {
- return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
- }
-
- function createCommonjsModule(fn, module) {
- return module = { exports: {} }, fn(module, module.exports), module.exports;
- }
-
- function getCjsExportFromNamespace (n) {
- return n && n['default'] || n;
- }
-
- function commonjsRequire () {
- throw new Error('Dynamic requires are not currently supported by @rollup/plugin-commonjs');
- }
-
- var alea = createCommonjsModule(function (module) {
- // A port of an algorithm by Johannes Baagøe , 2010
- // http://baagoe.com/en/RandomMusings/javascript/
- // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
- // Original work is under MIT license -
-
- // Copyright (C) 2010 by Johannes Baagøe
- //
- // Permission is hereby granted, free of charge, to any person obtaining a copy
- // of this software and associated documentation files (the "Software"), to deal
- // in the Software without restriction, including without limitation the rights
- // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- // copies of the Software, and to permit persons to whom the Software is
- // furnished to do so, subject to the following conditions:
- //
- // The above copyright notice and this permission notice shall be included in
- // all copies or substantial portions of the Software.
- //
- // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- // THE SOFTWARE.
-
-
-
- (function(global, module, define) {
-
- function Alea(seed) {
- var me = this, mash = Mash();
-
- me.next = function() {
- var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
- me.s0 = me.s1;
- me.s1 = me.s2;
- return me.s2 = t - (me.c = t | 0);
- };
-
- // Apply the seeding algorithm from Baagoe.
- me.c = 1;
- me.s0 = mash(' ');
- me.s1 = mash(' ');
- me.s2 = mash(' ');
- me.s0 -= mash(seed);
- if (me.s0 < 0) { me.s0 += 1; }
- me.s1 -= mash(seed);
- if (me.s1 < 0) { me.s1 += 1; }
- me.s2 -= mash(seed);
- if (me.s2 < 0) { me.s2 += 1; }
- mash = null;
- }
-
- function copy(f, t) {
- t.c = f.c;
- t.s0 = f.s0;
- t.s1 = f.s1;
- t.s2 = f.s2;
- return t;
- }
-
- function impl(seed, opts) {
- var xg = new Alea(seed),
- state = opts && opts.state,
- prng = xg.next;
- prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
- prng.double = function() {
- return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
- };
- prng.quick = prng;
- if (state) {
- if (typeof(state) == 'object') copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- function Mash() {
- var n = 0xefc8249d;
-
- var mash = function(data) {
- data = data.toString();
- for (var i = 0; i < data.length; i++) {
- n += data.charCodeAt(i);
- var h = 0.02519603282416938 * n;
- n = h >>> 0;
- h -= n;
- h *= n;
- n = h >>> 0;
- h -= n;
- n += h * 0x100000000; // 2^32
- }
- return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
- };
-
- return mash;
- }
-
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.alea = impl;
- }
-
- })(
- commonjsGlobal,
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var xor128 = createCommonjsModule(function (module) {
- // A Javascript implementaion of the "xor128" prng algorithm by
- // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
-
- (function(global, module, define) {
-
- function XorGen(seed) {
- var me = this, strseed = '';
-
- me.x = 0;
- me.y = 0;
- me.z = 0;
- me.w = 0;
-
- // Set up generator function.
- me.next = function() {
- var t = me.x ^ (me.x << 11);
- me.x = me.y;
- me.y = me.z;
- me.z = me.w;
- return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
- };
-
- if (seed === (seed | 0)) {
- // Integer seed.
- me.x = seed;
- } else {
- // String seed.
- strseed += seed;
- }
-
- // Mix in string seed, then discard an initial batch of 64 values.
- for (var k = 0; k < strseed.length + 64; k++) {
- me.x ^= strseed.charCodeAt(k) | 0;
- me.next();
- }
- }
-
- function copy(f, t) {
- t.x = f.x;
- t.y = f.y;
- t.z = f.z;
- t.w = f.w;
- return t;
- }
-
- function impl(seed, opts) {
- var xg = new XorGen(seed),
- state = opts && opts.state,
- prng = function() { return (xg.next() >>> 0) / 0x100000000; };
- prng.double = function() {
- do {
- var top = xg.next() >>> 11,
- bot = (xg.next() >>> 0) / 0x100000000,
- result = (top + bot) / (1 << 21);
- } while (result === 0);
- return result;
- };
- prng.int32 = xg.next;
- prng.quick = prng;
- if (state) {
- if (typeof(state) == 'object') copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.xor128 = impl;
- }
-
- })(
- commonjsGlobal,
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var xorwow = createCommonjsModule(function (module) {
- // A Javascript implementaion of the "xorwow" prng algorithm by
- // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
-
- (function(global, module, define) {
-
- function XorGen(seed) {
- var me = this, strseed = '';
-
- // Set up generator function.
- me.next = function() {
- var t = (me.x ^ (me.x >>> 2));
- me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
- return (me.d = (me.d + 362437 | 0)) +
- (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
- };
-
- me.x = 0;
- me.y = 0;
- me.z = 0;
- me.w = 0;
- me.v = 0;
-
- if (seed === (seed | 0)) {
- // Integer seed.
- me.x = seed;
- } else {
- // String seed.
- strseed += seed;
- }
-
- // Mix in string seed, then discard an initial batch of 64 values.
- for (var k = 0; k < strseed.length + 64; k++) {
- me.x ^= strseed.charCodeAt(k) | 0;
- if (k == strseed.length) {
- me.d = me.x << 10 ^ me.x >>> 4;
- }
- me.next();
- }
- }
-
- function copy(f, t) {
- t.x = f.x;
- t.y = f.y;
- t.z = f.z;
- t.w = f.w;
- t.v = f.v;
- t.d = f.d;
- return t;
- }
-
- function impl(seed, opts) {
- var xg = new XorGen(seed),
- state = opts && opts.state,
- prng = function() { return (xg.next() >>> 0) / 0x100000000; };
- prng.double = function() {
- do {
- var top = xg.next() >>> 11,
- bot = (xg.next() >>> 0) / 0x100000000,
- result = (top + bot) / (1 << 21);
- } while (result === 0);
- return result;
- };
- prng.int32 = xg.next;
- prng.quick = prng;
- if (state) {
- if (typeof(state) == 'object') copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.xorwow = impl;
- }
-
- })(
- commonjsGlobal,
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var xorshift7 = createCommonjsModule(function (module) {
- // A Javascript implementaion of the "xorshift7" algorithm by
- // François Panneton and Pierre L'ecuyer:
- // "On the Xorgshift Random Number Generators"
- // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
-
- (function(global, module, define) {
-
- function XorGen(seed) {
- var me = this;
-
- // Set up generator function.
- me.next = function() {
- // Update xor generator.
- var X = me.x, i = me.i, t, v, w;
- t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
- t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
- t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
- t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
- t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
- X[i] = v;
- me.i = (i + 1) & 7;
- return v;
- };
-
- function init(me, seed) {
- var j, w, X = [];
-
- if (seed === (seed | 0)) {
- // Seed state array using a 32-bit integer.
- w = X[0] = seed;
- } else {
- // Seed state using a string.
- seed = '' + seed;
- for (j = 0; j < seed.length; ++j) {
- X[j & 7] = (X[j & 7] << 15) ^
- (seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
- }
- }
- // Enforce an array length of 8, not all zeroes.
- while (X.length < 8) X.push(0);
- for (j = 0; j < 8 && X[j] === 0; ++j);
- if (j == 8) w = X[7] = -1; else w = X[j];
-
- me.x = X;
- me.i = 0;
-
- // Discard an initial 256 values.
- for (j = 256; j > 0; --j) {
- me.next();
- }
- }
-
- init(me, seed);
- }
-
- function copy(f, t) {
- t.x = f.x.slice();
- t.i = f.i;
- return t;
- }
-
- function impl(seed, opts) {
- if (seed == null) seed = +(new Date);
- var xg = new XorGen(seed),
- state = opts && opts.state,
- prng = function() { return (xg.next() >>> 0) / 0x100000000; };
- prng.double = function() {
- do {
- var top = xg.next() >>> 11,
- bot = (xg.next() >>> 0) / 0x100000000,
- result = (top + bot) / (1 << 21);
- } while (result === 0);
- return result;
- };
- prng.int32 = xg.next;
- prng.quick = prng;
- if (state) {
- if (state.x) copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.xorshift7 = impl;
- }
-
- })(
- commonjsGlobal,
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var xor4096 = createCommonjsModule(function (module) {
- // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
- //
- // This fast non-cryptographic random number generator is designed for
- // use in Monte-Carlo algorithms. It combines a long-period xorshift
- // generator with a Weyl generator, and it passes all common batteries
- // of stasticial tests for randomness while consuming only a few nanoseconds
- // for each prng generated. For background on the generator, see Brent's
- // paper: "Some long-period random number generators using shifts and xors."
- // http://arxiv.org/pdf/1004.3115v1.pdf
- //
- // Usage:
- //
- // var xor4096 = require('xor4096');
- // random = xor4096(1); // Seed with int32 or string.
- // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
- // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
- //
- // For nonzero numeric keys, this impelementation provides a sequence
- // identical to that by Brent's xorgens 3 implementaion in C. This
- // implementation also provides for initalizing the generator with
- // string seeds, or for saving and restoring the state of the generator.
- //
- // On Chrome, this prng benchmarks about 2.1 times slower than
- // Javascript's built-in Math.random().
-
- (function(global, module, define) {
-
- function XorGen(seed) {
- var me = this;
-
- // Set up generator function.
- me.next = function() {
- var w = me.w,
- X = me.X, i = me.i, t, v;
- // Update Weyl generator.
- me.w = w = (w + 0x61c88647) | 0;
- // Update xor generator.
- v = X[(i + 34) & 127];
- t = X[i = ((i + 1) & 127)];
- v ^= v << 13;
- t ^= t << 17;
- v ^= v >>> 15;
- t ^= t >>> 12;
- // Update Xor generator array state.
- v = X[i] = v ^ t;
- me.i = i;
- // Result is the combination.
- return (v + (w ^ (w >>> 16))) | 0;
- };
-
- function init(me, seed) {
- var t, v, i, j, w, X = [], limit = 128;
- if (seed === (seed | 0)) {
- // Numeric seeds initialize v, which is used to generates X.
- v = seed;
- seed = null;
- } else {
- // String seeds are mixed into v and X one character at a time.
- seed = seed + '\0';
- v = 0;
- limit = Math.max(limit, seed.length);
- }
- // Initialize circular array and weyl value.
- for (i = 0, j = -32; j < limit; ++j) {
- // Put the unicode characters into the array, and shuffle them.
- if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
- // After 32 shuffles, take v as the starting w value.
- if (j === 0) w = v;
- v ^= v << 10;
- v ^= v >>> 15;
- v ^= v << 4;
- v ^= v >>> 13;
- if (j >= 0) {
- w = (w + 0x61c88647) | 0; // Weyl.
- t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.
- i = (0 == t) ? i + 1 : 0; // Count zeroes.
- }
- }
- // We have detected all zeroes; make the key nonzero.
- if (i >= 128) {
- X[(seed && seed.length || 0) & 127] = -1;
- }
- // Run the generator 512 times to further mix the state before using it.
- // Factoring this as a function slows the main generator, so it is just
- // unrolled here. The weyl generator is not advanced while warming up.
- i = 127;
- for (j = 4 * 128; j > 0; --j) {
- v = X[(i + 34) & 127];
- t = X[i = ((i + 1) & 127)];
- v ^= v << 13;
- t ^= t << 17;
- v ^= v >>> 15;
- t ^= t >>> 12;
- X[i] = v ^ t;
- }
- // Storing state as object members is faster than using closure variables.
- me.w = w;
- me.X = X;
- me.i = i;
- }
-
- init(me, seed);
- }
-
- function copy(f, t) {
- t.i = f.i;
- t.w = f.w;
- t.X = f.X.slice();
- return t;
- };
-
- function impl(seed, opts) {
- if (seed == null) seed = +(new Date);
- var xg = new XorGen(seed),
- state = opts && opts.state,
- prng = function() { return (xg.next() >>> 0) / 0x100000000; };
- prng.double = function() {
- do {
- var top = xg.next() >>> 11,
- bot = (xg.next() >>> 0) / 0x100000000,
- result = (top + bot) / (1 << 21);
- } while (result === 0);
- return result;
- };
- prng.int32 = xg.next;
- prng.quick = prng;
- if (state) {
- if (state.X) copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.xor4096 = impl;
- }
-
- })(
- commonjsGlobal, // window object or global
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var tychei = createCommonjsModule(function (module) {
- // A Javascript implementaion of the "Tyche-i" prng algorithm by
- // Samuel Neves and Filipe Araujo.
- // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
-
- (function(global, module, define) {
-
- function XorGen(seed) {
- var me = this, strseed = '';
-
- // Set up generator function.
- me.next = function() {
- var b = me.b, c = me.c, d = me.d, a = me.a;
- b = (b << 25) ^ (b >>> 7) ^ c;
- c = (c - d) | 0;
- d = (d << 24) ^ (d >>> 8) ^ a;
- a = (a - b) | 0;
- me.b = b = (b << 20) ^ (b >>> 12) ^ c;
- me.c = c = (c - d) | 0;
- me.d = (d << 16) ^ (c >>> 16) ^ a;
- return me.a = (a - b) | 0;
- };
-
- /* The following is non-inverted tyche, which has better internal
- * bit diffusion, but which is about 25% slower than tyche-i in JS.
- me.next = function() {
- var a = me.a, b = me.b, c = me.c, d = me.d;
- a = (me.a + me.b | 0) >>> 0;
- d = me.d ^ a; d = d << 16 ^ d >>> 16;
- c = me.c + d | 0;
- b = me.b ^ c; b = b << 12 ^ d >>> 20;
- me.a = a = a + b | 0;
- d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
- me.c = c = c + d | 0;
- b = b ^ c;
- return me.b = (b << 7 ^ b >>> 25);
- }
- */
-
- me.a = 0;
- me.b = 0;
- me.c = 2654435769 | 0;
- me.d = 1367130551;
-
- if (seed === Math.floor(seed)) {
- // Integer seed.
- me.a = (seed / 0x100000000) | 0;
- me.b = seed | 0;
- } else {
- // String seed.
- strseed += seed;
- }
-
- // Mix in string seed, then discard an initial batch of 64 values.
- for (var k = 0; k < strseed.length + 20; k++) {
- me.b ^= strseed.charCodeAt(k) | 0;
- me.next();
- }
- }
-
- function copy(f, t) {
- t.a = f.a;
- t.b = f.b;
- t.c = f.c;
- t.d = f.d;
- return t;
- };
-
- function impl(seed, opts) {
- var xg = new XorGen(seed),
- state = opts && opts.state,
- prng = function() { return (xg.next() >>> 0) / 0x100000000; };
- prng.double = function() {
- do {
- var top = xg.next() >>> 11,
- bot = (xg.next() >>> 0) / 0x100000000,
- result = (top + bot) / (1 << 21);
- } while (result === 0);
- return result;
- };
- prng.int32 = xg.next;
- prng.quick = prng;
- if (state) {
- if (typeof(state) == 'object') copy(state, xg);
- prng.state = function() { return copy(xg, {}); };
- }
- return prng;
- }
-
- if (module && module.exports) {
- module.exports = impl;
- } else if (define && define.amd) {
- define(function() { return impl; });
- } else {
- this.tychei = impl;
- }
-
- })(
- commonjsGlobal,
- ('object') == 'object' && module, // present in node.js
- (typeof undefined) == 'function' && undefined // present with an AMD loader
- );
- });
-
- var seedrandom = createCommonjsModule(function (module) {
- /*
- Copyright 2014 David Bau.
-
- Permission is hereby granted, free of charge, to any person obtaining
- a copy of this software and associated documentation files (the
- "Software"), to deal in the Software without restriction, including
- without limitation the rights to use, copy, modify, merge, publish,
- distribute, sublicense, and/or sell copies of the Software, and to
- permit persons to whom the Software is furnished to do so, subject to
- the following conditions:
-
- The above copyright notice and this permission notice shall be
- included in all copies or substantial portions of the Software.
-
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
- */
-
- (function (pool, math) {
- //
- // The following constants are related to IEEE 754 limits.
- //
- var global = this,
- width = 256, // each RC4 output is 0 <= x < 256
- chunks = 6, // at least six RC4 outputs for each double
- digits = 52, // there are 52 significant digits in a double
- rngname = 'random', // rngname: name for Math.random and Math.seedrandom
- startdenom = math.pow(width, chunks),
- significance = math.pow(2, digits),
- overflow = significance * 2,
- mask = width - 1,
- nodecrypto; // node.js crypto module, initialized at the bottom.
-
- //
- // seedrandom()
- // This is the seedrandom function described above.
- //
- function seedrandom(seed, options, callback) {
- var key = [];
- options = (options == true) ? { entropy: true } : (options || {});
-
- // Flatten the seed string or build one from local entropy if needed.
- var shortseed = mixkey(flatten(
- options.entropy ? [seed, tostring(pool)] :
- (seed == null) ? autoseed() : seed, 3), key);
-
- // Use the seed to initialize an ARC4 generator.
- var arc4 = new ARC4(key);
-
- // This function returns a random double in [0, 1) that contains
- // randomness in every bit of the mantissa of the IEEE 754 value.
- var prng = function() {
- var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48
- d = startdenom, // and denominator d = 2 ^ 48.
- x = 0; // and no 'extra last byte'.
- while (n < significance) { // Fill up all significant digits by
- n = (n + x) * width; // shifting numerator and
- d *= width; // denominator and generating a
- x = arc4.g(1); // new least-significant-byte.
- }
- while (n >= overflow) { // To avoid rounding up, before adding
- n /= 2; // last byte, shift everything
- d /= 2; // right using integer math until
- x >>>= 1; // we have exactly the desired bits.
- }
- return (n + x) / d; // Form the number within [0, 1).
- };
-
- prng.int32 = function() { return arc4.g(4) | 0; };
- prng.quick = function() { return arc4.g(4) / 0x100000000; };
- prng.double = prng;
-
- // Mix the randomness into accumulated entropy.
- mixkey(tostring(arc4.S), pool);
-
- // Calling convention: what to return as a function of prng, seed, is_math.
- return (options.pass || callback ||
- function(prng, seed, is_math_call, state) {
- if (state) {
- // Load the arc4 state from the given state if it has an S array.
- if (state.S) { copy(state, arc4); }
- // Only provide the .state method if requested via options.state.
- prng.state = function() { return copy(arc4, {}); };
- }
-
- // If called as a method of Math (Math.seedrandom()), mutate
- // Math.random because that is how seedrandom.js has worked since v1.0.
- if (is_math_call) { math[rngname] = prng; return seed; }
-
- // Otherwise, it is a newer calling convention, so return the
- // prng directly.
- else return prng;
- })(
- prng,
- shortseed,
- 'global' in options ? options.global : (this == math),
- options.state);
- }
- math['seed' + rngname] = seedrandom;
-
- //
- // ARC4
- //
- // An ARC4 implementation. The constructor takes a key in the form of
- // an array of at most (width) integers that should be 0 <= x < (width).
- //
- // The g(count) method returns a pseudorandom integer that concatenates
- // the next (count) outputs from ARC4. Its return value is a number x
- // that is in the range 0 <= x < (width ^ count).
- //
- function ARC4(key) {
- var t, keylen = key.length,
- me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
-
- // The empty key [] is treated as [0].
- if (!keylen) { key = [keylen++]; }
-
- // Set up S using the standard key scheduling algorithm.
- while (i < width) {
- s[i] = i++;
- }
- for (i = 0; i < width; i++) {
- s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
- s[j] = t;
- }
-
- // The "g" method returns the next (count) outputs as one number.
- (me.g = function(count) {
- // Using instance members instead of closure state nearly doubles speed.
- var t, r = 0,
- i = me.i, j = me.j, s = me.S;
- while (count--) {
- t = s[i = mask & (i + 1)];
- r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
- }
- me.i = i; me.j = j;
- return r;
- // For robust unpredictability, the function call below automatically
- // discards an initial batch of values. This is called RC4-drop[256].
- // See http://google.com/search?q=rsa+fluhrer+response&btnI
- })(width);
- }
-
- //
- // copy()
- // Copies internal state of ARC4 to or from a plain object.
- //
- function copy(f, t) {
- t.i = f.i;
- t.j = f.j;
- t.S = f.S.slice();
- return t;
- };
-
- //
- // flatten()
- // Converts an object tree to nested arrays of strings.
- //
- function flatten(obj, depth) {
- var result = [], typ = (typeof obj), prop;
- if (depth && typ == 'object') {
- for (prop in obj) {
- try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
- }
- }
- return (result.length ? result : typ == 'string' ? obj : obj + '\0');
- }
-
- //
- // mixkey()
- // Mixes a string seed into a key that is an array of integers, and
- // returns a shortened string seed that is equivalent to the result key.
- //
- function mixkey(seed, key) {
- var stringseed = seed + '', smear, j = 0;
- while (j < stringseed.length) {
- key[mask & j] =
- mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
- }
- return tostring(key);
- }
-
- //
- // autoseed()
- // Returns an object for autoseeding, using window.crypto and Node crypto
- // module if available.
- //
- function autoseed() {
- try {
- var out;
- if (nodecrypto && (out = nodecrypto.randomBytes)) {
- // The use of 'out' to remember randomBytes makes tight minified code.
- out = out(width);
- } else {
- out = new Uint8Array(width);
- (global.crypto || global.msCrypto).getRandomValues(out);
- }
- return tostring(out);
- } catch (e) {
- var browser = global.navigator,
- plugins = browser && browser.plugins;
- return [+new Date, global, plugins, global.screen, tostring(pool)];
- }
- }
-
- //
- // tostring()
- // Converts an array of charcodes to a string
- //
- function tostring(a) {
- return String.fromCharCode.apply(0, a);
- }
-
- //
- // When seedrandom.js is loaded, we immediately mix a few bits
- // from the built-in RNG into the entropy pool. Because we do
- // not want to interfere with deterministic PRNG state later,
- // seedrandom will not call math.random on its own again after
- // initialization.
- //
- mixkey(math.random(), pool);
-
- //
- // Nodejs and AMD support: export the implementation as a module using
- // either convention.
- //
- if (('object') == 'object' && module.exports) {
- module.exports = seedrandom;
- // When in node.js, try using crypto package for autoseeding.
- try {
- nodecrypto = require('crypto');
- } catch (ex) {}
- } else if ((typeof undefined) == 'function' && undefined.amd) {
- undefined(function() { return seedrandom; });
- }
-
- // End anonymous scope, and pass initial values.
- })(
- [], // pool: entropy pool starts empty
- Math // math: package containing random, pow, and seedrandom
- );
- });
-
- // A library of seedable RNGs implemented in Javascript.
- //
- // Usage:
- //
- // var seedrandom = require('seedrandom');
- // var random = seedrandom(1); // or any seed.
- // var x = random(); // 0 <= x < 1. Every bit is random.
- // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
-
- // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
- // Period: ~2^116
- // Reported to pass all BigCrush tests.
-
-
- // xor128, a pure xor-shift generator by George Marsaglia.
- // Period: 2^128-1.
- // Reported to fail: MatrixRank and LinearComp.
-
-
- // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
- // Period: 2^192-2^32
- // Reported to fail: CollisionOver, SimpPoker, and LinearComp.
-
-
- // xorshift7, by François Panneton and Pierre L'ecuyer, takes
- // a different approach: it adds robustness by allowing more shifts
- // than Marsaglia's original three. It is a 7-shift generator
- // with 256 bits, that passes BigCrush with no systmatic failures.
- // Period 2^256-1.
- // No systematic BigCrush failures reported.
-
-
- // xor4096, by Richard Brent, is a 4096-bit xor-shift with a
- // very long period that also adds a Weyl generator. It also passes
- // BigCrush with no systematic failures. Its long period may
- // be useful if you have many generators and need to avoid
- // collisions.
- // Period: 2^4128-2^32.
- // No systematic BigCrush failures reported.
-
-
- // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
- // number generator derived from ChaCha, a modern stream cipher.
- // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
- // Period: ~2^127
- // No systematic BigCrush failures reported.
-
-
- // The original ARC4-based prng included in this library.
- // Period: ~2^1600
-
-
- seedrandom.alea = alea;
- seedrandom.xor128 = xor128;
- seedrandom.xorwow = xorwow;
- seedrandom.xorshift7 = xorshift7;
- seedrandom.xor4096 = xor4096;
- seedrandom.tychei = tychei;
-
- var seedrandom$1 = seedrandom;
- var seedrandom_1 = seedrandom$1.alea;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // https://en.wikipedia.org/wiki/Marsaglia_polar_method
- class MPRandGauss {
- constructor(mean, stdDeviation, dtype, truncated, seed) {
- this.mean = mean;
- this.stdDev = stdDeviation;
- this.dtype = dtype;
- this.nextVal = NaN;
- this.truncated = truncated;
- if (this.truncated) {
- this.upper = this.mean + this.stdDev * 2;
- this.lower = this.mean - this.stdDev * 2;
- }
- const seedValue = seed ? seed : Math.random();
- this.random = seedrandom_1(seedValue.toString());
- }
- /** Returns next sample from a Gaussian distribution. */
- nextValue() {
- if (!isNaN(this.nextVal)) {
- const value = this.nextVal;
- this.nextVal = NaN;
- return value;
- }
- let resultX, resultY;
- let isValid = false;
- while (!isValid) {
- let v1, v2, s;
- do {
- v1 = 2 * this.random() - 1;
- v2 = 2 * this.random() - 1;
- s = v1 * v1 + v2 * v2;
- } while (s >= 1 || s === 0);
- const mul = Math.sqrt(-2.0 * Math.log(s) / s);
- resultX = this.mean + this.stdDev * v1 * mul;
- resultY = this.mean + this.stdDev * v2 * mul;
- if (!this.truncated || this.isValidTruncated(resultX)) {
- isValid = true;
- }
- }
- if (!this.truncated || this.isValidTruncated(resultY)) {
- this.nextVal = this.convertValue(resultY);
- }
- return this.convertValue(resultX);
- }
- /** Handles proper rounding for non-floating-point numbers. */
- convertValue(value) {
- if (this.dtype == null || this.dtype === 'float32') {
- return value;
- }
- return Math.round(value);
- }
- /** Returns true if less than 2-standard-deviations from the mean. */
- isValidTruncated(value) {
- return value <= this.upper && value >= this.lower;
- }
- }
- // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
- // Gamma Variables."
- class RandGamma {
- constructor(alpha, beta, dtype, seed) {
- this.alpha = alpha;
- this.beta = 1 / beta; // convert rate to scale parameter
- this.dtype = dtype;
- const seedValue = seed ? seed : Math.random();
- this.randu = seedrandom_1(seedValue.toString());
- this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
- if (alpha < 1) {
- this.d = alpha + (2 / 3);
- }
- else {
- this.d = alpha - (1 / 3);
- }
- this.c = 1 / Math.sqrt(9 * this.d);
- }
- /** Returns next sample from a gamma distribution. */
- nextValue() {
- let x2, v0, v1, x, u, v;
- while (true) {
- do {
- x = this.randn.nextValue();
- v = 1 + (this.c * x);
- } while (v <= 0);
- v *= v * v;
- x2 = x * x;
- v0 = 1 - (0.331 * x2 * x2);
- v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));
- u = this.randu();
- if (u < v0 || Math.log(u) < v1) {
- break;
- }
- }
- v = (1 / this.beta) * this.d * v;
- if (this.alpha < 1) {
- v *= Math.pow(this.randu(), 1 / this.alpha);
- }
- return this.convertValue(v);
- }
- /** Handles proper rounding for non-floating-point numbers. */
- convertValue(value) {
- if (this.dtype === 'float32') {
- return value;
- }
- return Math.round(value);
- }
- }
- class UniformRandom {
- constructor(min = 0, max = 1, dtype, seed) {
- /** Handles proper rounding for non floating point numbers. */
- this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32');
- this.min = min;
- this.range = max - min;
- this.dtype = dtype;
- if (seed == null) {
- seed = Math.random();
- }
- if (typeof seed === 'number') {
- seed = seed.toString();
- }
- if (!this.canReturnFloat() && this.range <= 1) {
- throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`);
- }
- this.random = seedrandom_1(seed);
- }
- convertValue(value) {
- if (this.canReturnFloat()) {
- return value;
- }
- return Math.round(value);
- }
- nextValue() {
- return this.convertValue(this.min + this.range * this.random());
- }
- }
- function jarqueBeraNormalityTest(values) {
- // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test
- const n = values.length;
- const s = skewness(values);
- const k = kurtosis(values);
- const jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2));
- // JB test requires 2-degress of freedom from Chi-Square @ 0.95:
- // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm
- const CHI_SQUARE_2DEG = 5.991;
- if (jb > CHI_SQUARE_2DEG) {
- throw new Error(`Invalid p-value for JB: ${jb}`);
- }
- }
- function expectArrayInMeanStdRange(actual, expectedMean, expectedStdDev, epsilon) {
- if (epsilon == null) {
- epsilon = testEpsilon();
- }
- const actualMean = mean$1(actual);
- expectNumbersClose(actualMean, expectedMean, epsilon);
- expectNumbersClose(standardDeviation(actual, actualMean), expectedStdDev, epsilon);
- }
- function mean$1(values) {
- let sum = 0;
- for (let i = 0; i < values.length; i++) {
- sum += values[i];
- }
- return sum / values.length;
- }
- function standardDeviation(values, mean) {
- let squareDiffSum = 0;
- for (let i = 0; i < values.length; i++) {
- const diff = values[i] - mean;
- squareDiffSum += diff * diff;
- }
- return Math.sqrt(squareDiffSum / values.length);
- }
- function kurtosis(values) {
- // https://en.wikipedia.org/wiki/Kurtosis
- const valuesMean = mean$1(values);
- const n = values.length;
- let sum2 = 0;
- let sum4 = 0;
- for (let i = 0; i < n; i++) {
- const v = values[i] - valuesMean;
- sum2 += Math.pow(v, 2);
- sum4 += Math.pow(v, 4);
- }
- return (1 / n) * sum4 / Math.pow((1 / n) * sum2, 2);
- }
- function skewness(values) {
- // https://en.wikipedia.org/wiki/Skewness
- const valuesMean = mean$1(values);
- const n = values.length;
- let sum2 = 0;
- let sum3 = 0;
- for (let i = 0; i < n; i++) {
- const v = values[i] - valuesMean;
- sum2 += Math.pow(v, 2);
- sum3 += Math.pow(v, 3);
- }
- return (1 / n) * sum3 / Math.pow((1 / (n - 1)) * sum2, 3 / 2);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values sampled from a gamma distribution.
- *
- * ```js
- * tf.randomGamma([2, 2], 1).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param alpha The shape parameter of the gamma distribution.
- * @param beta The inverse scale parameter of the gamma distribution. Defaults
- * to 1.
- * @param dtype The data type of the output. Defaults to float32.
- * @param seed The seed for the random number generator.
- *
- * @doc {heading: 'Tensors', subheading: 'Random'}
- */
- function randomGamma_(shape, alpha, beta = 1, dtype = 'float32', seed) {
- if (beta == null) {
- beta = 1;
- }
- if (dtype == null) {
- dtype = 'float32';
- }
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new Error(`Unsupported data type ${dtype}`);
- }
- const rgamma = new RandGamma(alpha, beta, dtype, seed);
- const res = buffer(shape, dtype);
- for (let i = 0; i < res.values.length; i++) {
- res.values[i] = rgamma.nextValue();
- }
- return res.toTensor();
- }
- const randomGamma = op({ randomGamma_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values sampled from a normal distribution.
- *
- * ```js
- * tf.randomNormal([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param mean The mean of the normal distribution.
- * @param stdDev The standard deviation of the normal distribution.
- * @param dtype The data type of the output.
- * @param seed The seed for the random number generator.
- *
- * @doc {heading: 'Tensors', subheading: 'Random'}
- */
- function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
- if (dtype != null && dtype === 'bool') {
- throw new Error(`Unsupported data type ${dtype}`);
- }
- const randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
- const res = buffer(shape, dtype);
- for (let i = 0; i < res.values.length; i++) {
- res.values[i] = randGauss.nextValue();
- }
- return res.toTensor();
- }
- const randomNormal = op({ randomNormal_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values sampled from a uniform distribution.
- *
- * The generated values follow a uniform distribution in the range [minval,
- * maxval). The lower bound minval is included in the range, while the upper
- * bound maxval is excluded.
- *
- * ```js
- * tf.randomUniform([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param minval The lower bound on the range of random values to generate.
- * Defaults to 0.
- * @param maxval The upper bound on the range of random values to generate.
- * Defaults to 1.
- * @param dtype The data type of the output tensor. Defaults to 'float32'.
- *
- * @doc {heading: 'Tensors', subheading: 'Random'}
- */
- function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) {
- const res = buffer(shape, dtype);
- const random = new UniformRandom(minval, maxval, null, seed);
- for (let i = 0; i < res.values.length; i++) {
- res.values[i] = random.nextValue();
- }
- return res.toTensor();
- }
- const randomUniform = op({ randomUniform_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor1d` as it makes the code more readable.
- *
- * ```js
- * tf.tensor1d([1, 2, 3]).print();
- * ```
- *
- * @param values The values of the tensor. Can be array of numbers,
- * or a `TypedArray`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor1d(values, dtype) {
- assertNonNull(values);
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 1) {
- throw new Error('tensor1d() requires values to be a flat/TypedArray');
- }
- const shape = null;
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
- *
- * The tensor is a is half-open interval meaning it includes start, but
- * excludes stop. Decrementing ranges and negative step values are also
- * supported.sv
- *
- *
- * ```js
- * tf.range(0, 9, 2).print();
- * ```
- *
- * @param start An integer start value
- * @param stop An integer stop value
- * @param step An integer increment (will default to 1 or -1)
- * @param dtype The data type of the output tensor. Defaults to 'float32'.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function range(start, stop, step = 1, dtype = 'float32') {
- if (step === 0) {
- throw new Error('Cannot have a step of zero');
- }
- const forward = () => {
- const sameStartStop = start === stop;
- const increasingRangeNegativeStep = start < stop && step < 0;
- const decreasingRangePositiveStep = stop < start && step > 1;
- if (sameStartStop || increasingRangeNegativeStep ||
- decreasingRangePositiveStep) {
- return zeros([0], dtype);
- }
- const numElements = Math.abs(Math.ceil((stop - start) / step));
- const values = makeZerosTypedArray(numElements, dtype);
- if (stop < start && step === 1) {
- // Auto adjust the step's sign if it hasn't been set
- // (or was set to 1)
- step = -1;
- }
- values[0] = start;
- for (let i = 1; i < values.length; i++) {
- values[i] = values[i - 1] + step;
- }
- return tensor1d(values, dtype);
- };
- const attrs = { start, stop, step, dtype };
- return ENGINE.runKernelFunc(forward, {} /* inputs */, null /* grad */, Range, attrs);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes reciprocal of x element-wise: `1 / x`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, 2]);
- *
- * x.reciprocal().print(); // or tf.reciprocal(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function reciprocal_(x) {
- const $x = convertToTensor(x, 'x', 'reciprocal');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.reciprocal($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Reciprocal);
- }
- const reciprocal = op({ reciprocal_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes rectified linear element-wise: `max(x, 0)`.
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.relu().print(); // or tf.relu(x)
- * ```
- * @param x The input tensor. If the dtype is `bool`, the output dtype will be
- * `int32'.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function relu_(x) {
- const $x = convertToTensor(x, 'x', 'relu');
- const forward = (backend, save) => {
- save([$x]);
- if ($x.dtype === 'bool') {
- return cast($x, 'int32');
- }
- return backend.relu($x);
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Relu);
- }
- const relu = op({ relu_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 8]);
- *
- * x.relu6().print(); // or tf.relu6(x)
- * ```
- * @param x The input tensor. If the dtype is `bool`, the output dtype will be
- * `int32'.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function relu6_(x) {
- const $x = convertToTensor(x, 'x', 'relu6');
- const forward = (backend, save) => {
- save([$x]);
- if ($x.dtype === 'bool') {
- return cast($x, 'int32');
- }
- return backend.relu6($x);
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Relu6);
- }
- const relu6 = op({ relu6_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reverses a `tf.Tensor` along a specified axis.
- *
- * Also available are stricter rank-specific methods that assert that `x` is
- * of the given rank:
- * - `tf.reverse1d`
- * - `tf.reverse2d`
- * - `tf.reverse3d`
- * - `tf.reverse4d`
- *
- * Except `tf.reverse1d` (which does not have axis param), all methods have
- * same signature as this method.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- *
- * x.reverse().print();
- * ```
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * const axis = 1;
- * x.reverse(axis).print();
- * ```
- * @param x The input tensor to be reversed.
- * @param axis The set of dimensions to reverse. Must be in the
- * range [-rank(x), rank(x)). Defaults to all axes.
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function reverse_(x, axis) {
- const $x = convertToTensor(x, 'x', 'reverse');
- const forward = (backend) => {
- const axes = parseAxisParam(axis, $x.shape);
- if ($x.rank === 0) {
- return clone($x);
- }
- const res = backend.reverse($x, axes);
- return reshape(res, $x.shape);
- };
- const inputs = { x: $x };
- const attrs = { dims: axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Reverse, attrs);
- }
- const reverse = op({ reverse_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reverses a `tf.Tensor1D`.
- *
- * @param x The input tensor.
- */
- function reverse1d_(x) {
- const $x = convertToTensor(x, 'x', 'reverse');
- assert($x.rank === 1, () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`);
- return reverse($x, 0);
- }
- const reverse1d = op({ reverse1d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reverses a `tf.Tensor2D` along a specified axis.
- *
- * @param x The input tensor.
- * @param axis The set of dimensions to reverse. Must be in the
- * range [-rank(x), rank(x)). Defaults to all axes.
- */
- function reverse2d_(x, axis) {
- const $x = convertToTensor(x, 'x', 'reverse');
- assert($x.rank === 2, () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`);
- return reverse($x, axis);
- }
- const reverse2d = op({ reverse2d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reverses a `tf.Tensor3D` along a specified axis.
- *
- * @param x The input tensor.
- * @param axis The set of dimensions to reverse. Must be in the
- * range [-rank(x), rank(x)). Defaults to all axes.
- */
- function reverse3d_(x, axis) {
- const $x = convertToTensor(x, 'x', 'reverse');
- assert($x.rank === 3, () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`);
- return reverse($x, axis);
- }
- const reverse3d = op({ reverse3d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reverses a `tf.Tensor4D` along a specified axis.
- *
- * @param x The input tensor.
- * @param axis The set of dimensions to reverse. Must be in the
- * range [-rank(x), rank(x)). Defaults to all axes.
- */
- function reverse4d_(x, axis) {
- const $x = convertToTensor(x, 'x', 'reverse');
- assert($x.rank === 4, () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`);
- return reverse($x, axis);
- }
- const reverse4d = op({ reverse4d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes round of input `tf.Tensor` element-wise: `round(x)`.
- * It implements banker's rounding.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.round().print(); // or tf.round(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function round_(x) {
- const $x = convertToTensor(x, 'x', 'round');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend) => backend.round($x), inputs, null /* grad */, Round);
- }
- const round = op({ round_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
- * `y = 1 / sqrt(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 4, -1]);
- *
- * x.rsqrt().print(); // or tf.rsqrt(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function rsqrt_(x) {
- const $x = convertToTensor(x, 'x', 'rsqrt');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.rsqrt($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Rsqrt);
- }
- const rsqrt = op({ rsqrt_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes scaled exponential linear element-wise.
- *
- * `x < 0 ? scale * alpha * (exp(x) - 1) : x`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.selu().print(); // or tf.selu(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function selu_(x) {
- const $x = convertToTensor(x, 'x', 'selu');
- const forward = (backend, save) => {
- const res = backend.selu($x);
- save([$x]);
- return res;
- };
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Selu);
- }
- const selu = op({ selu_ });
-
- /**
- * 2-D convolution with separable filters.
- *
- * Performs a depthwise convolution that acts separately on channels followed
- * by a pointwise convolution that mixes channels. Note that this is
- * separability between dimensions [1, 2] and 3, not spatial separability
- * between dimensions 1 and 2.
- *
- * See
- * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
- * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
- * for more details.
- *
- * @param x The input tensor, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
- * assumed.
- * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
- * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
- * the filter used in the first step.
- * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
- * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
- * the filter used in the second step.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`. If strides is a single number, then `strideHeight ==
- * strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels]. Only "NHWC" is currently supported.
- *
- * @doc {heading: 'Operations', subheading: 'Convolution'}
- */
- function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation = [1, 1], dataFormat = 'NHWC') {
- const $x = convertToTensor(x, 'x', 'separableConv2d');
- const $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
- const $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- if (dataFormat === 'NCHW') {
- throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' +
- 'NHWC is supported');
- }
- assert(x4D.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got ` +
- `rank ${x4D.rank}.`);
- assert($depthwiseFilter.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but ` +
- `got rank ${$depthwiseFilter.rank}.`);
- assert($pointwiseFilter.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but ` +
- `got rank ${$depthwiseFilter.rank}.`);
- assert($pointwiseFilter.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter ` +
- ` must be 1, but got ${$pointwiseFilter.shape[0]}.`);
- assert($pointwiseFilter.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise ` +
- `filter must be 1, but got ${$pointwiseFilter.shape[1]}.`);
- const inChannels = $depthwiseFilter.shape[2];
- const channelMultiplier = $depthwiseFilter.shape[3];
- assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, () => `Error in separableConv2d: the third dimension of pointwise filter ` +
- `must be ${inChannels * channelMultiplier}, ` +
- `but got ${$pointwiseFilter.shape[2]}.`);
- const depthwise = depthwiseConv2d(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
- const pointwiseStride = 1;
- const res = conv2d(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const separableConv2d = op({ separableConv2d_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the difference between two lists of numbers.
- *
- * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
- * that represents all values that are in `x` but not in `y`. The returned
- * Tensor `out` is sorted in the same order that the numbers appear in `x`
- * (duplicates are preserved). This operation also returns a Tensor indices that
- * represents the position of each out element in `x`. In other words:
- *
- * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
- *
- * ```js
- * const x = [1, 2, 3, 4, 5, 6];
- * const y = [1, 3, 5];
- *
- * const [out, indices] = await tf.setdiff1dAsync(x, y);
- * out.print(); // [2, 4, 6]
- * indices.print(); // [1, 3, 5]
- * ```
- *
- * @param x 1-D Tensor. Values to keep.
- * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
- * output.
- * @returns Promise of Tensor tuple [out, indices].
- * out: Tensor with the same type as x.
- * indices: A Tensor of type int32.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- async function setdiff1dAsync_(x, y) {
- const $x = convertToTensor(x, 'x', 'setdiff1d');
- const $y = convertToTensor(y, 'y', 'setdiff1d');
- assert($x.dtype === $y.dtype, () => `x and y should have the same dtype, but got x (${$x.dtype}) and y (${$y.dtype}).`);
- assert($x.rank === 1, () => `x should be 1D tensor, but got x (${$x.shape}).`);
- assert($y.rank === 1, () => `y should be 1D tensor, but got y (${$y.shape}).`);
- const xVals = await $x.data();
- const yVals = await $y.data();
- const ySet = new Set(yVals);
- let outputSize = 0;
- for (let i = 0; i < xVals.length; i++) {
- if (!ySet.has(xVals[i])) {
- outputSize++;
- }
- }
- const buffer = new TensorBuffer([outputSize], $x.dtype);
- const indices = new TensorBuffer([outputSize], 'int32');
- for (let i = 0, p = 0; i < xVals.length; i++) {
- if (!ySet.has(xVals[i])) {
- buffer.values[p] = xVals[i];
- indices.values[p] = i;
- p++;
- }
- }
- return [buffer.toTensor(), indices.toTensor()];
- }
- const setdiff1dAsync = setdiff1dAsync_;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns an element-wise indication of the sign of a number.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
- *
- * x.sign().print(); // or tf.sign(x)
- * ```
- * @param x The input Tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function sign_(x) {
- const $x = convertToTensor(x, 'x', 'sign');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc(backend => backend.sign($x), inputs, null /* grad */, Sign);
- }
- const sign = op({ sign_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes sin of the input Tensor element-wise: `sin(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.sin().print(); // or tf.sin(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function sin_(x) {
- const $x = convertToTensor(x, 'x', 'sin');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.sin($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Sin);
- }
- const sin = op({ sin_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.sinh().print(); // or tf.sinh(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function sinh_(x) {
- const $x = convertToTensor(x, 'x', 'sinh');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.sinh($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Sinh);
- }
- const sinh = op({ sinh_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
- * of length `size`. See `slice` for details.
- */
- function slice1d_(x, begin, size) {
- const $x = convertToTensor(x, 'x', 'slice1d');
- assert($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
- return slice($x, [begin], [size]);
- }
- const slice1d = op({ slice1d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
- * is of size `size`. See `slice` for details.
- */
- function slice2d_(x, begin, size) {
- const $x = convertToTensor(x, 'x', 'slice2d');
- assert($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
- return slice($x, begin, size);
- }
- const slice2d = op({ slice2d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
- * is of size `size`. See `slice` for details.
- */
- function slice3d_(x, begin, size) {
- const $x = convertToTensor(x, 'x', 'slice3d');
- assert($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
- return slice($x, begin, size);
- }
- const slice3d = op({ slice3d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
- * is of size `size`. See `slice` for details.
- */
- function slice4d_(x, begin, size) {
- const $x = convertToTensor(x, 'x', 'slice4d');
- assert($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
- return slice($x, begin, size);
- }
- const slice4d = op({ slice4d_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the softmax normalized vector given the logits.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- *
- * a.softmax().print(); // or tf.softmax(a)
- * ```
- *
- * ```js
- * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
- *
- * a.softmax().print(); // or tf.softmax(a)
- * ```
- *
- * @param logits The logits array.
- * @param dim The dimension softmax would be performed on. Defaults to `-1`
- * which indicates the last dimension.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function softmax_(logits, dim = -1) {
- const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
- if (dim === -1) {
- dim = $logits.rank - 1;
- }
- if (dim !== $logits.rank - 1) {
- throw Error('Softmax along a non-last dimension is not yet supported. ' +
- `Logits was rank ${$logits.rank} and dim was ${dim}`);
- }
- const inputs = { logits: $logits };
- const attrs = { dim };
- return ENGINE.runKernelFunc((backend, save) => {
- const y = backend.softmax($logits, dim);
- save([y]);
- return y;
- }, inputs, null /* grad */, Softmax, attrs);
- }
- const softmax = op({ softmax_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Fast Fourier transform.
- *
- * Computes the 1-dimensional discrete Fourier transform over the inner-most
- * dimension of input.
- *
- * ```js
- * const real = tf.tensor1d([1, 2, 3]);
- * const imag = tf.tensor1d([1, 2, 3]);
- * const x = tf.complex(real, imag);
- *
- * x.fft().print(); // tf.spectral.fft(x).print();
- * ```
- * @param input The complex input to compute an fft over.
- *
- * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
- */
- function fft_(input) {
- assert(input.dtype === 'complex64', () => `The dtype for tf.spectral.fft() must be complex64 ` +
- `but got ${input.dtype}.`);
- const inputs = { input };
- return ENGINE.runKernelFunc(backend => {
- // Collapse all outer dimensions to a single batch dimension.
- const innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = input.size / innerDimensionSize;
- const input2D = input.as2D(batch, innerDimensionSize);
- const result = backend.fft(input2D);
- return result.reshape(input.shape);
- }, inputs, null /* gradient */, FFT);
- }
- const fft = op({ fft_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Inverse fast Fourier transform.
- *
- * Computes the inverse 1-dimensional discrete Fourier transform over the
- * inner-most dimension of input.
- *
- * ```js
- * const real = tf.tensor1d([1, 2, 3]);
- * const imag = tf.tensor1d([1, 2, 3]);
- * const x = tf.complex(real, imag);
- *
- * x.ifft().print(); // tf.spectral.ifft(x).print();
- * ```
- * @param input The complex input to compute an ifft over.
- *
- * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
- */
- function ifft_(input) {
- assert(input.dtype === 'complex64', () => `The dtype for tf.spectral.ifft() must be complex64 ` +
- `but got ${input.dtype}.`);
- const inputs = { input };
- return ENGINE.runKernelFunc(backend => {
- // Collapse all outer dimensions to a single batch dimension.
- const innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = input.size / innerDimensionSize;
- const input2D = reshape(input, [batch, innerDimensionSize]);
- const result = backend.ifft(input2D);
- return reshape(result, input.shape);
- }, inputs, null /* gradient */, IFFT);
- }
- const ifft = op({ ifft_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Inversed real value input fast Fourier transform.
- *
- * Computes the 1-dimensional inversed discrete Fourier transform over the
- * inner-most dimension of the real input.
- *
- * ```js
- * const real = tf.tensor1d([1, 2, 3]);
- * const imag = tf.tensor1d([0, 0, 0]);
- * const x = tf.complex(real, imag);
- *
- * x.irfft().print();
- * ```
- * @param input The real value input to compute an irfft over.
- *
- * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
- */
- function irfft_(input) {
- const innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = input.size / innerDimensionSize;
- let ret;
- if (innerDimensionSize <= 2) {
- const complexInput = reshape(input, [batch, innerDimensionSize]);
- ret = ifft(complexInput);
- }
- else {
- // The length of unique components of the DFT of a real-valued signal
- // is 2 * (input_len - 1)
- const outputShape = [batch, 2 * (innerDimensionSize - 1)];
- const realInput = reshape(real(input), [batch, innerDimensionSize]);
- const imagInput = reshape(imag(input), [batch, innerDimensionSize]);
- const realConjugate = reverse(slice(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
- const imagConjugate = mul(reverse(slice(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
- const r = concat([realInput, realConjugate], 1);
- const i = concat([imagInput, imagConjugate], 1);
- const complexInput = reshape(complex(r, i), [outputShape[0], outputShape[1]]);
- ret = ifft(complexInput);
- }
- ret = real(ret);
- // reshape the result if the input is 3D tensor.
- if (input.rank === 3 && input.shape[0] !== 0) {
- const temp = ret;
- const batch = input.shape[0];
- ret = reshape(ret, [batch, ret.shape[0] / batch, ret.shape[1]]);
- temp.dispose();
- }
- return ret;
- }
- const irfft = op({ irfft_ });
-
- /**
- * Prepare the split size array. When the input is a number, the axis is evenly
- * divided among the split size. When the input contains the negative value, the
- * rest of the axis is allocated toward that.
- */
- function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
- let splitSizes = [];
- if (typeof (numOrSizeSplits) === 'number') {
- assert(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
- splitSizes =
- new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
- }
- else {
- const numOfNegs = numOrSizeSplits.reduce((count, value) => {
- if (value === -1) {
- count += 1;
- }
- return count;
- }, 0);
- assert(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
- const negIndex = numOrSizeSplits.indexOf(-1);
- // Allow the number of split array to be -1, which indicates the rest
- // of dimension is allocated to that split.
- if (negIndex !== -1) {
- const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
- numOrSizeSplits[negIndex] = x.shape[axis] - total;
- }
- assert(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
- splitSizes = numOrSizeSplits;
- }
- return splitSizes;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Splits a `tf.Tensor` into sub tensors.
- *
- * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
- * into `numOrSizeSplits` smaller tensors.
- * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
- *
- * If `numOrSizeSplits` is a number array, splits `x` into
- * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
- * same size as `x` except along dimension `axis` where the size is
- * `numOrSizeSplits[i]`.
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
- * const [a, b] = tf.split(x, 2, 1);
- * a.print();
- * b.print();
- *
- * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
- * c.print();
- * d.print();
- * e.print();
- * ```
- *
- * @param x The input tensor to split.
- * @param numOrSizeSplits Either an integer indicating the number of
- * splits along the axis or an array of integers containing the sizes of
- * each output tensor along the axis. If a number then it must evenly divide
- * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
- * Can contain one -1 indicating that dimension is to be inferred.
- * @param axis The dimension along which to split. Defaults to 0 (the first
- * dim).
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function split_(x, numOrSizeSplits, axis = 0) {
- const $x = convertToTensor(x, 'x', 'split');
- const forward = (backend, _) => {
- const $axis = parseAxisParam(axis, $x.shape)[0];
- const splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis);
- return backend.split($x, splitSizes, $axis);
- };
- const inputs = { x: $x };
- const attr = { numOrSizeSplits, axis };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, SplitV, attr);
- }
- const split = op({ split_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Real value input fast Fourier transform.
- *
- * Computes the 1-dimensional discrete Fourier transform over the
- * inner-most dimension of the real input.
- *
- * ```js
- * const real = tf.tensor1d([1, 2, 3]);
- *
- * real.rfft().print();
- * ```
- * @param input The real value input to compute an rfft over.
- *
- * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
- */
- function rfft_(input, fftLength) {
- assert(input.dtype === 'float32', () => `The dtype for rfft() must be real value but got ${input.dtype}`);
- let innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = input.size / innerDimensionSize;
- let adjustedInput;
- if (fftLength != null && fftLength < innerDimensionSize) {
- // Need to crop
- const begin = input.shape.map(v => 0);
- const size = input.shape.map(v => v);
- size[input.shape.length - 1] = fftLength;
- adjustedInput = slice(input, begin, size);
- innerDimensionSize = fftLength;
- }
- else if (fftLength != null && fftLength > innerDimensionSize) {
- // Need to pad with zeros
- const zerosShape = input.shape.map(v => v);
- zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
- adjustedInput = concat([input, zeros(zerosShape)], input.shape.length - 1);
- innerDimensionSize = fftLength;
- }
- else {
- adjustedInput = input;
- }
- // Complement the input with zero imaginary numbers.
- const zerosInput = zerosLike(adjustedInput);
- const complexInput = reshape(complex(adjustedInput, zerosInput), [batch, innerDimensionSize]);
- const ret = fft(complexInput);
- // Exclude complex conjugations. These conjugations are put symmetrically.
- const half = Math.floor(innerDimensionSize / 2) + 1;
- const realValues = real(ret);
- const imagValues = imag(ret);
- const realComplexConjugate = split(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
- const imagComplexConjugate = split(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
- const outputShape = adjustedInput.shape.slice();
- outputShape[adjustedInput.shape.length - 1] = half;
- return reshape(complex(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
- }
- const rfft = op({ rfft_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 4, -1]);
- *
- * x.sqrt().print(); // or tf.sqrt(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function sqrt_(x) {
- const $x = convertToTensor(x, 'x', 'sqrt');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.sqrt($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Sqrt);
- }
- const sqrt = op({ sqrt_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns (a - b) * (a - b) element-wise.
- * Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
- * ```
- *
- * ```js
- * // Broadcast squared difference a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- *
- * @doc {heading: 'Operations', subheading: 'Arithmetic'}
- */
- function squaredDifference_(a, b) {
- let $a = convertToTensor(a, 'a', 'squaredDifference');
- let $b = convertToTensor(b, 'b', 'squaredDifference');
- [$a, $b] = makeTypesMatch($a, $b);
- assertAndGetBroadcastShape($a.shape, $b.shape);
- const forward = (backend, save) => {
- const res = backend.squaredDifference($a, $b);
- save([$a, $b]);
- return res;
- };
- const inputs = { a: $a, b: $b };
- const attrs = {};
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, SquaredDifference, attrs);
- }
- const squaredDifference = op({ squaredDifference_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
- *
- * ```js
- * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
- * x.squeeze().print();
- * ```
- *
- * @param x The input tensor to be squeezed.
- * @param axis An optional list of numbers. If specified, only
- * squeezes the dimensions listed. The dimension index starts at 0. It
- * is an error to squeeze a dimension that is not 1.
- *
- * @doc {heading: 'Tensors', subheading: 'Transformations'}
- */
- function squeeze_(x, axis) {
- const $x = convertToTensor(x, 'x', 'squeeze');
- return reshape($x, squeezeShape($x.shape, axis).newShape);
- }
- const squeeze = op({ squeeze_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * const c = tf.tensor1d([5, 6]);
- * tf.stack([a, b, c]).print();
- * ```
- *
- * @param tensors A list of tensor objects with the same shape and dtype.
- * @param axis The axis to stack along. Defaults to 0 (the first dim).
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function stack_(tensors, axis = 0) {
- const $tensors = convertToTensorArray(tensors, 'tensors', 'stack');
- assert($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
- if ($tensors.length === 1) {
- return expandDims($tensors[0], axis);
- }
- const rank = $tensors[0].rank;
- const shape = $tensors[0].shape;
- const dtype = $tensors[0].dtype;
- assert(axis <= rank, () => 'Axis must be <= rank of the tensor');
- $tensors.forEach(t => {
- assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
- assert(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
- });
- const expandedTensors = $tensors.map(t => expandDims(t, axis));
- // Stack exists in the TensorFlow C++ API
- // (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/stack) but not
- // in
- // https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/core/ops/ops.pbtxt.
- // Therefore we are treating it like a high-level op rather than
- // creating a dedicated stack kernel.
- return concat(expandedTensors, axis);
- }
- const stack = op({ stack_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`
- *
- * ```js
- * const x = tf.tensor1d([0, 2, -1, -3]);
- *
- * x.step(.5).print(); // or tf.step(x, .5)
- * ```
- * @param x The input tensor.
- * @param alpha The gradient when input is negative.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function step_(x, alpha = 0.0) {
- const $x = convertToTensor(x, 'x', 'step');
- const inputs = { x: $x };
- const attrs = { alpha };
- return ENGINE.runKernelFunc(backend => backend.step($x, alpha), inputs, null /* grad */, Step, attrs);
- }
- const step = op({ step_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts a strided slice of a tensor.
- *
- * Roughly speaking, this op extracts a slice of size (end-begin)/stride from
- * the given input tensor (x). Starting at the location specified by begin the
- * slice continues by adding stride to the index until all dimensions are not
- * less than end. Note that a stride can be negative, which causes a reverse
- * slice.
- *
- * ```js
- * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
- * [3, 2, 3]);
- * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
- * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
- * // [4, 4, 4]]]
- * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
- * // [3, 3, 3]]]
- * ```
- *
- * @param x The tensor to stride slice.
- * @param begin The coordinates to start the slice from.
- * @param end: The coordinates to end the slice at.
- * @param strides: The size of the slice.
- * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
- * and the fullest possible range in that dimension is used instead.
- * @param endMask: If the ith bit of endMask is set, end[i] is ignored
- * and the fullest possible range in that dimension is used instead.
- * @param shrinkAxisMask: a bitmask where bit i implies that
- * the ith specification should shrink the dimensionality. begin and end must
- * imply a slice of size 1 in the dimension.
- *
- * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
- */
- function stridedSlice_(x, begin, end, strides, beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0) {
- let $x = convertToTensor(x, 'x', 'stridedSlice');
- const forward = (backend) => {
- if (strides == null) {
- strides = new Array(begin.length);
- }
- const ellipsisAxes = maskToAxes(ellipsisMask);
- if (ellipsisAxes.length > 1) {
- throw new Error('Multiple ellipses in slice is not allowed.');
- }
- if (ellipsisMask !== 0 && newAxisMask !== 0) {
- throw new Error('Using both ellipsisMask and newAxisMask is not yet supported.');
- }
- if (ellipsisMask !== 0 && shrinkAxisMask !== 0) {
- throw new Error('Using both ellipsisMask and shrinkAxisMask is not yet supported.');
- }
- const numInterpolatedAxes = $x.rank - begin.length;
- // Expand the dims of x based on the newAxisMask.
- const expandAxes = maskToAxes(newAxisMask);
- const newShape = $x.shape.slice();
- expandAxes.forEach(axis => {
- begin[axis] = 0;
- end[axis] = 1;
- newShape.splice(axis, 0, 1);
- });
- $x = reshape($x, newShape);
- const { begin: normalizedBegin, end: normalizedEnd, strides: normalizedStrides } = getNormalizedAxes($x.shape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask);
- begin = normalizedBegin;
- end = normalizedEnd;
- strides = normalizedStrides;
- const shrinkAxes = maskToAxes(shrinkAxisMask);
- // Adjust the ends based on the shrink mask.
- shrinkAxes.forEach(axis => {
- end[axis] = begin[axis] + 1;
- strides[axis] = 1;
- });
- // Figure out the output shape.
- const size = computeOutShape(begin, end, strides);
- // Remove the axes based on shrinkMask.
- const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1);
- const nonStrided = strides.every(v => v === 1);
- if (nonStrided) {
- return reshape(slice($x, begin, size), outShape);
- }
- const res = backend.stridedSlice($x, begin, end, strides);
- return reshape(res, outShape);
- };
- const inputs = { x: $x };
- const attrs = {
- begin,
- end,
- strides,
- beginMask,
- endMask,
- ellipsisMask,
- newAxisMask,
- shrinkAxisMask
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, StridedSlice, attrs);
- }
- const stridedSlice = op({ stridedSlice_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.tan().print(); // or tf.tan(x)
- * ```
- * @param x The input tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Basic math'}
- */
- function tan_(x) {
- const $x = convertToTensor(x, 'x', 'tan');
- const inputs = { x: $x };
- return ENGINE.runKernelFunc((backend, save) => {
- const res = backend.tan($x);
- save([$x]);
- return res;
- }, inputs, null /* grad */, Tan);
- }
- const tan = op({ tan_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor2d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor2d([[1, 2], [3, 4]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. If not provided, it is inferred from
- * `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor2d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 2) {
- throw new Error('tensor2d() requires shape to have two numbers');
- }
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 2 && inferredShape.length !== 1) {
- throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor2d() requires shape to be provided when `values` ' +
- 'are a flat/TypedArray');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor4d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor4d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 4) {
- throw new Error('tensor4d() requires shape to have four numbers');
- }
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 4 && inferredShape.length !== 1) {
- throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor4d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor5d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor5d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 5) {
- throw new Error('tensor5d() requires shape to have five numbers');
- }
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 5 && inferredShape.length !== 1) {
- throw new Error('tensor5d() requires values to be ' +
- 'number[][][][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor5d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor6d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function tensor6d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 6) {
- throw new Error('tensor6d() requires shape to have six numbers');
- }
- const inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 6 && inferredShape.length !== 1) {
- throw new Error('tensor6d() requires values to be number[][][][][][] or ' +
- 'flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor6d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- shape = shape ||
- inferredShape;
- return makeTensor(values, shape, inferredShape, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Finds the values and indices of the `k` largest entries along the last
- * dimension.
- *
- * If the input is a vector (rank=1), finds the k largest entries in the vector
- * and outputs their values and indices as vectors. Thus values[j] is the j-th
- * largest entry in input, and its index is indices[j].
- * For higher rank inputs, computes the top k entries along the last dimension.
- *
- * If two elements are equal, the lower-index element appears first.
- *
- * ```js
- * const a = tf.tensor2d([[1, 5], [4, 3]]);
- * const {values, indices} = tf.topk(a);
- * values.print();
- * indices.print();
- * ```
- * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
- * @param k Number of top elements to look for along the last dimension.
- * @param sorted If true, the resulting `k` elements will be sorted by the
- * values in descending order.
- *
- * @doc {heading: 'Operations', subheading: 'Evaluation'}
- */
- function topk_(x, k = 1, sorted = true) {
- const $x = convertToTensor(x, 'x', 'topk');
- if ($x.rank === 0) {
- throw new Error('topk() expects the input to be of rank 1 or higher');
- }
- const lastDim = $x.shape[$x.shape.length - 1];
- if (k > lastDim) {
- throw new Error(`'k' passed to topk() must be <= the last dimension (${lastDim}) ` +
- `but got ${k}`);
- }
- const inputs = { x: $x };
- const attrs = { k, sorted };
- const [values, indices] = ENGINE.runKernelFunc(b => b.topk($x, k, sorted), inputs, null /* grad */, TopK, attrs);
- return { values, indices };
- }
- const topk = op({ topk_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with values sampled from a truncated normal
- * distribution.
- *
- * ```js
- * tf.truncatedNormal([2, 2]).print();
- * ```
- *
- * The generated values follow a normal distribution with specified mean and
- * standard deviation, except that values whose magnitude is more than 2
- * standard deviations from the mean are dropped and re-picked.
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param mean The mean of the normal distribution.
- * @param stdDev The standard deviation of the normal distribution.
- * @param dtype The data type of the output tensor.
- * @param seed The seed for the random number generator.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
- if (dtype != null && dtype === 'bool') {
- throw new Error(`Unsupported data type $ { dtype }`);
- }
- const randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
- const res = buffer(shape, dtype);
- for (let i = 0; i < res.values.length; i++) {
- res.values[i] = randGauss.nextValue();
- }
- return res.toTensor();
- }
- const truncatedNormal = op({ truncatedNormal_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Finds unique elements along an axis of a tensor.
- *
- * It returns a tensor `values` containing all of the unique elements along the
- * `axis` of the given tensor `x` in the same order that they occur along the
- * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
- * `indices` the same size as the number of the elements in `x` along the `axis`
- * dimension. It contains the index in the unique output `values`.
- *
- * ```js
- * // A 1-D tensor
- * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
- * const {values, indices} = tf.unique(a);
- * values.print(); // [1, 2, 4, 7, 8,]
- * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
- * ```
- *
- * ```js
- * // A 2-D tensor with axis=0
- * //
- * // 'a' is: [[1, 0, 0],
- * // [1, 0, 0],
- * // [2, 0, 0]]
- * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
- * const {values, indices} = tf.unique(a, 0)
- * values.print(); // [[1, 0, 0],
- * // [2, 0, 0]]
- * indices.print(); // [0, 0, 1]
- * ```
- *
- * ```js
- * // A 2-D tensor with axis=1
- * //
- * // 'a' is: [[1, 0, 0],
- * // [1, 0, 0],
- * // [2, 0, 0]]
- * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
- * const {values, indices} = tf.unique(a, 1)
- * values.print(); // [[1, 0],
- * // [1, 0],
- * // [2, 0]]
- * indices.print(); // [0, 1, 1]
- * ```
- * @param x A tensor (int32, string, bool).
- * @param axis The axis of the tensor to find the unique elements.
- * @returns [uniqueElements, indices] (see above for details)
- *
- * @doc {heading: 'Operations', subheading: 'Evaluation'}
- */
- function unique_(x, axis = 0) {
- // x can be of any dtype, thus null as the last argument.
- const $x = convertToTensor(x, 'x', 'unique', null);
- assert($x.rank > 0, () => 'The input tensor must be at least 1D');
- const inputs = { x: $x };
- const attrs = { axis };
- const [values, indices] = ENGINE.runKernel(Unique, inputs, attrs);
- return { values, indices };
- }
- const unique = op({ unique_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the sum along segments of a `tf.Tensor`.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
- * const numSegments = 3;
- *
- * x.unsortedSegmentSum(segmentIds, numSegments).print()
- * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
- * ```
- * @param x The `tf.Tensor` that will be summed along its segments.
- * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
- * dimension along the `axis`. Maps each element of `x` to a segment.
- * @param numSegments The number of distinct `segmentIds`.
- *
- * @doc {heading: 'Operations', subheading: 'Segment'}
- */
- function unsortedSegmentSum_(x, segmentIds, numSegments) {
- const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
- const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
- assert(isInt(numSegments), () => 'numSegments must be of dtype int');
- const inputs = { x: $x, segmentIds: $segmentIds };
- const attrs = { numSegments };
- const forward = (backend, save) => {
- const res = backend.unsortedSegmentSum($x, $segmentIds, numSegments);
- save([$segmentIds]);
- return res;
- };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, UnsortedSegmentSum, attrs);
- }
- const unsortedSegmentSum = op({ unsortedSegmentSum_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
- *
- * ```js
- * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * tf.unstack(a).forEach(tensor => tensor.print());
- * ```
- *
- * @param x A tensor object.
- * @param axis The axis to unstack along. Defaults to 0 (the first dim).
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- function unstack_(x, axis = 0) {
- const $x = convertToTensor(x, 'x', 'unstack');
- assert(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
- if (axis < 0) {
- axis += $x.shape.length;
- }
- const inputs = { value: $x };
- const attrs = { axis };
- const forward = (backend) => backend.unstack($x, axis);
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Unpack, attrs);
- }
- const unstack = op({ unstack_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a new variable with the provided initial value.
- * ```js
- * const x = tf.variable(tf.tensor([1, 2, 3]));
- * x.assign(tf.tensor([4, 5, 6]));
- *
- * x.print();
- * ```
- *
- * @param initialValue Initial value for the tensor.
- * @param trainable If true, optimizers are allowed to update it.
- * @param name Name of the variable. Defaults to a unique id.
- * @param dtype If set, initialValue will be converted to the given type.
- *
- * @doc {heading: 'Tensors', subheading: 'Creation'}
- */
- function variable(initialValue, trainable = true, name, dtype) {
- return ENGINE.makeVariable(initialValue, trainable, name, dtype);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function whereImpl(condShape, condVals) {
- const indices = [];
- for (let i = 0; i < condVals.length; i++) {
- if (condVals[i]) {
- indices.push(i);
- }
- }
- const inBuffer = buffer(condShape, 'int32');
- const out = buffer([indices.length, condShape.length], 'int32');
- for (let i = 0; i < indices.length; i++) {
- const loc = inBuffer.indexToLoc(indices[i]);
- const offset = i * condShape.length;
- out.values.set(loc, offset);
- }
- return out.toTensor();
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the coordinates of true elements of condition.
- *
- * The coordinates are returned in a 2-D tensor where the first dimension (rows)
- * represents the number of true elements, and the second dimension (columns)
- * represents the coordinates of the true elements. Keep in mind, the shape of
- * the output tensor can vary depending on how many true values there are in
- * input. Indices are output in row-major order. The resulting tensor has the
- * shape `[numTrueElems, condition.rank]`.
- *
- * This is analogous to calling the python `tf.where(cond)` without an x or y.
- *
- * ```js
- * const cond = tf.tensor1d([false, false, true], 'bool');
- * const result = await tf.whereAsync(cond);
- * result.print();
- * ```
- *
- * @doc {heading: 'Operations', subheading: 'Logical'}
- */
- async function whereAsync_(condition) {
- const $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
- const vals = await $condition.data();
- const res = whereImpl($condition.shape, vals);
- if (condition !== $condition) {
- $condition.dispose();
- }
- return res;
- }
- const whereAsync = whereAsync_;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Apply boolean mask to tensor.
- *
- * ```js
- * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
- * const mask = tf.tensor1d([1, 0, 1], 'bool');
- * const result = await tf.booleanMaskAsync(tensor, mask);
- * result.print();
- * ```
- *
- * @param tensor N-D tensor.
- * @param mask K-D boolean tensor, K <= N and K must be known statically.
- * @param axis A 0-D int Tensor representing the axis in tensor to mask from.
- * By default, axis is 0 which will mask from the first dimension.
- * Otherwise K + axis <= N.
- *
- * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
- */
- async function booleanMaskAsync_(tensor, mask, axis) {
- const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
- const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
- const axisFrom = axis == null ? 0 : axis;
- const maskDim = $mask.rank;
- const tensorShape = $tensor.shape;
- assert(maskDim > 0, () => 'mask cannot be scalar');
- assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, `mask's shape must match the first K dimensions of tensor's shape,`);
- let leadingSize = 1;
- for (let i = axisFrom; i < axisFrom + maskDim; i++) {
- leadingSize *= tensorShape[i];
- }
- const targetTensorShape = tensorShape.slice(0, axisFrom)
- .concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
- const reshapedTensor = reshape($tensor, targetTensorShape);
- const reshapedMask = reshape($mask, [-1]);
- const positivePositions = await whereAsync(reshapedMask);
- const indices = squeeze(positivePositions, [1]);
- const res = gather(reshapedTensor, indices, axisFrom);
- // Ensure no memory leak.
- if (tensor !== $tensor) {
- $tensor.dispose();
- }
- if (mask !== $mask) {
- $mask.dispose();
- }
- indices.dispose();
- reshapedTensor.dispose();
- reshapedMask.dispose();
- positivePositions.dispose();
- return res;
- }
- const booleanMaskAsync = booleanMaskAsync_;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated
- * Strict version of `tf.notEqual` that forces `a` and `b` to be of the same
- * shape.
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same shape and dtype as
- * `a`.
- */
- function notEqualStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'notEqualStrict');
- const $b = convertToTensor(b, 'b', 'notEqualStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in notEqualStrict: ');
- return notEqual($a, $b);
- }
- /**
- * @deprecated
- * Strict version of `tf.less` that forces `a` and `b` to be of the same
- * shape.
- *
- * @param a The first input tensor.
- * @param b The second input tensor. Must have the same shape and dtype as
- * `a`.
- */
- function lessStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'lessStrict');
- const $b = convertToTensor(b, 'b', 'lessStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in lessStrict: ');
- return less($a, $b);
- }
- function equalStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'equalStrict');
- const $b = convertToTensor(b, 'b', 'equalStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in equalStrict: ');
- return equal($a, $b);
- }
- function lessEqualStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'lessEqualStrict');
- const $b = convertToTensor(b, 'b', 'lessEqualStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in lessEqualStrict: ');
- return lessEqual($a, $b);
- }
- function greaterStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'greaterStrict');
- const $b = convertToTensor(b, 'b', 'greaterStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in greaterStrict: ');
- return greater($a, $b);
- }
- function greaterEqualStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'greaterEqualStrict');
- const $b = convertToTensor(b, 'b', 'greaterEqualStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in greaterEqualStrict: ');
- return greaterEqual($a, $b);
- }
- const equalStrict = op({ equalStrict_ });
- const greaterEqualStrict = op({ greaterEqualStrict_ });
- const greaterStrict = op({ greaterStrict_ });
- const lessEqualStrict = op({ lessEqualStrict_ });
- const lessStrict = op({ lessStrict_ });
- const notEqualStrict = op({ notEqualStrict_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated
- * Adds two `tf.Tensor`s element-wise, A + B.
- *
- * Inputs must be the same shape. For broadcasting support, use add() instead.
- *
- * @param a The first Tensor to add element-wise.
- * @param b The second Tensor to add element-wise.
- */
- function addStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'addStrict');
- const $b = convertToTensor(b, 'b', 'addStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: ');
- return add$1($a, $b);
- }
- /**
- * @deprecated
- * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must
- * be the same shape.
- *
- * For broadcasting support, use `tf.sub` instead.
- *
- * @param a The first Tensor to subtract element-wise.
- * @param b The second Tensor to subtract element-wise.
- */
- function subStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'subStrict');
- const $b = convertToTensor(b, 'b', 'subStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: ');
- return sub($a, $b);
- }
- /**
- * @deprecated
- * Computes the power of one `tf.Tensor` to another. Inputs must
- * be the same shape.
- *
- * For broadcasting support, use `tf.pow` instead.
- *
- * @param base The base tensor to pow element-wise.
- * @param exp The exponent tensor to pow element-wise.
- */
- function powStrict_(base, exp) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: ');
- return pow(base, exp);
- }
- /**
- * @deprecated
- * Multiplies two `tf.Tensor`s element-wise, A * B.
- *
- * Inputs must be the same shape. For broadcasting support, use `tf.mul`.
- *
- * @param a The first tensor to multiply.
- * @param b The first tensor to multiply. Must have the same
- * dtype as `a`.
- */
- function mulStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'mul');
- const $b = convertToTensor(b, 'b', 'mul');
- assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: ');
- return mul($a, $b);
- }
- /**
- * @deprecated
- * Divides two `tf.Tensor`s element-wise, A / B. Inputs must
- * be the same shape.
- *
- * @param a The first tensor as the numerator for element-wise division.
- * @param b The second tensor as the denominator for element-wise division.
- */
- function divStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'div');
- const $b = convertToTensor(b, 'b', 'div');
- assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: ');
- return div($a, $b);
- }
- /**
- * @deprecated
- * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use mod().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function modStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'modStrict');
- const $b = convertToTensor(b, 'b', 'modStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: ');
- return mod($a, $b);
- }
- /**
- * @deprecated
- * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use minimum().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function minimumStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'minimumStrict');
- const $b = convertToTensor(b, 'b', 'minimumStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: ');
- return minimum($a, $b);
- }
- /**
- * @deprecated
- * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use maximum().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function maximumStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'maximumStrict');
- const $b = convertToTensor(b, 'b', 'maximumStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: ');
- return maximum($a, $b);
- }
- /**
- * @deprecated
- * Returns (a - b) * (a - b) element-wise.
- *
- * Inputs must be the same shape. For broadcasting support, use
- * `tf.squaredDifference` instead.
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- */
- function squaredDifferenceStrict_(a, b) {
- deprecationWarn('strict variants of ops have been deprecated ' +
- 'and will be removed in future');
- const $a = convertToTensor(a, 'a', 'squaredDifferenceStrict');
- const $b = convertToTensor(b, 'b', 'squaredDifferenceStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: ');
- return squaredDifference($a, $b);
- }
- const addStrict = op({ addStrict_ });
- const divStrict = op({ divStrict_ });
- const maximumStrict = op({ maximumStrict_ });
- const minimumStrict = op({ minimumStrict_ });
- const modStrict = op({ modStrict_ });
- const mulStrict = op({ mulStrict_ });
- const powStrict = op({ powStrict_ });
- const squaredDifferenceStrict = op({ squaredDifferenceStrict_ });
- const subStrict = op({ subStrict_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the norm of scalar, vectors, and matrices.
- * This function can compute several different vector norms (the 1-norm, the
- * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
- * and matrix norms (Frobenius, 1-norm, and inf-norm).
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- *
- * x.norm().print(); // or tf.norm(x)
- * ```
- *
- * @param x The input array.
- * @param ord Optional. Order of the norm. Supported norm types are
- * following:
- *
- * | ord | norm for matrices | norm for vectors
- * |------------|---------------------------|---------------------
- * |'euclidean' |Frobenius norm |2-norm
- * |'fro' |Frobenius norm |
- * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
- * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
- * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
- * |2 | |sum(abs(x)^2)^1/2*
- *
- * @param axis Optional. If axis is null (the default), the input is
- * considered a vector and a single vector norm is computed over the entire
- * set of values in the Tensor, i.e. norm(x, ord) is equivalent
- * to norm(x.reshape([-1]), ord). If axis is a integer, the input
- * is considered a batch of vectors, and axis determines the axis in x
- * over which to compute vector norms. If axis is a 2-tuple of integer it is
- * considered a batch of matrices and axis determines the axes in NDArray
- * over which to compute a matrix norm.
- * @param keepDims Optional. If true, the norm have the same dimensionality
- * as the input.
- *
- * @doc {heading: 'Operations', subheading: 'Matrices'}
- */
- function norm_(x, ord = 'euclidean', axis = null, keepDims = false) {
- x = convertToTensor(x, 'x', 'norm');
- const norm = normImpl(x, ord, axis);
- let keepDimsShape = norm.shape;
- if (keepDims) {
- const axes = parseAxisParam(axis, x.shape);
- keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
- }
- return reshape(norm, keepDimsShape);
- }
- function normImpl(x, p, axis = null) {
- if (x.rank === 0) {
- return abs(x);
- }
- // consider vector when no axis is specified
- if (x.rank !== 1 && axis === null) {
- return normImpl(reshape(x, [-1]), p, axis);
- }
- // vector
- if (x.rank === 1 || typeof axis === 'number' ||
- Array.isArray(axis) && axis.length === 1) {
- if (p === 1) {
- return sum$1(abs(x), axis);
- }
- if (p === Infinity) {
- return max(abs(x), axis);
- }
- if (p === -Infinity) {
- return min(abs(x), axis);
- }
- if (p === 'euclidean' || p === 2) {
- // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
- return sqrt(sum$1(pow(abs(x), scalar(2, 'int32')), axis));
- }
- throw new Error(`Error in norm: invalid ord value: ${p}`);
- }
- // matrix (assumption axis[0] < axis[1])
- if (Array.isArray(axis) && axis.length === 2) {
- if (p === 1) {
- return max(sum$1(abs(x), axis[0]), axis[1] - 1);
- }
- if (p === Infinity) {
- return max(sum$1(abs(x), axis[1]), axis[0]);
- }
- if (p === -Infinity) {
- return min(sum$1(abs(x), axis[1]), axis[0]);
- }
- if (p === 'fro' || p === 'euclidean') {
- // norm(x) = sqrt(sum(pow(x, 2)))
- return sqrt(sum$1(square(x), axis));
- }
- throw new Error(`Error in norm: invalid ord value: ${p}`);
- }
- throw new Error(`Error in norm: invalid axis: ${axis}`);
- }
- const norm = op({ norm_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Compute the moving average of a variable.
- *
- * Without zeroDebias, the moving average operation is defined by:
- * `v += delta`
- * where
- * `delta = (1 - decay) * (x - v)`
- *
- * With zeroDebias (default), the `delta` term is scaled to debias the
- * effect of the (assumed) zero-initialization of `v`.
- * `delta /= (1 - decay ^ step)`
- *
- * For more details on the zero-debiasing algorithm, see:
- * https://arxiv.org/abs/1412.6980
- *
- * Note that this function is completely stateless and does not keep track of
- * step count. The step count needs to be maintained by the caller and passed
- * in as `step`.
- *
- * @param v The current moving average value.
- * @param x New input value, must have the same shape and dtype as `v`.
- * @param decay The decay factor. Typical values are 0.95 and 0.99.
- * @param step Step count.
- * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
- * @returns The new moving average value.
- *
- * @doc {heading: 'Operations', subheading: 'Moving Average'}
- */
- function movingAverage_(v, x, decay, step, zeroDebias = true) {
- const $v = convertToTensor(v, 'v', 'movingAverage');
- const $x = convertToTensor(x, 'x', 'movingAverage');
- const $decay = convertToTensor(decay, 'decay', 'movingAverage');
- assertTypesMatch($v, $x);
- assert(arraysEqual($v.shape, $x.shape), () => 'Shape mismatch in v and x');
- const one = scalar(1);
- const oneMinusDecay = sub(one, $decay);
- let update = mul(sub($x, $v), oneMinusDecay);
- if (zeroDebias) {
- assert(step != null, () => 'When using zeroDebias: true, step is required.');
- const $step = convertToTensor(step, 'step', 'movingAverage');
- update = div(update, sub(one, pow($decay, $step)));
- }
- return add$1($v, update);
- }
- const movingAverage = op({ movingAverage_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a new tensor by applying sparse updates to individual
- * values or slices within a zero tensor of the given shape tensor according to
- * indices. This operator is the inverse of the `tf.gatherND` operator which
- * extracts values or slices from a given tensor.
- *
- * ```js
- * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
- * const updates = tf.tensor1d([9, 10, 11, 12]);
- * const shape = [8];
- * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
- * ```
- *
- * @param indices The tensor contains the indices into the output tensor.
- * @param updates The tensor contains the value for the indices.
- * @param shape: The shape of the output tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
- */
- function scatterND_(indices, updates, shape) {
- const $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
- const $updates = convertToTensor(updates, 'updates', 'scatterND');
- validateInput($updates, $indices, shape);
- const forward = (backend) => {
- return backend.scatterND($indices, $updates, shape);
- };
- const inputs = { indices: $indices, updates: $updates };
- const attrs = { shape };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, ScatterNd, attrs);
- }
- const scatterND = op({ scatterND_ });
-
- /**
- * Validate sparseToDense inputs.
- *
- * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
- * sparseIndices[i] contains the complete index where sparseValues[i] will be
- * placed.
- * @param sparseValues A 0-D or 1-D Tensor. Values
- * corresponding to each row of sparseIndices, or a scalar value to be used for
- * all sparse indices.
- * @param outputShape number[]. Shape of the dense output tensor.
- * @param validateIndices boolean. indice validation is not supported, error
- * will be thrown if it is set.
- */
- function validateInput$1(sparseIndices, sparseValues, outputShape, defaultValues) {
- if (sparseIndices.dtype !== 'int32') {
- throw new Error('tf.sparseToDense() expects the indices to be int32 type,' +
- ` but the dtype was ${sparseIndices.dtype}.`);
- }
- if (sparseIndices.rank > 2) {
- throw new Error('sparseIndices should be a scalar, vector, or matrix,' +
- ` but got shape ${sparseIndices.shape}.`);
- }
- const numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
- const numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
- if (outputShape.length !== numDims) {
- throw new Error('outputShape has incorrect number of elements:,' +
- ` ${outputShape.length}, should be: ${numDims}.`);
- }
- const numValues = sparseValues.size;
- if (!(sparseValues.rank === 0 ||
- sparseValues.rank === 1 && numValues === numElems)) {
- throw new Error('sparseValues has incorrect shape ' +
- `${sparseValues.shape}, should be [] or [${numElems}]`);
- }
- if (sparseValues.dtype !== defaultValues.dtype) {
- throw new Error('sparseValues.dtype must match defaultValues.dtype');
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts a sparse representation into a dense tensor.
- *
- * Builds an array dense with shape outputShape such that:
- *
- * // If sparseIndices is scalar
- * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
- *
- * // If sparseIndices is a vector, then for each i
- * dense[sparseIndices[i]] = sparseValues[i]
- *
- * // If sparseIndices is an n by d matrix, then for each i in [0, n)
- * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
- * All other values in dense are set to defaultValue. If sparseValues is a
- * scalar, all sparse indices are set to this single value.
- *
- * If indices are repeated the final value is summed over all values for those
- * indices.
- *
- * ```js
- * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
- * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
- * const shape = [8];
- * tf.sparseToDense(indices, values, shape).print();
- * ```
- *
- * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
- * sparseIndices[i] contains the complete index where sparseValues[i] will be
- * placed.
- * @param sparseValues A 0-D or 1-D Tensor. Values
- * corresponding to each row of sparseIndices, or a scalar value to be used for
- * all sparse indices.
- * @param outputShape Shape of the dense output tensor. the type is inferred.
- * @param defaultValue Scalar. Value to set for indices not specified in
- * sparseIndices. Defaults to zero.
- *
- * @doc {heading: 'Operations', subheading: 'Normalization'}
- */
- function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue = 0) {
- const $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
- const $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense');
- const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
- validateInput$1($sparseIndices, $sparseValues, outputShape, $defaultValue);
- const inputs = {
- sparseIndices: $sparseIndices,
- sparseValues: $sparseValues,
- defaultValue: $defaultValue
- };
- const attrs = { outputShape };
- return ENGINE.runKernelFunc(backend => backend.sparseToDense($sparseIndices, $sparseValues, outputShape, $defaultValue), inputs, null /* grad */, SparseToDense, attrs);
- }
- const sparseToDense = op({ sparseToDense_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gather slices from input tensor into a Tensor with shape specified by
- * `indices`.
- *
- * `indices` is an K-dimensional integer tensor, best thought of as a
- * (K-1)-dimensional tensor of indices into input, where each element defines a
- * slice of input:
- * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
- *
- * Whereas in `tf.gather`, `indices` defines slices into the first dimension of
- * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
- * of input, where N = indices.shape[-1].
- *
- * The last dimension of indices can be at most the rank of input:
- * indices.shape[-1] <= input.rank
- *
- * The last dimension of `indices` corresponds to elements
- * (if indices.shape[-1] == input.rank) or slices
- * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
- * input.
- * The output tensor has shape
- * indices.shape[:-1] + input.shape[indices.shape[-1]:]
- *
- * Note that on CPU, if an out of bound index is found, an error is returned. On
- * GPU, if an out of bound index is found, a 0 is stored in the corresponding
- * output value.
- *
- * ```js
- * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
- * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
- * tf.gatherND(input, indices).print() // [10, 11]
- * ```
- *
- * @param x The tensor from which to gather values.
- * @param indices Index tensor, must be of type int32.
- *
- * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
- */
- function gatherND_(x, indices) {
- const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
- const $x = convertToTensor(x, 'x', 'gatherND');
- const forward = (backend) => {
- return backend.gatherND($x, $indices);
- };
- const inputs = { params: $x, indices: $indices };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, GatherNd);
- }
- const gatherND = op({ gatherND_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Normalize noise shape based on provided tensor and noise shape.
- *
- * @param x Tensor.
- * @param noiseShape The shape for the randomly generated keep/drop flags, as
- * an array of numbers. Optional.
- * @returns Normalized noise shape.
- */
- function getNoiseShape(x, noiseShape) {
- if (noiseShape == null) {
- return x.shape.slice();
- }
- if (arraysEqual(x.shape, noiseShape)) {
- return noiseShape;
- }
- if (x.shape.length === noiseShape.length) {
- const newDimension = [];
- for (let i = 0; i < x.shape.length; i++) {
- if (noiseShape[i] == null && x.shape[i] != null) {
- newDimension.push(x.shape[i]);
- }
- else {
- newDimension.push(noiseShape[i]);
- }
- }
- return newDimension;
- }
- return noiseShape;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes dropout.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 2, 1]);
- * const rate = 0.75;
- * const output = tf.dropout(x, rate);
- * output.print();
- * ```
- *
- * @param x A floating point Tensor or TensorLike.
- * @param rate A float in the range [0, 1). The probability that each element
- * of x is discarded.
- * @param noiseShape An array of numbers of type int32, representing the
- * shape for randomly generated keep/drop flags. If the noiseShape has null
- * value, it will be automatically replaced with the x's relative dimension
- * size. Optional.
- * @param seed Used to create random seeds. Optional.
- * @returns A Tensor of the same shape of x.
- *
- * @doc {heading: 'Operations', subheading: 'Dropout'}
- */
- function dropout_(x, rate, noiseShape, seed) {
- const $x = convertToTensor(x, 'x', 'dropout');
- assert($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
- `scaled, but got a ${$x.dtype} tensor instead.`);
- assert(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
- if (rate === 0) {
- return x instanceof Tensor ? $x.clone() : $x;
- }
- const $noiseShape = getNoiseShape($x, noiseShape);
- const keepProb = 1 - rate;
- const multiplier = div(floor(add$1(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
- return mul($x, multiplier);
- }
- const dropout = op({ dropout_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function enclosingPowerOfTwo(value) {
- // Return 2**N for integer N such that 2**N >= value.
- return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
- }
- function cosineWindow(windowLength, a, b) {
- const even = 1 - windowLength % 2;
- const newValues = new Float32Array(windowLength);
- for (let i = 0; i < windowLength; ++i) {
- const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1);
- newValues[i] = a - b * Math.cos(cosArg);
- }
- return tensor1d(newValues, 'float32');
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns whether the targets are in the top K predictions.
- *
- * ```js
- * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
- * const targets = tf.tensor1d([2, 0]);
- * const precision = await tf.inTopKAsync(predictions, targets);
- * precision.print();
- * ```
- * @param predictions 2-D or higher `tf.Tensor` with last dimension being
- * at least `k`.
- * @param targets 1-D or higher `tf.Tensor`.
- * @param k Optional Number of top elements to look at for computing precision,
- * default to 1.
- *
- * @doc {heading: 'Operations', subheading: 'Evaluation'}
- */
- async function inTopKAsync_(predictions, targets, k = 1) {
- const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
- const $targets = convertToTensor(targets, 'targets', 'inTopK');
- assert($predictions.rank > 1, () => 'inTopK() expects the predictions to be of rank 2 or higher, ' +
- `but got ${$predictions.rank}`);
- assert($predictions.rank - 1 === $targets.rank, () => `predictions rank should be 1 larger than ` +
- `targets rank, but got predictions rank ` +
- `${$predictions.rank} and targets rank ${$targets.rank}`);
- assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, `predictions's shape should be align with the targets' shape, ` +
- 'except the last dimension.');
- const lastDim = $predictions.shape[$predictions.shape.length - 1];
- assert(k > 0 && k <= lastDim, () => `'k' passed to inTopK() must be > 0 && <= the predictions last ` +
- `dimension (${lastDim}), but got ${k}`);
- const predictionsVals = await $predictions.data();
- const targetsVals = await $targets.data();
- // Reshape predictionsVals into a 2d tensor [batch, lastDim]
- // and look up topK along lastDim.
- const [batch, size] = [predictionsVals.length / lastDim, lastDim];
- const precision = getTypedArrayFromDType('bool', batch);
- for (let b = 0; b < batch; b++) {
- const offset = b * size;
- const vals = predictionsVals.subarray(offset, offset + size);
- const valAndInd = [];
- for (let i = 0; i < vals.length; i++) {
- valAndInd.push({ value: vals[i], index: i });
- }
- valAndInd.sort((a, b) => b.value - a.value);
- precision[b] = 0;
- for (let i = 0; i < k; i++) {
- if (valAndInd[i].index === targetsVals[b]) {
- precision[b] = 1;
- break;
- }
- }
- }
- if (predictions !== $predictions) {
- $predictions.dispose();
- }
- if (targets !== $targets) {
- $targets.dispose();
- }
- // Output precision has the same shape as targets.
- return tensor(precision, $targets.shape, 'bool');
- }
- const inTopKAsync = inTopKAsync_;
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the derivative of the filter of a 2D convolution.
- *
- * @param x The input tensor, of rank 4 or rank 3 of shape
- * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
- * @param dy The dy image, of rank 4 or rank 3, of shape
- * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
- * @param filterShape The shape of the filter, length 4,
- * [filterHeight, filterWidth, inDepth, outDepth].
- * @param strides The strides of the convolution: [strideHeight,
- * strideWidth].
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels].
- * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The
- * rounding mode used when computing output dimensions if pad is a
- * number. If none is provided, it will not round and error if the output
- * is of fractional size.
- */
- function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
- let x4D = x;
- if (x.rank === 3) {
- x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
- }
- let dy4D = dy;
- if (dy4D.rank === 3) {
- dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
- `${x4D.shape}.`);
- assert(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
- `${dy4D.shape}.`);
- assert(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
- `${filterShape}.`);
- const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
- const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
- assert(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
- `match input depth in filter (${filterShape[2]}.`);
- assert(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
- `match output depth for filter (${filterShape[3]}).`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in conv2dDerFilter: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = backend => {
- const dilations = 1;
- const $dataFormat = convertConv2DDataFormat(dataFormat);
- const convInfo = computeConv2DInfo(x4D.shape, filterShape, strides, dilations, pad, dimRoundingMode, false, $dataFormat);
- return backend.conv2dDerFilter(x4D, dy4D, convInfo);
- };
- const inputs = { x: x4D, dy: dy4D };
- const attrs = { strides, pad, dataFormat, dimRoundingMode };
- return ENGINE.runKernelFunc(forward, inputs, null, Conv2DBackpropFilter, attrs);
- }
- const conv2DBackpropFilter = op({ conv2DBackpropFilter_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Returns gradient for fused activation.
- function getFusedDyActivation(dy, y, activation) {
- if (activation == null || activation === 'linear') {
- return dy;
- }
- if (activation === 'relu') {
- return mul(dy, step(y));
- }
- throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
- }
- // Returns gradient for fused bias.
- function getFusedBiasGradient(bias, dyActivation) {
- let res = dyActivation;
- const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, bias.shape);
- }
- function applyActivation(x, activation, preluActivationWeights) {
- if (activation === 'linear') {
- return x;
- }
- else if (activation === 'relu') {
- return relu(x);
- }
- else if (activation === 'elu') {
- return elu(x);
- }
- else if (activation === 'relu6') {
- return relu6(x);
- }
- else if (activation === 'prelu') {
- return prelu(x, preluActivationWeights);
- }
- throw new Error(`Unknown fused activation ${activation}.`);
- }
- // Whether we should call fused ops.
- const shouldFuse = (gradientDepth, activation) => {
- const gradientMode = gradientDepth > 0;
- return !gradientMode || activation === 'linear';
- };
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes a 2D convolution over the input x, optionally fused with adding a
- * bias and applying an activation.
- *
- * ```js
- * const inputDepth = 2;
- * const inShape = [2, 2, 2, inputDepth];
- * const outputDepth = 2;
- * const fSize = 1;
- * const pad = 0;
- * const strides = 1;
- *
- * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- * 16], inShape);
- * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
- * outputDepth]);
- *
- * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
- * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
- * ```
- *
- * @param obj An object with the following properties:
- * @param x The input tensor, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
- * assumed.
- * @param filter The filter, rank 4, of shape
- * `[filterHeight, filterWidth, inDepth, outDepth]`.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid` output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels]. Only "NHWC" is currently supported.
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- * @param bias Tensor to be added to the result.
- * @param activation Name of activation kernel (defaults to `linear`) to be
- * applied
- * after biasAdd.
- * @param preluActivationWeights Tensor of prelu weights to be applied as part
- * of a `prelu` activation, typically the same shape as `x`.
- */
- function fusedConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights }) {
- activation = activation || 'linear';
- if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
- let result = conv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- if (bias != null) {
- result = add$1(result, bias);
- }
- return applyActivation(result, activation, preluActivationWeights);
- }
- const $x = convertToTensor(x, 'x', 'conv2d');
- const $filter = convertToTensor(filter, 'filter', 'conv2d');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in fused conv2d: input must be rank 4, but got rank ` +
- `${x4D.rank}.`);
- assert($filter.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ` +
- `${$filter.rank}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in fused conv2d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- assert(x4D.shape[3] === $filter.shape[2], () => `Error in conv2d: depth of input (${x4D.shape[3]}) must match ` +
- `input depth for filter ${$filter.shape[2]}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- assert(dataFormat === 'NHWC', () => `Error in conv2d: got dataFormat of ${dataFormat} but only NHWC is currently supported.`);
- const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
- let $bias;
- if (bias != null) {
- $bias = convertToTensor(bias, 'bias', 'fused conv2d');
- [$bias] = makeTypesMatch($bias, $x);
- assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
- }
- let $preluActivationWeights;
- if (preluActivationWeights != null) {
- $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
- }
- const grad = (dy, saved) => {
- const [$filter, x4D, y, $bias] = saved;
- const dyActivation = getFusedDyActivation(dy, y, activation);
- assert(tupleValuesAreOne(dilations), () => 'Error in gradient of fused conv2D: ' +
- `dilation rates greater than 1 ` +
- `are not yet supported in gradients. Got dilations '${dilations}'`);
- const xDer = conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);
- const filterDer = conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);
- const der = [xDer, filterDer];
- if ($bias != null) {
- const biasDer = getFusedBiasGradient($bias, dyActivation);
- der.push(biasDer);
- }
- return der;
- };
- const forward = (backend) => {
- const res = backend.fusedConv2d({
- input: x4D,
- filter: $filter,
- convInfo,
- bias: $bias,
- activation,
- preluActivationWeights: $preluActivationWeights
- });
- return res;
- };
- const inputs = {
- x: x4D,
- filter: $filter,
- bias: $bias,
- preluActivationWeights: $preluActivationWeights
- };
- const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode, activation };
- // Depending on the the params passed in we will have different number of
- // inputs and thus a a different number of elements in the gradient.
- if (bias == null) {
- const customOp = customGrad((x4D, filter, save) => {
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, FusedConv2D, attrs);
- save([filter, x4D, res]);
- if (reshapedTo4D) {
- res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return { value: res, gradFunc: grad };
- });
- return customOp(x4D, $filter);
- }
- else {
- const customOpWithBias = customGrad((x4D, filter, bias, save) => {
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, FusedConv2D, attrs);
- save([filter, x4D, res, bias]);
- if (reshapedTo4D) {
- res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return { value: res, gradFunc: grad };
- });
- return customOpWithBias(x4D, $filter, $bias);
- }
- }
- const conv2d$1 = op({ fusedConv2d_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, convInfo) {
- let x4D = x;
- if (x.rank === 3) {
- x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
- }
- let dy4D = dy;
- if (dy4D.rank === 3) {
- dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
- }
- const forward = backend => backend.depthwiseConv2DDerFilter(x4D, dy4D, convInfo);
- const inputs = { x: x4D, dy: dy4D };
- return ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropFilter);
- }
- const depthwiseConv2dNativeBackpropFilter = op({ depthwiseConv2dNativeBackpropFilter_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, convInfo) {
- let dy4D = dy;
- let reshapedTo4D = false;
- if (dy.rank === 3) {
- reshapedTo4D = true;
- dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
- }
- const forward = backend => backend.depthwiseConv2DDerInput(dy4D, filter, convInfo);
- const inputs = { dy: dy4D };
- const res = ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropInput);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const depthwiseConv2dNativeBackpropInput = op({ depthwiseConv2dNativeBackpropInput_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes depthwise 2D convolution, optionally fused with adding a
- * bias and applying an activation.
- *
- * Given a 4D `input` array and a `filter` array of shape
- * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
- * `inChannels` convolutional filters of depth 1, this op applies a
- * different filter to each input channel (expanding from 1 channel to
- * `channelMultiplier` channels for each), then concatenates the results
- * together. The output has `inChannels * channelMultiplier` channels.
- *
- * See
- * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
- * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
- * for more details.
- *
- * @param obj An object with the following properties:
- * @param x The input tensor, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
- * assumed.
- * @param filter The filter tensor, rank 4, of shape
- * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
- * @param strides The strides of the convolution: `[strideHeight,
- * strideWidth]`. If strides is a single number, then `strideHeight ==
- * strideWidth`.
- * @param pad The type of padding algorithm.
- * - `same` and stride 1: output will be of same size as input,
- * regardless of filter size.
- * - `valid`: output will be smaller than input if filter is larger
- * than 1x1.
- * - For more info, see this guide:
- * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
- * https://www.tensorflow.org/api_guides/python/nn#Convolution)
- * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
- * in which we sample input values across the height and width dimensions
- * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
- * number, then `dilationHeight == dilationWidth`. If it is greater than
- * 1, then all values of `strides` must be 1.
- * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
- * "NHWC". Specify the data format of the input and output data. With the
- * default format "NHWC", the data is stored in the order of: [batch,
- * height, width, channels]. Only "NHWC" is currently supported.
- * @param dimRoundingMode The rounding mode used when computing output
- * dimensions if pad is a number. If none is provided, it will not round
- * and error if the output is of fractional size.
- * @param bias Tensor to be added to the result.
- * @param activation Name of activation kernel (defaults to `linear`).
- * @param preluActivationWeights Tensor of prelu weights to be applied as part
- * of a `prelu` activation, typically the same shape as `x`.
- */
- function fusedDepthwiseConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights }) {
- if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
- let result = depthwiseConv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- if (bias != null) {
- result = add$1(result, bias);
- }
- return applyActivation(result, activation, preluActivationWeights);
- }
- const $x = convertToTensor(x, 'x', 'depthwiseConv2d');
- const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');
- let x4D = $x;
- let reshapedTo4D = false;
- if ($x.rank === 3) {
- reshapedTo4D = true;
- x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
- }
- assert(x4D.rank === 4, () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` +
- `rank ${x4D.rank}.`);
- assert($filter.rank === 4, () => `Error in fused depthwiseConv2d: filter must be rank 4, ` +
- `but got rank ${$filter.rank}.`);
- assert(x4D.shape[3] === $filter.shape[2], () => `Error in fused depthwiseConv2d: number of input channels ` +
- `(${x4D.shape[3]}) must match the inChannels dimension in ` +
- `filter ${$filter.shape[2]}.`);
- if (dilations == null) {
- dilations = [1, 1];
- }
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in fused depthwiseConv2d: Either strides or dilations must ' +
- `be 1. Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in fused depthwiseConv2d: pad must be an integer when ` +
- `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
- let $bias;
- if (bias != null) {
- $bias = convertToTensor(bias, 'bias', 'fused conv2d');
- [$bias] = makeTypesMatch($bias, $x);
- assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
- }
- let $preluActivationWeights;
- if (preluActivationWeights != null) {
- $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
- }
- const grad = (dy, saved) => {
- assert(tupleValuesAreOne(dilations), () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' +
- `greater than 1 are not yet supported. Got dilations ` +
- `'${dilations}'`);
- const [$filter, x4D, y, bias] = saved;
- const dyActivation = getFusedDyActivation(dy, y, activation);
- const xDer = depthwiseConv2dNativeBackpropInput(x4D.shape, dyActivation, $filter, convInfo);
- const filterDer = depthwiseConv2dNativeBackpropFilter(x4D, dyActivation, $filter.shape, convInfo);
- if (bias != null) {
- const biasDer = getFusedBiasGradient($bias, dyActivation);
- return [xDer, filterDer, biasDer];
- }
- return [xDer, filterDer];
- };
- const forward = (backend) => {
- const res = backend.fusedDepthwiseConv2D({
- input: x4D,
- filter: $filter,
- convInfo,
- bias: $bias,
- activation,
- preluActivationWeights: $preluActivationWeights
- });
- return res;
- };
- const inputs = {
- x: x4D,
- filter: $filter,
- bias: $bias,
- preluActivationWeights: $preluActivationWeights
- };
- const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode, activation };
- // Depending on the the params passed in we will have different number of
- // inputs and thus a a different number of elements in the gradient.
- if (bias == null) {
- const customOp = customGrad((x4D, filter, save) => {
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, FusedDepthwiseConv2D, attrs);
- save([filter, x4D, res]);
- if (reshapedTo4D) {
- res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return { value: res, gradFunc: grad };
- });
- return customOp(x4D, $filter);
- }
- else {
- const customOpWithBias = customGrad((x4D, filter, bias, save) => {
- let res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, FusedDepthwiseConv2D, attrs);
- save([filter, x4D, res, bias]);
- if (reshapedTo4D) {
- res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return { value: res, gradFunc: grad };
- });
- return customOpWithBias(x4D, $filter, $bias);
- }
- }
- const depthwiseConv2d$1 = op({ fusedDepthwiseConv2d_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the dot product of two matrices with optional activation and bias.
- *
- * ```js
- * const a = tf.tensor2d([-1, -2], [1, 2]);
- * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const bias = tf.tensor2d([1, 2], [1, 2]);
- *
- * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
- * ```
- *
- * @param obj An object with the following properties:
- * - `a` First matrix in dot product operation.
- * - `b` Second matrix in dot product operation.
- * - `transposeA` If true, `a` is transposed before multiplication.
- * - `transposeB` If true, `b` is transposed before multiplication.
- * - `bias` Matrix to be added to the result.
- * - `activation` Name of activation kernel (defaults to `linear`).
- * - `preluActivationWeights` Tensor of prelu weights.
- */
- function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights }) {
- if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
- let result = matMul(a, b, transposeA, transposeB);
- if (bias != null) {
- result = add$1(result, bias);
- }
- return applyActivation(result, activation, preluActivationWeights);
- }
- let $a = convertToTensor(a, 'a', 'fused matMul');
- let $b = convertToTensor(b, 'b', 'fused matMul');
- [$a, $b] = makeTypesMatch($a, $b);
- const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
- const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
- const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
- const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
- const outerDimsA = $a.shape.slice(0, -2);
- const outerDimsB = $b.shape.slice(0, -2);
- const batchDimA = sizeFromShape(outerDimsA);
- const batchDimB = sizeFromShape(outerDimsB);
- assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, () => `Error in fused matMul: inputs must have the same rank of at least ` +
- `2, got ranks ${$a.rank} and ${$b.rank}.`);
- assert(arraysEqual(outerDimsA, outerDimsB), () => `Error in fused matMul: outer dimensions (${outerDimsA}) and (` +
- `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` +
- `${$b.shape} must match.`);
- assert(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
- `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
- `${$b.shape} and transposeA=${transposeA}` +
- ` and transposeB=${transposeB} must match.`);
- const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);
- const a3D = transposeA ?
- reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
- reshape($a, [batchDimA, outerShapeA, innerShapeA]);
- const b3D = transposeB ?
- reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
- reshape($b, [batchDimB, innerShapeB, outerShapeB]);
- let $bias;
- if (bias != null) {
- $bias = convertToTensor(bias, 'bias', 'fused matMul');
- [$bias] = makeTypesMatch($bias, $a);
- assertAndGetBroadcastShape(outShape, $bias.shape);
- }
- let $preluActivationWeights;
- if (preluActivationWeights != null) {
- $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
- }
- const grad = (dy, saved) => {
- const [a3D, b3D, y, $bias] = saved;
- // we reshape dy because the result of the forward is not
- // necessarily going to be a 3d tensor due to a reshape done at the end of
- // the customOp.
- const dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation);
- let aDer;
- let bDer;
- if (!transposeA && !transposeB) {
- aDer = matMul(dyActivation, b3D, false, true);
- bDer = matMul(a3D, dyActivation, true, false);
- }
- else if (!transposeA && transposeB) {
- aDer = matMul(dyActivation, b3D, false, false);
- bDer = matMul(dyActivation, a3D, true, false);
- }
- else if (transposeA && !transposeB) {
- aDer = matMul(b3D, dyActivation, false, true);
- bDer = matMul(a3D, dyActivation, false, false);
- }
- else {
- aDer = matMul(b3D, dyActivation, true, true);
- bDer = matMul(dyActivation, a3D, true, true);
- }
- if (bias != null) {
- const biasDer = getFusedBiasGradient($bias, dyActivation);
- return [aDer, bDer, biasDer];
- }
- else {
- return [aDer, bDer];
- }
- };
- const forward = (backend) => {
- const y = backend.fusedBatchMatMul({
- a: a3D,
- b: b3D,
- transposeA,
- transposeB,
- bias: $bias,
- activation,
- preluActivationWeights: $preluActivationWeights
- });
- return y;
- };
- const inputs = {
- a: a3D,
- b: b3D,
- bias: $bias,
- preluActivationWeights: $preluActivationWeights
- };
- const attrs = { transposeA, transposeB, activation };
- // Depending on the the params passed in we will have different number of
- // inputs and thus a a different number of elements in the gradient.
- if (bias == null) {
- const customOp = customGrad((a3D, b3D, save) => {
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, _FusedMatMul, attrs);
- save([a3D, b3D, res]);
- return { value: reshape(res, outShape), gradFunc: grad };
- });
- return customOp(a3D, b3D);
- }
- else {
- const customOpWithBias = customGrad((a3D, b3D, $bias, save) => {
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, _FusedMatMul, attrs);
- save([a3D, b3D, res, $bias]);
- return { value: reshape(res, outShape), gradFunc: grad };
- });
- return customOpWithBias(a3D, b3D, $bias);
- }
- }
- const matMul$1 = op({ fusedMatMul_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var fused_ops = /*#__PURE__*/Object.freeze({
- __proto__: null,
- conv2d: conv2d$1,
- depthwiseConv2d: depthwiseConv2d$1,
- matMul: matMul$1
- });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Generate a hamming window.
- *
- * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
- *
- * ```js
- * tf.signal.hammingWindow(10).print();
- * ```
- * @param The length of window
- *
- * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
- */
- function hammingWindow_(windowLength) {
- return cosineWindow(windowLength, 0.54, 0.46);
- }
- const hammingWindow = op({ hammingWindow_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Generate a Hann window.
- *
- * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
- *
- * ```js
- * tf.signal.hannWindow(10).print();
- * ```
- * @param The length of window
- *
- * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
- */
- function hannWindow_(windowLength) {
- return cosineWindow(windowLength, 0.5, 0.5);
- }
- const hannWindow = op({ hannWindow_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Expands input into frames of frameLength.
- * Slides a window size with frameStep.
- *
- * ```js
- * tf.signal.frame([1, 2, 3], 2, 1).print();
- * ```
- * @param signal The input tensor to be expanded
- * @param frameLength Length of each frame
- * @param frameStep The frame hop size in samples.
- * @param padEnd Whether to pad the end of signal with padValue.
- * @param padValue An number to use where the input signal does
- * not exist when padEnd is True.
- *
- * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
- */
- function frame_(signal, frameLength, frameStep, padEnd = false, padValue = 0) {
- let start = 0;
- const output = [];
- while (start + frameLength <= signal.size) {
- output.push(slice(signal, start, frameLength));
- start += frameStep;
- }
- if (padEnd) {
- while (start < signal.size) {
- const padLen = (start + frameLength) - signal.size;
- const pad = concat([
- slice(signal, start, frameLength - padLen), fill([padLen], padValue)
- ]);
- output.push(pad);
- start += frameStep;
- }
- }
- if (output.length === 0) {
- return tensor2d([], [0, frameLength]);
- }
- return reshape(concat(output), [output.length, frameLength]);
- }
- const frame = op({ frame_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the Short-time Fourier Transform of signals
- * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
- *
- * ```js
- * const input = tf.tensor1d([1, 1, 1, 1, 1])
- * tf.signal.stft(input, 3, 1).print();
- * ```
- * @param signal 1-dimensional real value tensor.
- * @param frameLength The window length of samples.
- * @param frameStep The number of samples to step.
- * @param fftLength The size of the FFT to apply.
- * @param windowFn A callable that takes a window length and returns 1-d tensor.
- *
- * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
- */
- function stft_(signal, frameLength, frameStep, fftLength, windowFn = hannWindow) {
- if (fftLength == null) {
- fftLength = enclosingPowerOfTwo(frameLength);
- }
- const framedSignal = frame(signal, frameLength, frameStep);
- const windowedSignal = mul(framedSignal, windowFn(frameLength));
- const output = [];
- for (let i = 0; i < framedSignal.shape[0]; i++) {
- output.push(rfft(slice(windowedSignal, [i, 0], [1, frameLength]), fftLength));
- }
- return concat(output);
- }
- const stft = op({ stft_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Extracts crops from the input image tensor and resizes them using bilinear
- * sampling or nearest neighbor sampling (possibly with aspect ratio change)
- * to a common output size specified by crop_size.
- *
- * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
- * where imageHeight and imageWidth must be positive, specifying the
- * batch of images from which to take crops
- * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
- * coordinates of the box in the boxInd[i]'th image in the batch
- * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
- * `[0, batch)` that specifies the image that the `i`-th box refers to.
- * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
- * specifying the size to which all crops are resized to.
- * @param method Optional string from `'bilinear' | 'nearest'`,
- * defaults to bilinear, which specifies the sampling method for resizing
- * @param extrapolationValue A threshold for deciding when to remove boxes based
- * on score. Defaults to 0.
- * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function cropAndResize_(image, boxes, boxInd, cropSize, method, extrapolationValue) {
- const $image = convertToTensor(image, 'image', 'cropAndResize');
- const $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
- const $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
- method = method || 'bilinear';
- extrapolationValue = extrapolationValue || 0;
- const numBoxes = $boxes.shape[0];
- assert($image.rank === 4, () => 'Error in cropAndResize: image must be rank 4,' +
- `but got rank ${$image.rank}.`);
- assert($boxes.rank === 2 && $boxes.shape[1] === 4, () => `Error in cropAndResize: boxes must be have size [${numBoxes},4] ` +
- `but had shape ${$boxes.shape}.`);
- assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, () => `Error in cropAndResize: boxInd must be have size [${numBoxes}] ` +
- `but had shape ${$boxes.shape}.`);
- assert(cropSize.length === 2, () => `Error in cropAndResize: cropSize must be of length 2, but got ` +
- `length ${cropSize.length}.`);
- assert(cropSize[0] >= 1 && cropSize[1] >= 1, () => `cropSize must be atleast [1,1], but was ${cropSize}`);
- assert(method === 'bilinear' || method === 'nearest', () => `method must be bilinear or nearest, but was ${method}`);
- const forward = (backend) => backend.cropAndResize($image, $boxes, $boxInd, cropSize, method, extrapolationValue);
- const inputs = { image: $image, boxes: $boxes, boxInd: $boxInd };
- const attrs = { method, extrapolationValue, cropSize };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, CropAndResize, attrs);
- return res;
- }
- const cropAndResize = op({ cropAndResize_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Flips the image left to right. Currently available in the CPU, WebGL, and
- * WASM backends.
- *
- * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
- */
- /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
- function flipLeftRight_(image) {
- const $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
- assert($image.rank === 4, () => 'Error in flipLeftRight: image must be rank 4,' +
- `but got rank ${$image.rank}.`);
- const inputs = { image: $image };
- const res = ENGINE.runKernel(FlipLeftRight, inputs, {});
- return res;
- }
- const flipLeftRight = op({ flipLeftRight_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Rotates the input image tensor counter-clockwise with an optional offset
- * center of rotation. Currently available in the CPU, WebGL, and WASM backends.
- *
- * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
- * @param radians The amount of rotation.
- * @param fillValue The value to fill in the empty space leftover
- * after rotation. Can be either a single grayscale value (0-255), or an
- * array of three numbers `[red, green, blue]` specifying the red, green,
- * and blue channels. Defaults to `0` (black).
- * @param center The center of rotation. Can be either a single value (0-1), or
- * an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
- * the image around its center).
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function rotateWithOffset_(image, radians, fillValue = 0, center = 0.5) {
- const $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
- assert($image.rank === 4, () => 'Error in rotateWithOffset: image must be rank 4,' +
- `but got rank ${$image.rank}.`);
- const inputs = { image: $image };
- const attrs = { radians, fillValue, center };
- const res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
- return res;
- }
- const rotateWithOffset = op({ rotateWithOffset_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
- if (iouThreshold == null) {
- iouThreshold = 0.5;
- }
- if (scoreThreshold == null) {
- scoreThreshold = Number.NEGATIVE_INFINITY;
- }
- if (softNmsSigma == null) {
- softNmsSigma = 0.0;
- }
- const numBoxes = boxes.shape[0];
- maxOutputSize = Math.min(maxOutputSize, numBoxes);
- assert(0 <= iouThreshold && iouThreshold <= 1, () => `iouThreshold must be in [0, 1], but was '${iouThreshold}'`);
- assert(boxes.rank === 2, () => `boxes must be a 2D tensor, but was of rank '${boxes.rank}'`);
- assert(boxes.shape[1] === 4, () => `boxes must have 4 columns, but 2nd dimension was ${boxes.shape[1]}`);
- assert(scores.rank === 1, () => 'scores must be a 1D tensor');
- assert(scores.shape[0] === numBoxes, () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` +
- `but was ${scores.shape[0]}`);
- assert(0 <= softNmsSigma && softNmsSigma <= 1, () => `softNmsSigma must be in [0, 1], but was '${softNmsSigma}'`);
- return { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
- const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
- maxOutputSize = inputs.maxOutputSize;
- iouThreshold = inputs.iouThreshold;
- scoreThreshold = inputs.scoreThreshold;
- const attrs = { maxOutputSize, iouThreshold, scoreThreshold };
- return ENGINE.runKernelFunc(b => b.nonMaxSuppression($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold), { boxes: $boxes, scores: $scores }, null /* grad */, NonMaxSuppressionV3, attrs);
- }
- const nonMaxSuppression = op({ nonMaxSuppression_ });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Inserts a value into a sorted array. This method allows duplicate, meaning it
- * allows inserting duplicate value, in which case, the element will be inserted
- * at the lowest index of the value.
- * @param arr The array to modify.
- * @param element The element to insert.
- * @param comparator Optional. If no comparator is specified, elements are
- * compared using array_util.defaultComparator, which is suitable for Strings
- * and Numbers in ascending arrays. If the array contains multiple instances of
- * the target value, the left-most instance will be returned. To provide a
- * comparator, it should take 2 arguments to compare and return a negative,
- * zero, or a positive number.
- */
- function binaryInsert(arr, element, comparator) {
- const index = binarySearch(arr, element, comparator);
- const insertionPoint = index < 0 ? -(index + 1) : index;
- arr.splice(insertionPoint, 0, element);
- }
- /**
- * Searches the array for the target using binary search, returns the index
- * of the found element, or position to insert if element not found. If no
- * comparator is specified, elements are compared using array_
- * util.defaultComparator, which is suitable for Strings and Numbers in
- * ascending arrays. If the array contains multiple instances of the target
- * value, the left-most instance will be returned.
- * @param arr The array to be searched in.
- * @param target The target to be searched for.
- * @param comparator Should take 2 arguments to compare and return a negative,
- * zero, or a positive number.
- * @return Lowest index of the target value if found, otherwise the insertion
- * point where the target should be inserted, in the form of
- * (-insertionPoint - 1).
- */
- function binarySearch(arr, target, comparator) {
- return binarySearch_(arr, target, comparator || defaultComparator);
- }
- /**
- * Compares its two arguments for order.
- * @param a The first element to be compared.
- * @param b The second element to be compared.
- * @return A negative number, zero, or a positive number as the first
- * argument is less than, equal to, or greater than the second.
- */
- function defaultComparator(a, b) {
- return a > b ? 1 : a < b ? -1 : 0;
- }
- function binarySearch_(arr, target, comparator) {
- let left = 0;
- let right = arr.length;
- let middle = 0;
- let found = false;
- while (left < right) {
- middle = left + ((right - left) >>> 1);
- const compareResult = comparator(target, arr[middle]);
- if (compareResult > 0) {
- left = middle + 1;
- }
- else {
- right = middle;
- // If compareResult is 0, the value is found. We record it is found,
- // and then keep looking because there may be duplicate.
- found = !compareResult;
- }
- }
- return found ? left : -left - 1;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function nonMaxSuppressionV3Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
- return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */)
- .selectedIndices;
- }
- function nonMaxSuppressionV4Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
- return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */, false /* returnScoresTensor */, padToMaxOutputSize /* padToMaxOutputSize */, true
- /* returnValidOutputs */ );
- }
- function nonMaxSuppressionV5Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
- return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true /* returnScoresTensor */);
- }
- function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) {
- // The list is sorted in ascending order, so that we can always pop the
- // candidate with the largest score in O(1) time.
- const candidates = [];
- for (let i = 0; i < scores.length; i++) {
- if (scores[i] > scoreThreshold) {
- candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
- }
- }
- candidates.sort(ascendingComparator);
- // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
- // before.
- const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
- const selectedIndices = [];
- const selectedScores = [];
- while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
- const candidate = candidates.pop();
- const { score: originalScore, boxIndex, suppressBeginIndex } = candidate;
- if (originalScore < scoreThreshold) {
- break;
- }
- // Overlapping boxes are likely to have similar scores, therefore we
- // iterate through the previously selected boxes backwards in order to
- // see if candidate's score should be suppressed. We use
- // suppressBeginIndex to track and ensure a candidate can be suppressed
- // by a selected box no more than once. Also, if the overlap exceeds
- // iouThreshold, we simply ignore the candidate.
- let ignoreCandidate = false;
- for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
- const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
- if (iou >= iouThreshold) {
- ignoreCandidate = true;
- break;
- }
- candidate.score =
- candidate.score * suppressWeight(iouThreshold, scale, iou);
- if (candidate.score <= scoreThreshold) {
- break;
- }
- }
- // At this point, if `candidate.score` has not dropped below
- // `scoreThreshold`, then we know that we went through all of the
- // previous selections and can safely update `suppressBeginIndex` to the
- // end of the selected array. Then we can re-insert the candidate with
- // the updated score and suppressBeginIndex back in the candidate list.
- // If on the other hand, `candidate.score` has dropped below the score
- // threshold, we will not add it back to the candidates list.
- candidate.suppressBeginIndex = selectedIndices.length;
- if (!ignoreCandidate) {
- // Candidate has passed all the tests, and is not suppressed, so
- // select the candidate.
- if (candidate.score === originalScore) {
- selectedIndices.push(boxIndex);
- selectedScores.push(candidate.score);
- }
- else if (candidate.score > scoreThreshold) {
- // Candidate's score is suppressed but is still high enough to be
- // considered, so add back to the candidates list.
- binaryInsert(candidates, candidate, ascendingComparator);
- }
- }
- }
- // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
- const validOutputs = selectedIndices.length;
- const elemsToPad = maxOutputSize - validOutputs;
- if (padToMaxOutputSize && elemsToPad > 0) {
- selectedIndices.push(...new Array(elemsToPad).fill(0));
- selectedScores.push(...new Array(elemsToPad).fill(0.0));
- }
- const result = { selectedIndices: tensor1d(selectedIndices, 'int32') };
- if (returnScoresTensor) {
- result['selectedScores'] = tensor1d(selectedScores, 'float32');
- }
- if (returnValidOutputs) {
- result['validOutputs'] = scalar(validOutputs, 'int32');
- }
- return result;
- }
- function intersectionOverUnion(boxes, i, j) {
- const iCoord = boxes.subarray(i * 4, i * 4 + 4);
- const jCoord = boxes.subarray(j * 4, j * 4 + 4);
- const yminI = Math.min(iCoord[0], iCoord[2]);
- const xminI = Math.min(iCoord[1], iCoord[3]);
- const ymaxI = Math.max(iCoord[0], iCoord[2]);
- const xmaxI = Math.max(iCoord[1], iCoord[3]);
- const yminJ = Math.min(jCoord[0], jCoord[2]);
- const xminJ = Math.min(jCoord[1], jCoord[3]);
- const ymaxJ = Math.max(jCoord[0], jCoord[2]);
- const xmaxJ = Math.max(jCoord[1], jCoord[3]);
- const areaI = (ymaxI - yminI) * (xmaxI - xminI);
- const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
- if (areaI <= 0 || areaJ <= 0) {
- return 0.0;
- }
- const intersectionYmin = Math.max(yminI, yminJ);
- const intersectionXmin = Math.max(xminI, xminJ);
- const intersectionYmax = Math.min(ymaxI, ymaxJ);
- const intersectionXmax = Math.min(xmaxI, xmaxJ);
- const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
- Math.max(intersectionXmax - intersectionXmin, 0.0);
- return intersectionArea / (areaI + areaJ - intersectionArea);
- }
- // A Gaussian penalty function, this method always returns values in [0, 1].
- // The weight is a function of similarity, the more overlap two boxes are, the
- // smaller the weight is, meaning highly overlapping boxe will be significantly
- // penalized. On the other hand, a non-overlapping box will not be penalized.
- function suppressWeight(iouThreshold, scale, iou) {
- const weight = Math.exp(scale * iou * iou);
- return iou <= iouThreshold ? weight : 0.0;
- }
- function ascendingComparator(c1, c2) {
- // For objects with same scores, we make the object with the larger index go
- // first. In an array that pops from the end, this means that the object with
- // the smaller index will be popped first. This ensures the same output as
- // the TensorFlow python version.
- return (c1.score - c2.score) ||
- ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Performs non maximum suppression of bounding boxes based on
- * iou (intersection over union).
- *
- * This is the async version of `nonMaxSuppression`
- *
- * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
- * the bounding box.
- * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
- * @param maxOutputSize The maximum number of boxes to be selected.
- * @param iouThreshold A float representing the threshold for deciding whether
- * boxes overlap too much with respect to IOU. Must be between [0, 1].
- * Defaults to 0.5 (50% box overlap).
- * @param scoreThreshold A threshold for deciding when to remove boxes based
- * on score. Defaults to -inf, which means any score is accepted.
- * @return A 1D tensor with the selected box indices.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- async function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
- const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
- maxOutputSize = inputs.maxOutputSize;
- iouThreshold = inputs.iouThreshold;
- scoreThreshold = inputs.scoreThreshold;
- const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
- const boxesVals = boxesAndScores[0];
- const scoresVals = boxesAndScores[1];
- // We call a cpu based impl directly with the typedarray data here rather
- // than a kernel because all kernels are synchronous (and thus cannot await
- // .data()).
- const res = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
- if ($boxes !== boxes) {
- $boxes.dispose();
- }
- if ($scores !== scores) {
- $scores.dispose();
- }
- return res;
- }
- const nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Performs non maximum suppression of bounding boxes based on
- * iou (intersection over union).
- *
- * This op also supports a Soft-NMS mode (c.f.
- * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
- * of other overlapping boxes, therefore favoring different regions of the image
- * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
- * parameter to be larger than 0.
- *
- * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
- * the bounding box.
- * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
- * @param maxOutputSize The maximum number of boxes to be selected.
- * @param iouThreshold A float representing the threshold for deciding whether
- * boxes overlap too much with respect to IOU. Must be between [0, 1].
- * Defaults to 0.5 (50% box overlap).
- * @param scoreThreshold A threshold for deciding when to remove boxes based
- * on score. Defaults to -inf, which means any score is accepted.
- * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
- * When sigma is 0, it falls back to nonMaxSuppression.
- * @return A map with the following properties:
- * - selectedIndices: A 1D tensor with the selected box indices.
- * - selectedScores: A 1D tensor with the corresponding scores for each
- * selected box.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
- const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
- maxOutputSize = params.maxOutputSize;
- iouThreshold = params.iouThreshold;
- scoreThreshold = params.scoreThreshold;
- softNmsSigma = params.softNmsSigma;
- const inputs = { boxes: $boxes, scores: $scores };
- const attrs = { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
- const result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
- return { selectedIndices: result[0], selectedScores: result[1] };
- }
- const nonMaxSuppressionWithScore = op({ nonMaxSuppressionWithScore_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Asynchronously performs non maximum suppression of bounding boxes based on
- * iou (intersection over union).
- *
- * This op also supports a Soft-NMS mode (c.f.
- * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
- * of other overlapping boxes, therefore favoring different regions of the image
- * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
- * parameter to be larger than 0.
- *
- * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
- * the bounding box.
- * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
- * @param maxOutputSize The maximum number of boxes to be selected.
- * @param iouThreshold A float representing the threshold for deciding whether
- * boxes overlap too much with respect to IOU. Must be between [0, 1].
- * Defaults to 0.5 (50% box overlap).
- * @param scoreThreshold A threshold for deciding when to remove boxes based
- * on score. Defaults to -inf, which means any score is accepted.
- * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
- * When sigma is 0, it falls back to nonMaxSuppression.
- * @return A map with the following properties:
- * - selectedIndices: A 1D tensor with the selected box indices.
- * - selectedScores: A 1D tensor with the corresponding scores for each
- * selected box.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- async function nonMaxSuppressionWithScoreAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
- const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
- maxOutputSize = params.maxOutputSize;
- iouThreshold = params.iouThreshold;
- scoreThreshold = params.scoreThreshold;
- softNmsSigma = params.softNmsSigma;
- const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
- const boxesVals = boxesAndScores[0];
- const scoresVals = boxesAndScores[1];
- // We call a cpu based impl directly with the typedarray data here rather
- // than a kernel because all kernels are synchronous (and thus cannot await
- // .data()).
- const res = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
- if ($boxes !== boxes) {
- $boxes.dispose();
- }
- if ($scores !== scores) {
- $scores.dispose();
- }
- return res;
- }
- const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Asynchronously performs non maximum suppression of bounding boxes based on
- * iou (intersection over union), with an option to pad results.
- *
- * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
- * the bounding box.
- * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
- * @param maxOutputSize The maximum number of boxes to be selected.
- * @param iouThreshold A float representing the threshold for deciding whether
- * boxes overlap too much with respect to IOU. Must be between [0, 1].
- * Defaults to 0.5 (50% box overlap).
- * @param scoreThreshold A threshold for deciding when to remove boxes based
- * on score. Defaults to -inf, which means any score is accepted.
- * @param padToMaxOutputSize Defalts to false. If true, size of output
- * `selectedIndices` is padded to maxOutputSize.
- * @return A map with the following properties:
- * - selectedIndices: A 1D tensor with the selected box indices.
- * - validOutputs: A scalar denoting how many elements in `selectedIndices`
- * are valid. Valid elements occur first, then padding.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
- const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
- const $maxOutputSize = params.maxOutputSize;
- const $iouThreshold = params.iouThreshold;
- const $scoreThreshold = params.scoreThreshold;
- const inputs = { boxes: $boxes, scores: $scores };
- const attrs = {
- maxOutputSize: $maxOutputSize,
- iouThreshold: $iouThreshold,
- scoreThreshold: $scoreThreshold,
- padToMaxOutputSize
- };
- const result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
- return { selectedIndices: result[0], validOutputs: result[1] };
- }
- const nonMaxSuppressionPadded = op({ nonMaxSuppressionPadded_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Asynchronously performs non maximum suppression of bounding boxes based on
- * iou (intersection over union), with an option to pad results.
- *
- * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
- * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
- * the bounding box.
- * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
- * @param maxOutputSize The maximum number of boxes to be selected.
- * @param iouThreshold A float representing the threshold for deciding whether
- * boxes overlap too much with respect to IOU. Must be between [0, 1].
- * Defaults to 0.5 (50% box overlap).
- * @param scoreThreshold A threshold for deciding when to remove boxes based
- * on score. Defaults to -inf, which means any score is accepted.
- * @param padToMaxOutputSize Defalts to false. If true, size of output
- * `selectedIndices` is padded to maxOutputSize.
- * @return A map with the following properties:
- * - selectedIndices: A 1D tensor with the selected box indices.
- * - validOutputs: A scalar denoting how many elements in `selectedIndices`
- * are valid. Valid elements occur first, then padding.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- async function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
- const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
- const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
- const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
- const $maxOutputSize = params.maxOutputSize;
- const $iouThreshold = params.iouThreshold;
- const $scoreThreshold = params.scoreThreshold;
- const [boxesVals, scoresVals] = await Promise.all([$boxes.data(), $scores.data()]);
- // We call a cpu based impl directly with the typedarray data here rather
- // than a kernel because all kernels are synchronous (and thus cannot await
- // .data()).
- const res = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize);
- if ($boxes !== boxes) {
- $boxes.dispose();
- }
- if ($scores !== scores) {
- $scores.dispose();
- }
- return res;
- }
- const nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Bilinear resize a batch of 3D images to a new shape.
- *
- * @param images The images, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param size The new shape `[newHeight, newWidth]` to resize the
- * images to. Each channel is resized individually.
- * @param alignCorners Defaults to False. If true, rescale
- * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
- * corners of images and resized images. If false, rescale by
- * `new_height / height`. Treat similarly the width dimension.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function resizeBilinear_(images, size, alignCorners = false) {
- const $images = convertToTensor(images, 'images', 'resizeBilinear');
- assert($images.rank === 3 || $images.rank === 4, () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` +
- `rank ${$images.rank}.`);
- assert(size.length === 2, () => `Error in resizeBilinear: new shape must 2D, but got shape ` +
- `${size}.`);
- let batchImages = $images;
- let reshapedTo4D = false;
- if ($images.rank === 3) {
- reshapedTo4D = true;
- batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
- }
- const [newHeight, newWidth] = size;
- const forward = (backend, save) => {
- save([batchImages]);
- return backend.resizeBilinear(batchImages, newHeight, newWidth, alignCorners);
- };
- const inputs = { images: batchImages };
- const attrs = { alignCorners, size };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* gradient */, ResizeBilinear, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const resizeBilinear = op({ resizeBilinear_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * NearestNeighbor resize a batch of 3D images to a new shape.
- *
- * @param images The images, of rank 4 or rank 3, of shape
- * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
- * @param size The new shape `[newHeight, newWidth]` to resize the
- * images to. Each channel is resized individually.
- * @param alignCorners Defaults to False. If true, rescale
- * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
- * corners of images and resized images. If false, rescale by
- * `new_height / height`. Treat similarly the width dimension.
- *
- * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
- */
- function resizeNearestNeighbor_(images, size, alignCorners = false) {
- const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
- assert($images.rank === 3 || $images.rank === 4, () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` +
- `rank ${$images.rank}.`);
- assert(size.length === 2, () => `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` +
- `${size}.`);
- assert($images.dtype === 'float32' || $images.dtype === 'int32', () => '`images` must have `int32` or `float32` as dtype');
- let batchImages = $images;
- let reshapedTo4D = false;
- if ($images.rank === 3) {
- reshapedTo4D = true;
- batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
- }
- const [newHeight, newWidth] = size;
- const inputs = { images: batchImages };
- const attrs = { alignCorners, size };
- const forward = (backend, save) => {
- save([batchImages]);
- return backend.resizeNearestNeighbor(batchImages, newHeight, newWidth, alignCorners);
- };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* gradient */, ResizeNearestNeighbor, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const resizeNearestNeighbor = op({ resizeNearestNeighbor_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Copy a tensor setting everything outside a central band in each innermost
- * matrix to zero.
- *
- * The band part is computed as follows: Assume input has `k` dimensions
- * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
- * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
- * The indicator function
- * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
- * `&& (num_upper < 0 || (n-m) <= num_upper)`
- *
- * ```js
- * const x = tf.tensor2d([[ 0, 1, 2, 3],
- * [-1, 0, 1, 2],
- * [-2, -1, 0, 1],
- * [-3, -2, -1, 0]]);
- * let y = tf.linalg.bandPart(x, 1, -1);
- * y.print(); // [[ 0, 1, 2, 3],
- * // [-1, 0, 1, 2],
- * // [ 0, -1, 0, 1],
- * // [ 0, 0 , -1, 0]]
- * let z = tf.linalg.bandPart(x, 2, 1);
- * z.print(); // [[ 0, 1, 0, 0],
- * // [-1, 0, 1, 0],
- * // [-2, -1, 0, 1],
- * // [ 0, -2, -1, 0]]
- * ```
- *
- * @param x Rank `k` tensor
- * @param numLower Number of subdiagonals to keep.
- * If negative, keep entire lower triangle.
- * @param numUpper Number of subdiagonals to keep.
- * If negative, keep entire upper triangle.
- * @returns Rank `k` tensor of the same shape as input.
- * The extracted banded tensor.
- *
- * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
- */
- function bandPart_(a, numLower, numUpper) {
- assert(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
- assert(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
- const $a = convertToTensor(a, 'a', 'bandPart');
- assert($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
- const shape = $a.shape;
- const [M, N] = $a.shape.slice(-2);
- if (!(numLower <= M)) {
- throw new Error(`bandPart(): numLower (${numLower})` +
- ` must not be greater than the number of rows (${M}).`);
- }
- if (!(numUpper <= N)) {
- throw new Error(`bandPart(): numUpper (${numUpper})` +
- ` must not be greater than the number of columns (${N}).`);
- }
- if (numLower < 0) {
- numLower = M;
- }
- if (numUpper < 0) {
- numUpper = N;
- }
- const i = reshape(range(0, M, 1, 'int32'), [-1, 1]);
- const j = range(0, N, 1, 'int32');
- const ij = sub(i, j);
- const inBand = logicalAnd(lessEqual(ij, scalar(+numLower, 'int32')), greaterEqual(ij, scalar(-numUpper, 'int32')));
- const zero = zeros([M, N], $a.dtype);
- return reshape(stack(unstack(reshape($a, [-1, M, N]))
- .map(mat => where(inBand, mat, zero))), shape);
- }
- const bandPart = op({ bandPart_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gram-Schmidt orthogonalization.
- *
- * ```js
- * const x = tf.tensor2d([[1, 2], [3, 4]]);
- * let y = tf.linalg.gramSchmidt(x);
- * y.print();
- * console.log('Othogonalized:');
- * y.dot(y.transpose()).print(); // should be nearly the identity matrix.
- * console.log('First row direction maintained:');
- * const data = await y.array();
- * console.log(data[0][1] / data[0][0]); // should be nearly 2.
- * ```
- *
- * @param xs The vectors to be orthogonalized, in one of the two following
- * formats:
- * - An Array of `tf.Tensor1D`.
- * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
- * of `xs`.
- * In each case, all the vectors must have the same length and the length
- * must be greater than or equal to the number of vectors.
- * @returns The orthogonalized and normalized vectors or matrix.
- * Orthogonalization means that the vectors or the rows of the matrix
- * are orthogonal (zero inner products). Normalization means that each
- * vector or each row of the matrix has an L2 norm that equals `1`.
- *
- * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
- */
- function gramSchmidt_(xs) {
- let inputIsTensor2D;
- if (Array.isArray(xs)) {
- inputIsTensor2D = false;
- assert(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
- 'empty');
- const dim = xs[0].shape[0];
- for (let i = 1; i < xs.length; ++i) {
- assert(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
- `(${xs[i].shape[0]} vs. ${dim})`);
- }
- }
- else {
- inputIsTensor2D = true;
- xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
- }
- assert(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
- `number of dimensions (${xs[0].shape[0]}).`);
- const ys = [];
- const xs1d = xs;
- for (let i = 0; i < xs.length; ++i) {
- ys.push(ENGINE.tidy(() => {
- let x = xs1d[i];
- if (i > 0) {
- for (let j = 0; j < i; ++j) {
- const proj = mul(sum$1(mul(ys[j], x)), ys[j]);
- x = sub(x, proj);
- }
- }
- return div(x, norm(x, 'euclidean'));
- }));
- }
- if (inputIsTensor2D) {
- return stack(ys, 0);
- }
- else {
- return ys;
- }
- }
- const gramSchmidt = op({ gramSchmidt_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Compute QR decomposition of m-by-n matrix using Householder transformation.
- *
- * Implementation based on
- * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
- * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
- *
- * ```js
- * const a = tf.tensor2d([[1, 2], [3, 4]]);
- * let [q, r] = tf.linalg.qr(a);
- * console.log('Q');
- * q.print();
- * console.log('R');
- * r.print();
- * console.log('Orthogonalized');
- * q.dot(q.transpose()).print() // should be nearly the identity matrix.
- * console.log('Reconstructed');
- * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
- * ```
- *
- * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
- * it has the shape `[..., M, N]`.
- * @param fullMatrices An optional boolean parameter. Defaults to `false`.
- * If `true`, compute full-sized `Q`. If `false` (the default),
- * compute only the leading N columns of `Q` and `R`.
- * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
- * i.e., its columns all have unit norm and are mutually orthogonal.
- * If `M >= N`,
- * If `fullMatrices` is `false` (default),
- * - `Q` has a shape of `[..., M, N]`,
- * - `R` has a shape of `[..., N, N]`.
- * If `fullMatrices` is `true` (default),
- * - `Q` has a shape of `[..., M, M]`,
- * - `R` has a shape of `[..., M, N]`.
- * If `M < N`,
- * - `Q` has a shape of `[..., M, M]`,
- * - `R` has a shape of `[..., M, N]`.
- * @throws If the rank of `x` is less than 2.
- *
- * @doc {heading:'Operations',
- * subheading:'Linear Algebra',
- * namespace:'linalg'}
- */
- function qr_(x, fullMatrices = false) {
- assert(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`);
- if (x.rank === 2) {
- return qr2d(x, fullMatrices);
- }
- else {
- // Rank > 2.
- // TODO(cais): Below we split the input into individual 2D tensors,
- // perform QR decomposition on them and then stack the results back
- // together. We should explore whether this can be parallelized.
- const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
- .reduce((value, prev) => value * prev);
- const x2ds = unstack(reshape(x, [
- outerDimsProd, x.shape[x.shape.length - 2],
- x.shape[x.shape.length - 1]
- ]), 0);
- const q2ds = [];
- const r2ds = [];
- x2ds.forEach(x2d => {
- const [q2d, r2d] = qr2d(x2d, fullMatrices);
- q2ds.push(q2d);
- r2ds.push(r2d);
- });
- const q = reshape(stack(q2ds, 0), x.shape);
- const r = reshape(stack(r2ds, 0), x.shape);
- return [q, r];
- }
- }
- function qr2d(x, fullMatrices = false) {
- return ENGINE.tidy(() => {
- assert(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
- const m = x.shape[0];
- const n = x.shape[1];
- let q = eye(m); // Orthogonal transform so far.
- let r = clone(x); // Transformed matrix so far.
- const one2D = tensor2d([[1]], [1, 1]);
- let w = clone(one2D);
- const iters = m >= n ? n : m;
- for (let j = 0; j < iters; ++j) {
- // This tidy within the for-loop ensures we clean up temporary
- // tensors as soon as they are no longer needed.
- const rTemp = r;
- const wTemp = w;
- const qTemp = q;
- [w, r, q] = ENGINE.tidy(() => {
- // Find H = I - tau * w * w', to put zeros below R(j, j).
- const rjEnd1 = slice(r, [j, j], [m - j, 1]);
- const normX = norm(rjEnd1);
- const rjj = slice(r, [j, j], [1, 1]);
- // The sign() function returns 0 on 0, which causes division by zero.
- const s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
- const u1 = sub(rjj, mul(s, normX));
- const wPre = div(rjEnd1, u1);
- if (wPre.shape[0] === 1) {
- w = clone(one2D);
- }
- else {
- w = concat([
- one2D,
- slice(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
- ], 0);
- }
- const tau = neg(div(matMul(s, u1), normX));
- // -- R := HR, Q := QH.
- const rjEndAll = slice(r, [j, 0], [m - j, n]);
- const tauTimesW = mul(tau, w);
- const wT = transpose(w);
- if (j === 0) {
- r = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
- }
- else {
- const rTimesTau = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
- r = concat([slice(r, [0, 0], [j, n]), rTimesTau], 0);
- }
- const tawTimesWT = transpose(tauTimesW);
- const qAllJEnd = slice(q, [0, j], [m, q.shape[1] - j]);
- if (j === 0) {
- q = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
- }
- else {
- const qTimesTau = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
- q = concat([slice(q, [0, 0], [m, j]), qTimesTau], 1);
- }
- return [w, r, q];
- });
- dispose([rTemp, wTemp, qTemp]);
- }
- if (!fullMatrices && m > n) {
- q = slice(q, [0, 0], [m, n]);
- r = slice(r, [0, 0], [n, n]);
- }
- return [q, r];
- });
- }
- const qr = op({ qr_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- (function (Reduction) {
- Reduction[Reduction["NONE"] = 0] = "NONE";
- Reduction[Reduction["MEAN"] = 1] = "MEAN";
- Reduction[Reduction["SUM"] = 2] = "SUM";
- Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
- })(exports.Reduction || (exports.Reduction = {}));
-
- /**
- * Computes the weighted loss between two tensors.
- *
- * @param losses Tensor of shape `[batch_size, d1, ... dN]`.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `losses`, and must be broadcastable to `losses` (i.e., all
- * dimensions must be either `1`, or the same as the corresponding
- * `losses` dimension).
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function computeWeightedLoss_(losses, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
- }
- const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);
- if (reduction === exports.Reduction.NONE) {
- return weightedLoss;
- }
- if (reduction === exports.Reduction.SUM) {
- return sum$1(weightedLoss);
- }
- if (reduction === exports.Reduction.MEAN) {
- if ($weights == null) {
- return mean(weightedLoss);
- }
- else {
- const broadcastFactor = $losses.size / $weights.size;
- const result = div(sum$1(weightedLoss), sum$1($weights));
- return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) :
- result;
- }
- }
- if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- if ($weights == null) {
- return div(sum$1(weightedLoss), scalar($losses.size));
- }
- else {
- const broadcastedWeights = mul($weights, ones$1($losses.shape));
- const numNonZeros = cast(sum$1(notEqual(broadcastedWeights, scalar(0))), 'float32');
- return div(sum$1(weightedLoss), numNonZeros);
- }
- }
- throw Error(`Unknown reduction: ${reduction}`);
- }
- const computeWeightedLoss = op({ computeWeightedLoss_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the absolute difference loss between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function absoluteDifference_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
- const $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'absoluteDifference');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
- const losses = abs(sub($labels, $predictions));
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const absoluteDifference = op({ absoluteDifference_ });
-
- /**
- * Computes the cosine distance loss between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param axis The dimension along which the cosine distance is computed.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function cosineDistance_(labels, predictions, axis, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $labels = convertToTensor(labels, 'labels', 'cosineDistance');
- const $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'cosineDistance');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
- const one = scalar(1);
- const losses = sub(one, sum$1(mul($labels, $predictions), axis, true));
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const cosineDistance = op({ cosineDistance_ });
-
- /**
- * Computes the Hinge loss between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function hingeLoss_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- let $labels = convertToTensor(labels, 'labels', 'hingeLoss');
- const $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'hingeLoss');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
- const one = scalar(1);
- // Convert binary labels to (-1, 1)
- $labels = sub(mul(scalar(2), $labels), one);
- const losses = relu(sub(one, mul($labels, $predictions)));
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const hingeLoss = op({ hingeLoss_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the huber loss between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param delta Point where huber loss changes from quadratic to linear.
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`.
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function huberLoss_(labels, predictions, weights, delta = 1.0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $labels = convertToTensor(labels, 'labels', 'huberLoss');
- const $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'huberLoss');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
- const deltaScalar = scalar(delta);
- const error = abs(sub($predictions, $labels));
- const quadratic = minimum(error, deltaScalar);
- const linear = sub(error, quadratic);
- const losses = add$1(mul(scalar(0.5), square(quadratic)), mul(deltaScalar, linear));
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const huberLoss = op({ huberLoss_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the log loss between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param epsilon A small increment to avoid taking log of zero
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function logLoss_(labels, predictions, weights, epsilon = 1e-7, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $labels = convertToTensor(labels, 'labels', 'logLoss');
- const $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'logLoss');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
- const one = scalar(1);
- const epsilonScalar = scalar(epsilon);
- const l1 = neg(mul($labels, log(add$1($predictions, epsilonScalar))));
- const l2 = mul(sub(one, $labels), log(add$1(sub(one, $predictions), epsilonScalar)));
- const losses = sub(l1, l2);
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const logLoss = op({ logLoss_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the mean squared error between two tensors.
- *
- * @param labels The ground truth output tensor, same dimensions as
- * 'predictions'.
- * @param predictions The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
- */
- function meanSquaredError_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- const $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
- const $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'meanSquaredError');
- }
- assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
- const losses = squaredDifference($labels, $predictions);
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const meanSquaredError = op({ meanSquaredError_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function sigmoidCrossEntropyWithLogits_(labels, logits) {
- const $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
- const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
- assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
- /**
- * Implementation Details:
- *
- * For brevity, let `x = logits`, `z = labels`. The logistic loss is
- * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
- * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
- * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
- * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
- * = (1 - z) * x + log(1 + exp(-x))
- * = x - x * z + log(1 + exp(-x))
- *
- * For x < 0, to avoid overflow in exp(-x), we reformulate the above
- * x - x * z + log(1 + exp(-x))
- * = log(exp(x)) - x * z + log(1 + exp(-x))
- * = - x * z + log(1 + exp(x))
- *
- * Hence, to ensure stability and avoid overflow, the implementation uses
- * this equivalent formulation:
- * max(x, 0) - x * z + log(1 + exp(-abs(x)))
- */
- const maxOutput = relu($logits);
- const outputXTarget = mul($logits, $labels);
- const sigmoidOutput = log1p(exp(neg(abs($logits))));
- return add$1(sub(maxOutput, outputXTarget), sigmoidOutput);
- }
- /**
- * Computes the sigmoid cross entropy loss between two tensors.
- *
- * If labelSmoothing is nonzero, smooth the labels towards 1/2:
- *
- * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
- * + 0.5 * labelSmoothing
- *
- * @param multiClassLabels The ground truth output tensor of shape
- * [batch_size, num_classes], same dimensions as 'predictions'.
- * @param logits The predicted outputs.
- * @param weights Tensor whose rank is either 0, or the same rank as
- * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
- * must be either `1`, or the same as the corresponding `losses`
- * dimension).
- * @param labelSmoothing If greater than 0, then smooth the labels.
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
- */
- function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- let $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
- const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
- }
- assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
- if (labelSmoothing > 0) {
- const labelSmoothingScalar = scalar(labelSmoothing);
- const one = scalar(1);
- const half = scalar(0.5);
- $multiClassLabels =
- add$1(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
- }
- const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const sigmoidCrossEntropy = op({ sigmoidCrossEntropy_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes softmax cross entropy between logits and labels.
- *
- * Measures the probability error in discrete classification tasks in which
- * the classes are mutually exclusive (each entry is in exactly one class).
- * For example, each CIFAR-10 image is labeled with one and only one label: an
- * image can be a dog or a truck, but not both.
- *
- * `NOTE`: While the classes are mutually exclusive, their probabilities need
- * not be. All that is required is that each row of labels is a valid
- * probability distribution. If they are not, the computation of the gradient
- * will be incorrect.
- *
- * `WARNING`: This op expects unscaled logits, since it performs a softmax on
- * logits internally for efficiency. Do not call this op with the output of
- * softmax, as it will produce incorrect results.
- *
- * logits and labels must have the same shape, e.g. [batch_size, num_classes]
- * and the same dtype.
- * @param labels The labels array.
- * @param logits The logits array.
- * @param dim The dimension softmax would be performed on. Defaults to `-1`
- * which indicates the last dimension.
- */
- function softmaxCrossEntropyWithLogits_(labels, logits, dim = -1) {
- if (dim === -1) {
- dim = logits.rank - 1;
- }
- if (dim !== logits.rank - 1) {
- throw Error(`Softmax cross entropy along a non-last dimension is not yet ` +
- `supported. Labels / logits was rank ${logits.rank} ` +
- `and dim was ${dim}`);
- }
- // Use a custom gradient for numerical stability.
- const customOp = customGrad((labels, logits, save) => {
- // Reference:
- // 1. http://cs231n.github.io/linear-classify/#softmax
- // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
- const keepDims = true;
- const lse = logSumExp(logits, [dim], keepDims);
- const logResult = sub(cast(logits, 'float32'), lse);
- save([labels, logResult]);
- const costVector = neg(mul(logResult, labels));
- const value = sum$1(costVector, [dim]);
- const gradFunc = (dy, saved) => {
- const [labels, logResult] = saved;
- const dyShape = expandShapeToKeepDim(dy.shape, [dim]);
- return [
- mul(reshape(dy, dyShape), sub(cast(labels, 'float32'), exp(logResult))),
- mul(reshape(dy, dyShape), sub(exp(logResult), cast(labels, 'float32'))),
- ];
- };
- return { value, gradFunc };
- });
- return customOp(labels, logits);
- }
- /**
- * Computes the softmax cross entropy loss between two tensors.
- *
- * If labelSmoothing is nonzero, smooth the labels towards 1/2:
- *
- * newOnehotLabels = onehotLabels * (1 - labelSmoothing)
- * + labelSmoothing / numClasses
- *
- * @param onehotLabels One hot encoded labels
- * [batch_size, num_classes], same dimensions as 'predictions'.
- * @param logits The predicted outputs.
- * @param weights Tensor whose rank is either 0, or 1, and must be
- * broadcastable to `loss` of shape [batch_size]
- * @param labelSmoothing If greater than 0, then smooth the labels.
- * @param reduction Type of reduction to apply to loss. Should be of type
- * `Reduction`
- *
- * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
- */
- function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
- let $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
- const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
- let $weights = null;
- if (weights != null) {
- $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
- }
- assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
- if (labelSmoothing > 0) {
- const labelSmoothingScalar = scalar(labelSmoothing);
- const one = scalar(1);
- const numClasses = scalar($onehotLabels.shape[1]);
- $onehotLabels =
- add$1(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
- }
- const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
- return computeWeightedLoss(losses, $weights, reduction);
- }
- const softmaxCrossEntropy = op({ softmaxCrossEntropy_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const spectral = {
- fft,
- ifft,
- rfft,
- irfft
- };
- const signal = {
- hammingWindow,
- hannWindow,
- frame,
- stft,
- };
- const image = {
- flipLeftRight,
- resizeNearestNeighbor,
- resizeBilinear,
- rotateWithOffset,
- cropAndResize,
- nonMaxSuppression,
- nonMaxSuppressionAsync,
- nonMaxSuppressionWithScore,
- nonMaxSuppressionWithScoreAsync,
- nonMaxSuppressionPadded,
- nonMaxSuppressionPaddedAsync
- };
- const linalg = {
- bandPart,
- gramSchmidt,
- qr
- };
- const losses = {
- absoluteDifference,
- computeWeightedLoss,
- cosineDistance,
- hingeLoss,
- huberLoss,
- logLoss,
- meanSquaredError,
- sigmoidCrossEntropy,
- softmaxCrossEntropy
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
- class Optimizer extends Serializable {
- /**
- * Executes `f()` and minimizes the scalar output of `f()` by computing
- * gradients of y with respect to the list of trainable variables provided by
- * `varList`. If no list is provided, it defaults to all trainable variables.
- *
- * @param f The function to execute and whose output to minimize.
- * @param returnCost Whether to return the scalar cost value produced by
- * executing `f()`.
- * @param varList An optional list of variables to update. If specified, only
- * the trainable variables in varList will be updated by minimize. Defaults to
- * all trainable variables.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers'}
- */
- minimize(f, returnCost = false, varList) {
- const { value, grads } = this.computeGradients(f, varList);
- if (varList != null) {
- const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] }));
- this.applyGradients(gradArray);
- }
- else {
- this.applyGradients(grads);
- }
- // Dispose gradients.
- dispose(grads);
- if (returnCost) {
- return value;
- }
- else {
- value.dispose();
- return null;
- }
- }
- /**
- * The number of iterations that this optimizer instance has been invoked for.
- */
- get iterations() {
- if (this.iterations_ == null) {
- this.iterations_ = 0;
- }
- return this.iterations_;
- }
- incrementIterations() {
- this.iterations_ = this.iterations + 1;
- }
- /**
- * Executes f() and computes the gradient of the scalar output of f() with
- * respect to the list of trainable variables provided by `varList`. If no
- * list is provided, it defaults to all trainable variables.
- *
- * @param f The function to execute and whose output to use for computing
- * gradients with respect to variables.
- * @param varList An optional list of variables to compute gradients with
- * respect to. If specified, only the trainable variables in varList will have
- * gradients computed with respect to. Defaults to all trainable variables.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers'}
- */
- computeGradients(f, varList) {
- return variableGrads(f, varList);
- }
- /**
- * Dispose the variables (if any) owned by this optimizer instance.
- */
- dispose() {
- if (this.iterations_ != null) {
- dispose(this.iterations_);
- }
- }
- async saveIterations() {
- if (this.iterations_ == null) {
- this.iterations_ = 0;
- }
- return {
- name: 'iter',
- // TODO(cais): Use 'int64' type when available.
- tensor: scalar(this.iterations_, 'int32')
- };
- }
- async getWeights() {
- throw new Error('getWeights() is not implemented for this optimizer yet.');
- }
- async setWeights(weightValues) {
- throw new Error(`setWeights() is not implemented for this optimizer class ` +
- `${this.getClassName()}`);
- }
- /**
- * Extract the first element of the weight values and set it
- * as the iterations counter variable of this instance of optimizer.
- *
- * @param weightValues
- * @returns Weight values with the first element consumed and excluded.
- */
- async extractIterations(weightValues) {
- this.iterations_ = (await weightValues[0].tensor.data())[0];
- return weightValues.slice(1);
- }
- }
- Object.defineProperty(Optimizer, Symbol.hasInstance, {
- value: (instance) => {
- return instance.minimize != null && instance.computeGradients != null &&
- instance.applyGradients != null;
- }
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doclink Optimizer */
- class AdadeltaOptimizer extends Optimizer {
- constructor(learningRate, rho, epsilon = null) {
- super();
- this.learningRate = learningRate;
- this.rho = rho;
- this.epsilon = epsilon;
- this.accumulatedGrads = [];
- this.accumulatedUpdates = [];
- if (epsilon == null) {
- this.epsilon = ENGINE.backend.epsilon();
- }
- }
- applyGradients(variableGradients) {
- const variableNames = Array.isArray(variableGradients) ?
- variableGradients.map(item => item.name) :
- Object.keys(variableGradients);
- variableNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- const trainable = false;
- if (this.accumulatedGrads[i] == null) {
- this.accumulatedGrads[i] = {
- originalName: `${name}/accum_grad`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- if (this.accumulatedUpdates[i] == null) {
- this.accumulatedUpdates[i] = {
- originalName: `${name}/accum_var`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const accumulatedGrad = this.accumulatedGrads[i].variable;
- const accumulatedUpdate = this.accumulatedUpdates[i].variable;
- tidy(() => {
- const newAccumulatedGrad = add$1(mul(accumulatedGrad, this.rho), mul(square(gradient), 1 - this.rho));
- const updates = mul(div(sqrt(add$1(accumulatedUpdate, this.epsilon)), sqrt(add$1(accumulatedGrad, this.epsilon))), gradient);
- const newAccumulatedUpdate = add$1(mul(accumulatedUpdate, this.rho), mul(square(updates), 1 - this.rho));
- accumulatedGrad.assign(newAccumulatedGrad);
- accumulatedUpdate.assign(newAccumulatedUpdate);
- const newValue = add$1(mul(updates, -this.learningRate), value);
- value.assign(newValue);
- });
- });
- this.incrementIterations();
- }
- dispose() {
- if (this.accumulatedUpdates != null) {
- dispose(this.accumulatedGrads.map(v => v.variable));
- dispose(this.accumulatedUpdates.map(v => v.variable));
- }
- }
- async getWeights() {
- // Order matters for Python compatibility.
- const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates];
- return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- const variableCount = weightValues.length / 2;
- const trainable = false;
- this.accumulatedGrads =
- weightValues.slice(0, variableCount).map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- this.accumulatedUpdates =
- weightValues.slice(variableCount, variableCount * 2)
- .map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'rho': this.rho,
- 'epsilon': this.epsilon
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['rho'], config['epsilon']);
- }
- }
- /** @nocollapse */
- AdadeltaOptimizer.className = 'Adadelta'; // Name matters for Python compatibility.
- registerClass(AdadeltaOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doclink Optimizer */
- class AdagradOptimizer extends Optimizer {
- constructor(learningRate, initialAccumulatorValue = 0.1) {
- super();
- this.learningRate = learningRate;
- this.initialAccumulatorValue = initialAccumulatorValue;
- this.accumulatedGrads = [];
- }
- applyGradients(variableGradients) {
- const variableNames = Array.isArray(variableGradients) ?
- variableGradients.map(item => item.name) :
- Object.keys(variableGradients);
- variableNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- if (this.accumulatedGrads[i] == null) {
- const trainable = false;
- this.accumulatedGrads[i] = {
- originalName: `${name}/accumulator`,
- variable: tidy(() => fill(value.shape, this.initialAccumulatorValue)
- .variable(trainable))
- };
- }
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const accumulatedGrad = this.accumulatedGrads[i].variable;
- tidy(() => {
- const newAccumulatedGrad = add$1(accumulatedGrad, square(gradient));
- accumulatedGrad.assign(newAccumulatedGrad);
- const newValue = add$1(mul(div(gradient, sqrt(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value);
- value.assign(newValue);
- });
- });
- this.incrementIterations();
- }
- dispose() {
- if (this.accumulatedGrads != null) {
- dispose(this.accumulatedGrads.map(v => v.variable));
- }
- }
- async getWeights() {
- // Order matters for Python compatibility.
- return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable })));
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- const trainable = false;
- this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'initialAccumulatorValue': this.initialAccumulatorValue,
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['initialAccumulatorValue']);
- }
- }
- /** @nocollapse */
- AdagradOptimizer.className = 'Adagrad'; // Note: Name matters for Python compatibility.
- registerClass(AdagradOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class AdamOptimizer extends Optimizer {
- constructor(learningRate, beta1, beta2, epsilon = null) {
- super();
- this.learningRate = learningRate;
- this.beta1 = beta1;
- this.beta2 = beta2;
- this.epsilon = epsilon;
- this.accumulatedFirstMoment = [];
- this.accumulatedSecondMoment = [];
- tidy(() => {
- // accB* will be updated by batch.
- this.accBeta1 = scalar(beta1).variable();
- this.accBeta2 = scalar(beta2).variable();
- });
- if (epsilon == null) {
- this.epsilon = ENGINE.backend.epsilon();
- }
- }
- applyGradients(variableGradients) {
- const varNames = Array.isArray(variableGradients) ?
- variableGradients.map(v => v.name) :
- Object.keys(variableGradients);
- tidy(() => {
- const oneMinusAccBeta1 = sub(1, this.accBeta1);
- const oneMinusAccBeta2 = sub(1, this.accBeta2);
- varNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- const trainable = false;
- if (this.accumulatedFirstMoment[i] == null) {
- this.accumulatedFirstMoment[i] = {
- originalName: `${name}/m`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- if (this.accumulatedSecondMoment[i] == null) {
- this.accumulatedSecondMoment[i] = {
- originalName: `${name}/v`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const firstMoment = this.accumulatedFirstMoment[i].variable;
- const secondMoment = this.accumulatedSecondMoment[i].variable;
- const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
- const newSecondMoment = add$1(mul(secondMoment, this.beta2), mul(square(gradient), 1 - this.beta2));
- const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
- const biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
- firstMoment.assign(newFirstMoment);
- secondMoment.assign(newSecondMoment);
- const newValue = add$1(mul(div(biasCorrectedFirstMoment, add$1(sqrt(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
- value.assign(newValue);
- });
- this.accBeta1.assign(mul(this.accBeta1, this.beta1));
- this.accBeta2.assign(mul(this.accBeta2, this.beta2));
- });
- this.incrementIterations();
- }
- dispose() {
- this.accBeta1.dispose();
- this.accBeta2.dispose();
- if (this.accumulatedFirstMoment != null) {
- dispose(this.accumulatedFirstMoment.map(v => v.variable));
- }
- if (this.accumulatedSecondMoment != null) {
- dispose(this.accumulatedSecondMoment.map(v => v.variable));
- }
- }
- async getWeights() {
- // Order matters for Python compatibility.
- const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
- return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- tidy(() => {
- this.accBeta1.assign(pow(this.beta1, this.iterations_ + 1));
- this.accBeta2.assign(pow(this.beta2, this.iterations_ + 1));
- });
- const variableCount = weightValues.length / 2;
- const trainable = false;
- this.accumulatedFirstMoment =
- weightValues.slice(0, variableCount).map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- this.accumulatedSecondMoment =
- weightValues.slice(variableCount, variableCount * 2)
- .map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'beta1': this.beta1,
- 'beta2': this.beta2,
- 'epsilon': this.epsilon,
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
- }
- }
- /** @nocollapse */
- AdamOptimizer.className = 'Adam'; // Note: Name matters for Python compatibility.
- registerClass(AdamOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class AdamaxOptimizer extends Optimizer {
- constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) {
- super();
- this.learningRate = learningRate;
- this.beta1 = beta1;
- this.beta2 = beta2;
- this.epsilon = epsilon;
- this.decay = decay;
- this.accumulatedFirstMoment = [];
- this.accumulatedWeightedInfNorm = [];
- tidy(() => {
- this.iteration = scalar(0).variable();
- this.accBeta1 = scalar(beta1).variable();
- });
- if (epsilon == null) {
- this.epsilon = ENGINE.backend.epsilon();
- }
- }
- applyGradients(variableGradients) {
- const variableNames = Array.isArray(variableGradients) ?
- variableGradients.map(item => item.name) :
- Object.keys(variableGradients);
- tidy(() => {
- const oneMinusAccBeta1 = sub(1, this.accBeta1);
- const lr = div(-this.learningRate, add$1(mul(this.iteration, this.decay), 1));
- variableNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- const trainable = false;
- if (this.accumulatedFirstMoment[i] == null) {
- this.accumulatedFirstMoment[i] = {
- originalName: `${name}/m`,
- variable: zerosLike(value).variable(trainable)
- };
- }
- if (this.accumulatedWeightedInfNorm[i] == null) {
- this.accumulatedWeightedInfNorm[i] = {
- originalName: `${name}/v`,
- variable: zerosLike(value).variable(trainable)
- };
- }
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const firstMoment = this.accumulatedFirstMoment[i].variable;
- const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable;
- const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
- const ut0 = mul(weightedInfNorm, this.beta2);
- const ut1 = abs(gradient);
- const newWeightedInfNorm = maximum(ut0, ut1);
- firstMoment.assign(newFirstMoment);
- weightedInfNorm.assign(newWeightedInfNorm);
- const newValue = add$1(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add$1(newWeightedInfNorm, this.epsilon))), value);
- value.assign(newValue);
- });
- this.iteration.assign(add$1(this.iteration, 1));
- this.accBeta1.assign(mul(this.accBeta1, this.beta1));
- });
- this.incrementIterations();
- }
- dispose() {
- this.accBeta1.dispose();
- this.iteration.dispose();
- if (this.accumulatedFirstMoment != null) {
- dispose(this.accumulatedFirstMoment.map(v => v.variable));
- }
- if (this.accumulatedWeightedInfNorm != null) {
- dispose(this.accumulatedWeightedInfNorm.map(v => v.variable));
- }
- }
- async getWeights() {
- throw new Error('getWeights() is not implemented for Adamax yet.');
- }
- async setWeights(weightValues) {
- throw new Error('setWeights() is not implemented for Adamax yet.');
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'beta1': this.beta1,
- 'beta2': this.beta2,
- 'epsilon': this.epsilon,
- 'decay': this.decay
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
- }
- }
- /** @nocollapse */
- AdamaxOptimizer.className = 'Adamax'; // Note: Name matters for Python compatbility.
- registerClass(AdamaxOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doclink Optimizer */
- class SGDOptimizer extends Optimizer {
- constructor(learningRate) {
- super();
- this.learningRate = learningRate;
- this.setLearningRate(learningRate);
- }
- applyGradients(variableGradients) {
- const varNames = Array.isArray(variableGradients) ?
- variableGradients.map(v => v.name) :
- Object.keys(variableGradients);
- varNames.forEach((name, i) => {
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const value = ENGINE.registeredVariables[name];
- tidy(() => {
- const newValue = add$1(mul(this.c, gradient), value);
- value.assign(newValue);
- });
- });
- this.incrementIterations();
- }
- /**
- * Sets the learning rate of the optimizer.
- */
- setLearningRate(learningRate) {
- this.learningRate = learningRate;
- if (this.c != null) {
- this.c.dispose();
- }
- this.c = keep(scalar(-learningRate));
- }
- dispose() {
- this.c.dispose();
- }
- async getWeights() {
- return [await this.saveIterations()];
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- if (weightValues.length !== 0) {
- throw new Error('SGD optimizer does not have settable weights.');
- }
- }
- getConfig() {
- return { 'learningRate': this.learningRate };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate']);
- }
- }
- /** @nocollapse */
- SGDOptimizer.className = 'SGD'; // Note: Name matters for Python compatibility.
- registerClass(SGDOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doclink Optimizer */
- class MomentumOptimizer extends SGDOptimizer {
- constructor(learningRate, momentum, useNesterov = false) {
- super(learningRate);
- this.learningRate = learningRate;
- this.momentum = momentum;
- this.useNesterov = useNesterov;
- this.accumulations = [];
- this.m = scalar(this.momentum);
- }
- applyGradients(variableGradients) {
- const variableNames = Array.isArray(variableGradients) ?
- variableGradients.map(item => item.name) :
- Object.keys(variableGradients);
- variableNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- if (this.accumulations[i] == null) {
- const trainable = false;
- this.accumulations[i] = {
- originalName: `${name}/momentum`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- const accumulation = this.accumulations[i].variable;
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- tidy(() => {
- let newValue;
- const newAccumulation = add$1(mul(this.m, accumulation), gradient);
- if (this.useNesterov) {
- newValue = add$1(mul(this.c, add$1(gradient, mul(newAccumulation, this.m))), value);
- }
- else {
- newValue = add$1(mul(this.c, newAccumulation), value);
- }
- accumulation.assign(newAccumulation);
- value.assign(newValue);
- });
- });
- this.incrementIterations();
- }
- dispose() {
- this.m.dispose();
- if (this.accumulations != null) {
- dispose(this.accumulations.map(v => v.variable));
- }
- }
- /**
- * Sets the momentum of the optimizer.
- *
- * @param momentum
- */
- setMomentum(momentum) {
- this.momentum = momentum;
- }
- async getWeights() {
- // Order matters for Python compatibility.
- return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable })));
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- const trainable = false;
- this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'momentum': this.momentum,
- 'useNesterov': this.useNesterov
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
- }
- }
- /** @nocollapse */
- MomentumOptimizer.className = 'Momentum'; // Name matters for Python compatibility.
- registerClass(MomentumOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** @doclink Optimizer */
- class RMSPropOptimizer extends Optimizer {
- constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) {
- super();
- this.learningRate = learningRate;
- this.decay = decay;
- this.momentum = momentum;
- this.epsilon = epsilon;
- this.accumulatedMeanSquares = [];
- this.accumulatedMoments = [];
- this.accumulatedMeanGrads = [];
- this.centered = centered;
- if (epsilon == null) {
- this.epsilon = ENGINE.backend.epsilon();
- }
- if (learningRate == null) {
- throw new Error(`learningRate for RMSPropOptimizer must be defined.`);
- }
- }
- applyGradients(variableGradients) {
- const variableNames = Array.isArray(variableGradients) ?
- variableGradients.map(item => item.name) :
- Object.keys(variableGradients);
- variableNames.forEach((name, i) => {
- const value = ENGINE.registeredVariables[name];
- const trainable = false;
- if (this.accumulatedMeanSquares[i] == null) {
- this.accumulatedMeanSquares[i] = {
- originalName: `${name}/rms`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- if (this.accumulatedMoments[i] == null) {
- this.accumulatedMoments[i] = {
- originalName: `${name}/momentum`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- if (this.accumulatedMeanGrads[i] == null && this.centered) {
- this.accumulatedMeanGrads[i] = {
- originalName: `${name}/mg`,
- variable: tidy(() => zerosLike(value).variable(trainable))
- };
- }
- const gradient = Array.isArray(variableGradients) ?
- variableGradients[i].tensor :
- variableGradients[name];
- if (gradient == null) {
- return;
- }
- const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable;
- const accumulatedMoments = this.accumulatedMoments[i].variable;
- tidy(() => {
- const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square(gradient), 1 - this.decay));
- if (this.centered) {
- const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable;
- // Centered gradient
- const newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay));
- const gradContribution = div(mul(gradient, this.learningRate), sqrt(sub(newAccumulatedMeanSquare, add$1(square(newAccumulatedMeanGrad), this.epsilon))));
- const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), gradContribution);
- accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
- accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
- accumulatedMoments.assign(newAccumulatedMoments);
- const newValue = sub(value, newAccumulatedMoments);
- value.assign(newValue);
- }
- else {
- // Plain gradient
- const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square(gradient), 1 - this.decay));
- const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), div(mul(gradient, this.learningRate), sqrt(add$1(newAccumulatedMeanSquare, this.epsilon))));
- accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
- accumulatedMoments.assign(newAccumulatedMoments);
- const newValue = sub(value, newAccumulatedMoments);
- value.assign(newValue);
- }
- });
- });
- this.incrementIterations();
- }
- dispose() {
- if (this.accumulatedMeanSquares != null) {
- dispose(this.accumulatedMeanSquares.map(v => v.variable));
- }
- if (this.accumulatedMeanGrads != null && this.centered) {
- dispose(this.accumulatedMeanGrads.map(v => v.variable));
- }
- if (this.accumulatedMoments != null) {
- dispose(this.accumulatedMoments.map(v => v.variable));
- }
- }
- async getWeights() {
- // Order matters for Python compatibility.
- const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
- if (this.centered) {
- variables.push(...this.accumulatedMeanGrads);
- }
- return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
- }
- async setWeights(weightValues) {
- weightValues = await this.extractIterations(weightValues);
- const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
- const trainable = false;
- this.accumulatedMeanSquares =
- weightValues.slice(0, variableCount).map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- this.accumulatedMoments =
- weightValues.slice(variableCount, variableCount * 2)
- .map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- if (this.centered) {
- this.accumulatedMeanGrads =
- weightValues.slice(variableCount * 2, variableCount * 3)
- .map(v => ({
- originalName: v.name,
- variable: v.tensor.variable(trainable)
- }));
- }
- }
- getConfig() {
- return {
- 'learningRate': this.learningRate,
- 'decay': this.decay,
- 'momentum': this.momentum,
- 'epsilon': this.epsilon,
- 'centered': this.centered
- };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
- }
- }
- /** @nocollapse */
- RMSPropOptimizer.className = 'RMSProp'; // Note: Name matters for Python compatibility.
- registerClass(RMSPropOptimizer);
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class OptimizerConstructors {
- /**
- * Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
- *
- * ```js
- * // Fit a quadratic function by learning the coefficients a, b, c.
- * const xs = tf.tensor1d([0, 1, 2, 3]);
- * const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
- *
- * const a = tf.scalar(Math.random()).variable();
- * const b = tf.scalar(Math.random()).variable();
- * const c = tf.scalar(Math.random()).variable();
- *
- * // y = a * x^2 + b * x + c.
- * const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
- * const loss = (pred, label) => pred.sub(label).square().mean();
- *
- * const learningRate = 0.01;
- * const optimizer = tf.train.sgd(learningRate);
- *
- * // Train the model.
- * for (let i = 0; i < 10; i++) {
- * optimizer.minimize(() => loss(f(xs), ys));
- * }
- *
- * // Make predictions.
- * console.log(
- * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
- * const preds = f(xs).dataSync();
- * preds.forEach((pred, i) => {
- * console.log(`x: ${i}, pred: ${pred}`);
- * });
- * ```
- *
- * @param learningRate The learning rate to use for the SGD algorithm.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static sgd(learningRate) {
- return new SGDOptimizer(learningRate);
- }
- /**
- * Constructs a `tf.MomentumOptimizer` that uses momentum gradient
- * descent.
- *
- * See
- * [http://proceedings.mlr.press/v28/sutskever13.pdf](
- * http://proceedings.mlr.press/v28/sutskever13.pdf)
- *
- * @param learningRate The learning rate to use for the Momentum gradient
- * descent algorithm.
- * @param momentum The momentum to use for the momentum gradient descent
- * algorithm.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static momentum(learningRate, momentum, useNesterov = false) {
- return new MomentumOptimizer(learningRate, momentum, useNesterov);
- }
- /**
- * Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
- * descent. This implementation uses plain momentum and is not centered
- * version of RMSProp.
- *
- * See
- * [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
- * http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
- *
- * @param learningRate The learning rate to use for the RMSProp gradient
- * descent algorithm.
- * @param decay The discounting factor for the history/coming gradient.
- * @param momentum The momentum to use for the RMSProp gradient descent
- * algorithm.
- * @param epsilon Small value to avoid zero denominator.
- * @param centered If true, gradients are normalized by the estimated
- * variance of the gradient.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) {
- return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
- }
- /**
- * Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
- * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
- *
- * @param learningRate The learning rate to use for the Adam gradient
- * descent algorithm.
- * @param beta1 The exponential decay rate for the 1st moment estimates.
- * @param beta2 The exponential decay rate for the 2nd moment estimates.
- * @param epsilon A small constant for numerical stability.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) {
- return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
- }
- /**
- * Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
- * See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
- *
- * @param learningRate The learning rate to use for the Adadelta gradient
- * descent algorithm.
- * @param rho The learning rate decay over each update.
- * @param epsilon A constant epsilon used to better condition the grad
- * update.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static adadelta(learningRate = .001, rho = .95, epsilon = null) {
- return new AdadeltaOptimizer(learningRate, rho, epsilon);
- }
- /**
- * Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
- * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
- *
- * @param learningRate The learning rate to use for the Adamax gradient
- * descent algorithm.
- * @param beta1 The exponential decay rate for the 1st moment estimates.
- * @param beta2 The exponential decay rate for the 2nd moment estimates.
- * @param epsilon A small constant for numerical stability.
- * @param decay The learning rate decay over each update.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) {
- return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
- }
- /**
- * Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
- * See
- * [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
- * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
- * or
- * [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
- * http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
- *
- * @param learningRate The learning rate to use for the Adagrad gradient
- * descent algorithm.
- * @param initialAccumulatorValue Starting value for the accumulators, must be
- * positive.
- *
- * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
- */
- static adagrad(learningRate, initialAccumulatorValue = 0.1) {
- return new AdagradOptimizer(learningRate, initialAccumulatorValue);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // tslint:disable-next-line:no-unused-expression
- [MomentumOptimizer, SGDOptimizer, AdadeltaOptimizer, AdagradOptimizer,
- RMSPropOptimizer, AdamaxOptimizer, AdamOptimizer];
- const train = {
- sgd: OptimizerConstructors.sgd,
- momentum: OptimizerConstructors.momentum,
- adadelta: OptimizerConstructors.adadelta,
- adagrad: OptimizerConstructors.adagrad,
- rmsprop: OptimizerConstructors.rmsprop,
- adamax: OptimizerConstructors.adamax,
- adam: OptimizerConstructors.adam
- };
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const delayCallback = (() => {
- if (typeof requestAnimationFrame !== 'undefined') {
- return requestAnimationFrame;
- }
- else if (typeof setImmediate !== 'undefined') {
- return setImmediate;
- }
- return (f) => f(); // no delays
- })();
- /**
- * Returns a promise that resolve when a requestAnimationFrame has completed.
- *
- * On Node.js this uses setImmediate instead of requestAnimationFrame.
- *
- * This is simply a sugar method so that users can do the following:
- * `await tf.nextFrame();`
- *
- * @doc {heading: 'Performance', subheading: 'Timing'}
- */
- function nextFrame() {
- return new Promise(resolve => delayCallback(() => resolve()));
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Returns the image center in pixels.
- function getImageCenter(center, imageHeight, imageWidth) {
- const centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
- const centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
- return [centerX, centerY];
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gets the new shape of the input Tensor after it's been reshaped
- * to:
- * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
- * inputShape[1], ..., inputShape[N-1]]
- *
- * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getReshaped(inputShape, blockShape, prod, batchToSpace = true) {
- let reshaped = [];
- if (batchToSpace) {
- reshaped = reshaped.concat(blockShape.slice(0));
- reshaped.push(inputShape[0] / prod);
- reshaped = reshaped.concat(inputShape.slice(1));
- }
- else {
- reshaped = reshaped.concat(inputShape[0]);
- const spatialLength = blockShape.length;
- for (let i = 0; i < spatialLength; ++i) {
- reshaped =
- reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
- }
- reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
- }
- return reshaped;
- }
- /**
- * Gets the permutation that will transpose the dimensions of the
- * reshaped tensor to shape:
- *
- * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
- * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) {
- const permuted = [];
- if (batchToSpace) {
- permuted.push(blockShapeRank);
- for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
- if (i <= 2 * blockShapeRank) {
- permuted.push(i);
- permuted.push(i - (blockShapeRank + 1));
- }
- else {
- permuted.push(i);
- }
- }
- }
- else {
- const permutedBeforeBatch = [];
- const permutedAfterBatch = [];
- for (let i = 1; i < reshapedRank; ++i) {
- if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
- permutedAfterBatch.push(i);
- }
- else {
- permutedBeforeBatch.push(i);
- }
- }
- permuted.push(...permutedBeforeBatch);
- permuted.push(0);
- permuted.push(...permutedAfterBatch);
- }
- return permuted;
- }
- /**
- * Gets the shape of the reshaped and permuted input Tensor before any cropping
- * is applied. The new shape will be:
- *
- * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
- * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) {
- const reshapedPermuted = [];
- if (batchToSpace) {
- reshapedPermuted.push(inputShape[0] / prod);
- }
- else {
- reshapedPermuted.push(inputShape[0] * prod);
- }
- for (let i = 1; i < inputShape.length; ++i) {
- if (i <= blockShape.length) {
- if (batchToSpace) {
- reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
- }
- else {
- reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
- }
- }
- else {
- reshapedPermuted.push(inputShape[i]);
- }
- }
- return reshapedPermuted;
- }
- /**
- * Converts the crops argument into the beginning coordinates of a slice
- * operation.
- */
- function getSliceBeginCoords(crops, blockShape) {
- const sliceBeginCoords = [0];
- for (let i = 0; i < blockShape; ++i) {
- sliceBeginCoords.push(crops[i][0]);
- }
- return sliceBeginCoords;
- }
- /**
- * Converts the crops argument into the size of a slice operation. When
- * combined with getSliceBeginCoords this function allows the reshaped and
- * permuted Tensor to be cropped to its final output shape of:
- *
- * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
- * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
- * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getSliceSize(uncroppedShape, crops, blockShape) {
- const sliceSize = uncroppedShape.slice(0, 1);
- for (let i = 0; i < blockShape; ++i) {
- sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
- }
- return sliceSize;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
- const SELU_SCALE = 1.0507009873554804934193349852946;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ERF_P = 0.3275911;
- const ERF_A1 = 0.254829592;
- const ERF_A2 = -0.284496736;
- const ERF_A3 = 1.421413741;
- const ERF_A4 = -1.453152027;
- const ERF_A5 = 1.061405429;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function warn(...msg) {
- if (!env().getBool('IS_TEST')) {
- console.warn(...msg);
- }
- }
- function log$1(...msg) {
- if (!env().getBool('IS_TEST')) {
- console.log(...msg);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Merges real and imaginary Float32Arrays into a single complex Float32Array.
- *
- * The memory layout is interleaved as follows:
- * real: [r0, r1, r2]
- * imag: [i0, i1, i2]
- * complex: [r0, i0, r1, i1, r2, i2]
- *
- * This is the inverse of splitRealAndImagArrays.
- *
- * @param real The real values of the complex tensor values.
- * @param imag The imag values of the complex tensor values.
- * @returns A complex tensor as a Float32Array with merged values.
- */
- function mergeRealAndImagArrays(real, imag) {
- if (real.length !== imag.length) {
- throw new Error(`Cannot merge real and imag arrays of different lengths. real:` +
- `${real.length}, imag: ${imag.length}.`);
- }
- const result = new Float32Array(real.length * 2);
- for (let i = 0; i < result.length; i += 2) {
- result[i] = real[i / 2];
- result[i + 1] = imag[i / 2];
- }
- return result;
- }
- /**
- * Splits a complex Float32Array into real and imag parts.
- *
- * The memory layout is interleaved as follows:
- * complex: [r0, i0, r1, i1, r2, i2]
- * real: [r0, r1, r2]
- * imag: [i0, i1, i2]
- *
- * This is the inverse of mergeRealAndImagArrays.
- *
- * @param complex The complex tensor values.
- * @returns An object with real and imag Float32Array components of the complex
- * tensor.
- */
- function splitRealAndImagArrays(complex) {
- const real = new Float32Array(complex.length / 2);
- const imag = new Float32Array(complex.length / 2);
- for (let i = 0; i < complex.length; i += 2) {
- real[i / 2] = complex[i];
- imag[i / 2] = complex[i + 1];
- }
- return { real, imag };
- }
- /**
- * Extracts even indexed complex values in the given array.
- * @param complex The complex tensor values
- */
- function complexWithEvenIndex(complex) {
- const len = Math.ceil(complex.length / 4);
- const real = new Float32Array(len);
- const imag = new Float32Array(len);
- for (let i = 0; i < complex.length; i += 4) {
- real[Math.floor(i / 4)] = complex[i];
- imag[Math.floor(i / 4)] = complex[i + 1];
- }
- return { real, imag };
- }
- /**
- * Extracts odd indexed comple values in the given array.
- * @param complex The complex tensor values
- */
- function complexWithOddIndex(complex) {
- const len = Math.floor(complex.length / 4);
- const real = new Float32Array(len);
- const imag = new Float32Array(len);
- for (let i = 2; i < complex.length; i += 4) {
- real[Math.floor(i / 4)] = complex[i];
- imag[Math.floor(i / 4)] = complex[i + 1];
- }
- return { real, imag };
- }
- /**
- * Get the map representing a complex value in the given array.
- * @param complex The complex tensor values.
- * @param index An index of the target complex value.
- */
- function getComplexWithIndex(complex, index) {
- const real = complex[index * 2];
- const imag = complex[index * 2 + 1];
- return { real, imag };
- }
- /**
- * Insert a given complex value into the TypedArray.
- * @param data The array in which the complex value is inserted.
- * @param c The complex value to be inserted.
- * @param index An index of the target complex value.
- */
- function assignToTypedArray(data, real, imag, index) {
- data[index * 2] = real;
- data[index * 2 + 1] = imag;
- }
- /**
- * Make the list of exponent terms used by FFT.
- */
- function exponents(n, inverse) {
- const real = new Float32Array(n / 2);
- const imag = new Float32Array(n / 2);
- for (let i = 0; i < Math.ceil(n / 2); i++) {
- const x = (inverse ? 2 : -2) * Math.PI * (i / n);
- real[i] = Math.cos(x);
- imag[i] = Math.sin(x);
- }
- return { real, imag };
- }
- /**
- * Make the exponent term used by FFT.
- */
- function exponent(k, n, inverse) {
- const x = (inverse ? 2 : -2) * Math.PI * (k / n);
- const real = Math.cos(x);
- const imag = Math.sin(x);
- return { real, imag };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function castTensor(x, dtype, backend) {
- if (dtype === 'complex64') {
- if (x.dtype === 'complex64') {
- return x.clone();
- }
- const zerosTensor = zeros(x.shape);
- const floatX = cast(x, 'float32');
- const result = backend.complex(floatX, zerosTensor);
- zerosTensor.dispose();
- floatX.dispose();
- return result;
- }
- if (!hasEncodingLoss(x.dtype, dtype)) {
- // We don't change the underlying data, since we cast to higher
- // precision.
- return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype);
- }
- if (x.dtype === 'complex64') {
- const real = backend.real(x);
- const result = cast(real, dtype);
- real.dispose();
- return result;
- }
- if (dtype === 'int32') {
- return backend.int(x);
- }
- else if (dtype === 'bool') {
- const zero = scalar(0, x.dtype);
- const result = backend.notEqual(x, zero);
- zero.dispose();
- return result;
- }
- else {
- throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
- }
- }
- function reshapeTensor(x, shape) {
- return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype);
- }
- function linspaceImpl(start, stop, num) {
- const step = (stop - start) / (num - 1);
- const values = makeZerosTypedArray(num, 'float32');
- values[0] = start;
- for (let i = 1; i < values.length; i++) {
- values[i] = values[i - 1] + step;
- }
- return tensor1d(values, 'float32');
- }
-
- var backend_util = /*#__PURE__*/Object.freeze({
- __proto__: null,
- slice_util: slice_util,
- segment_util: segment_util,
- castTensor: castTensor,
- reshapeTensor: reshapeTensor,
- linspaceImpl: linspaceImpl,
- upcastType: upcastType,
- axesAreInnerMostDims: axesAreInnerMostDims,
- combineLocations: combineLocations,
- computeOutAndReduceShapes: computeOutAndReduceShapes,
- expandShapeToKeepDim: expandShapeToKeepDim,
- assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
- getAxesPermutation: getAxesPermutation,
- getUndoAxesPermutation: getUndoAxesPermutation,
- getInnerMostAxes: getInnerMostAxes,
- getBroadcastDims: getBroadcastDims,
- getReductionAxes: getReductionAxes,
- assertAndGetBroadcastShape: assertAndGetBroadcastShape,
- assertParamsConsistent: assertParamsConsistent,
- computeOutShape: computeOutShape$1,
- computeDilation2DInfo: computeDilation2DInfo,
- computePool2DInfo: computePool2DInfo,
- computePool3DInfo: computePool3DInfo,
- computeConv2DInfo: computeConv2DInfo,
- computeConv3DInfo: computeConv3DInfo,
- computeDefaultPad: computeDefaultPad,
- tupleValuesAreOne: tupleValuesAreOne,
- eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
- convertConv2DDataFormat: convertConv2DDataFormat,
- getFusedDyActivation: getFusedDyActivation,
- getFusedBiasGradient: getFusedBiasGradient,
- applyActivation: applyActivation,
- shouldFuse: shouldFuse,
- PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
- computeOptimalWindowSize: computeOptimalWindowSize,
- getImageCenter: getImageCenter,
- getReshaped: getReshaped,
- getPermuted: getPermuted,
- getReshapedPermuted: getReshapedPermuted,
- getSliceBeginCoords: getSliceBeginCoords,
- getSliceSize: getSliceSize,
- prepareAndValidate: prepareAndValidate,
- validateUpdateShape: validateUpdateShape,
- validateInput: validateInput,
- calculateShapes: calculateShapes,
- SELU_SCALEALPHA: SELU_SCALEALPHA,
- SELU_SCALE: SELU_SCALE,
- ERF_P: ERF_P,
- ERF_A1: ERF_A1,
- ERF_A2: ERF_A2,
- ERF_A3: ERF_A3,
- ERF_A4: ERF_A4,
- ERF_A5: ERF_A5,
- warn: warn,
- log: log$1,
- mergeRealAndImagArrays: mergeRealAndImagArrays,
- splitRealAndImagArrays: splitRealAndImagArrays,
- complexWithEvenIndex: complexWithEvenIndex,
- complexWithOddIndex: complexWithOddIndex,
- getComplexWithIndex: getComplexWithIndex,
- assignToTypedArray: assignToTypedArray,
- exponents: exponents,
- exponent: exponent,
- prepareSplitSize: prepareSplitSize
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // TODO(annxingyuan): Use this helper in WASM Split kernel once intermediate
- // kernels have been modularized in WebGL and CPU
- // https://github.com/tensorflow/tfjs/issues/2822.
- /** Shared implementation of the split kernel across WebGL and CPU. */
- function split$1(x, sizeSplits, axis) {
- const begin = new Array(x.rank).fill(0);
- const size = x.shape.slice();
- return sizeSplits.map(s => {
- const sliceSize = [...size];
- sliceSize[axis] = s;
- const sliceT = slice(x, begin, sliceSize);
- begin[axis] += s;
- return sliceT;
- });
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function tile$1(xBuf, reps) {
- const newShape = new Array(xBuf.rank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = xBuf.shape[i] * reps[i];
- }
- const result = buffer(newShape, xBuf.dtype);
- for (let i = 0; i < result.values.length; ++i) {
- const newLoc = result.indexToLoc(i);
- const originalLoc = new Array(xBuf.rank);
- for (let j = 0; j < originalLoc.length; j++) {
- originalLoc[j] = newLoc[j] % xBuf.shape[j];
- }
- const originalIndex = xBuf.locToIndex(originalLoc);
- result.values[i] = xBuf.values[originalIndex];
- }
- return result.toTensor();
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function topkImpl(x, xShape, xDtype, k, sorted) {
- // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
- const lastDim = xShape[xShape.length - 1];
- const [batch, size] = [x.length / lastDim, lastDim];
- const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
- const allTopKIndices = getTypedArrayFromDType('int32', batch * k);
- for (let b = 0; b < batch; b++) {
- const offset = b * size;
- const vals = x.subarray(offset, offset + size);
- const valAndInd = [];
- for (let i = 0; i < vals.length; i++) {
- valAndInd.push({ value: vals[i], index: i });
- }
- valAndInd.sort((a, b) => b.value - a.value);
- const outOffset = b * k;
- const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
- const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
- for (let i = 0; i < k; i++) {
- topKVals[i] = valAndInd[i].value;
- topKIndices[i] = valAndInd[i].index;
- }
- }
- // Reshape back to the original input shape, except that the last
- // dimension is k.
- const outputShape = xShape.slice();
- outputShape[outputShape.length - 1] = k;
- return [
- tensor(allTopKVals, outputShape, xDtype),
- tensor(allTopKIndices, outputShape, 'int32')
- ];
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var kernel_impls = /*#__PURE__*/Object.freeze({
- __proto__: null,
- nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl,
- nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl,
- nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl,
- split: split$1,
- tile: tile$1,
- topkImpl: topkImpl,
- whereImpl: whereImpl
- });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const absGradConfig = {
- kernelName: Abs,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(dy, step(cast(x, 'float32'), -1)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const acosGradConfig = {
- kernelName: Acos,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return {
- x: () => {
- const a = square(cast(x, 'float32'));
- const b = sqrt(sub(scalar(1), a));
- return neg(div(dy, b));
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const acoshGradConfig = {
- kernelName: Acosh,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return {
- x: () => {
- const a = sqrt(sub(square(cast(x, 'float32')), 1));
- return div(dy, a);
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const addGradConfig = {
- kernelName: Add,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- let res = dy;
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, a.shape);
- };
- const derB = () => {
- let res = dy;
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, b.shape);
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const addNGradConfig = {
- kernelName: AddN,
- saveAllInputs: true,
- gradFunc: (dy, saved) => {
- const ders = {};
- saved.forEach((_, i) => {
- ders[i] = () => dy.clone();
- });
- return ders;
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const argMaxGradConfig = {
- kernelName: ArgMax,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => zerosLike(x) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const argMinGradConfig = {
- kernelName: ArgMin,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => zerosLike(x) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const asinGradConfig = {
- kernelName: Asin,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, sqrt(sub(scalar(1), square(cast(x, 'float32'))))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const asinhGradConfig = {
- kernelName: Asinh,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return {
- x: () => {
- const a = sqrt(add$1(scalar(1), square(cast(x, 'float32'))));
- return div(dy, a);
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const atan2GradConfig = {
- kernelName: Atan2,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- const d = add$1(square(a), square(b));
- let res = mul(dy, div(b, d));
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, a.shape);
- };
- const derB = () => {
- const d = add$1(square(a), square(b));
- let res = neg(mul(dy, div(a, d)));
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, b.shape);
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const atanGradConfig = {
- kernelName: Atan,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, add$1(square(cast(x, 'float32')), 1)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const atanhGradConfig = {
- kernelName: Atanh,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, sub(scalar(1), square(cast(x, 'float32')))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the backprop of a 3d avg pool.
- *
- * @param dy The dy error, of rank 5 of shape
- * [batchSize, depth, height, width, channels].
- * assumed.
- * @param input The original input image, of rank 5 or rank4 of shape
- * [batchSize, depth, height, width, channels].
- * @param filterSize The filter size:
- * `[filterDepth, filterHeight, filterWidth]`.
- * `filterSize` is a single number,
- * then `filterDepth == filterHeight == filterWidth`.
- * @param strides The strides of the pooling:
- * `[strideDepth, strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param dilations Deprecated, this field will be gone in v3.0.0. The dilation
- * rates: `[dilationDepth, dilationHeight, dilationWidth]`
- * in which we sample input values across the depth, height and width
- * dimensions in dilated pooling.
- * Defaults to `[1, 1, 1]`. If `dilations` is a single number,
- * then `dilationDepth == dilationHeight == dilationWidth`.
- * If it is greater than 1, then all values of `strides` must be 1.
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The
- * rounding mode used when computing output dimensions if pad is a
- * number. If none is provided, it will not round and error if the output
- * is of fractional size.
- */
- function avgPool3dBackprop_(dy, input, filterSize, strides, dilations = [1, 1, 1], pad, dimRoundingMode) {
- const $dy = convertToTensor(dy, 'dy', 'avgPool3dBackprop');
- const $input = convertToTensor(input, 'input', 'avgPool3dBackprop');
- let dy5D = $dy;
- let input5D = $input;
- let reshapedTo5D = false;
- if ($input.rank === 4) {
- reshapedTo5D = true;
- dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
- input5D = reshape($input, [
- 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
- ]);
- }
- assert(dy5D.rank === 5, () => `Error in avgPool3dBackprop: dy must be rank 5 but got rank ` +
- `${dy5D.rank}.`);
- assert(input5D.rank === 5, () => `Error in avgPool3dBackprop: input must be rank 5 but got rank ` +
- `${input5D.rank}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool3dBackprop: Either strides or dilations ' +
- `must be 1. Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in maxPool3dBackprop: pad must be an integer when ` +
- `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = backend => {
- const convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- return backend.avgPool3dBackprop(dy5D, input5D, convInfo);
- };
- const inputs = { dy: dy5D, input: input5D };
- const attrs = { filterSize, strides, dilations, pad, dimRoundingMode };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, AvgPool3DBackprop, attrs);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const avgPool3dBackprop = op({ avgPool3dBackprop_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const avgPool3DGradConfig = {
- kernelName: AvgPool3D,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { filterSize, strides, dilations, pad, dimRoundingMode } = attrs;
- const $dilations = dilations == null ? [1, 1, 1] : dilations;
- return {
- x: () => avgPool3dBackprop(dy, x, filterSize, strides, $dilations, pad, dimRoundingMode)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the backprop of an 2D avg pool.
- *
- * @param dy The dy error, of rank 4 or rank 3 of shape
- * [batchSize, height, width, channels]. If rank 3, batch of 1 is
- * assumed.
- * @param input The input image, of rank 4 or rank 3 of shape
- * [batchSize, height, width, channels]. If rank 3, batch of 1 is
- * assumed.
- * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- */
- function avgPoolBackprop_(dy, input, filterSize, strides, pad) {
- const $dy = convertToTensor(dy, 'dy', 'avgPoolBackprop');
- const $input = convertToTensor(input, 'input', 'avgPoolBackprop');
- assert($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
- let input4D = $input;
- let dy4D = $dy;
- let reshapedTo4D = false;
- if ($input.rank === 3) {
- reshapedTo4D = true;
- input4D =
- reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
- dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
- }
- assert(dy4D.rank === 4, () => `Error in avgPoolBackprop: dy must be rank 4 but got rank ` +
- `${dy4D.rank}.`);
- assert(input4D.rank === 4, () => `Error in avgPoolBackprop: input must be rank 4 but got rank ` +
- `${input4D.rank}.`);
- const forward = backend => {
- const convInfo = computePool2DInfo(input4D.shape, filterSize, strides, 1 /* dilations */, pad);
- return backend.avgPoolBackprop(dy4D, input4D, convInfo);
- };
- const inputs = { dy: dy4D, input: input4D };
- const attrs = { filterSize, strides, pad };
- const res = ENGINE.runKernelFunc(forward, inputs, null, AvgPoolBackprop, attrs);
- if (reshapedTo4D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
- }
- return res;
- }
- const avgPoolBackprop = op({ avgPoolBackprop_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const avgPoolGradConfig = {
- kernelName: AvgPool,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { filterSize, strides, pad } = attrs;
- return {
- x: () => avgPoolBackprop(dy, x, filterSize, strides, pad)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const batchMatMulGradConfig = {
- kernelName: BatchMatMul,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved, attrs) => {
- const [a, b] = saved;
- const { transposeA, transposeB } = attrs;
- if (!transposeA && !transposeB) {
- return {
- a: () => matMul(dy, b, false, true),
- b: () => matMul(a, dy, true, false)
- };
- }
- else if (!transposeA && transposeB) {
- return {
- a: () => matMul(dy, b, false, false),
- b: () => matMul(dy, a, true, false)
- };
- }
- else if (transposeA && !transposeB) {
- return {
- a: () => matMul(b, dy, false, true),
- b: () => matMul(a, dy, false, false)
- };
- }
- else {
- return {
- a: () => matMul(b, dy, true, true),
- b: () => matMul(dy, a, true, true)
- };
- }
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const batchToSpaceNDGradConfig = {
- kernelName: BatchToSpaceND,
- gradFunc: (dy, saved, attrs) => {
- const { blockShape, crops } = attrs;
- return { x: () => spaceToBatchND(dy, blockShape, crops) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const broadcastToGradConfig = {
- kernelName: BroadcastTo,
- gradFunc: (dy, saved, attrs) => {
- const broadCastToAttrs = attrs;
- const inputShape = broadCastToAttrs.inputShape;
- const outputShape = broadCastToAttrs.shape;
- const reps = Array.from(outputShape);
- for (let i = inputShape.length - 1; i >= 0; i--) {
- if (inputShape[i] === outputShape[i]) {
- reps[i] = 1;
- }
- else if (inputShape[i] !== 1) {
- throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
- }
- }
- const axes = [];
- for (let i = 0; i < reps.length; i++) {
- if (reps[i] > 1) {
- axes.push(i);
- }
- }
- return { x: () => sum$1(dy, axes, true /* keepDims */) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const castGradConfig = {
- kernelName: Cast,
- gradFunc: (dy) => {
- return { x: () => dy.clone() };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ceilGradConfig = {
- kernelName: Ceil,
- gradFunc: (dy) => {
- // TODO(manrajgrover): Return null for gradients when backprop supports it.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const clipByValueGradConfig = {
- kernelName: ClipByValue,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { clipValueMin, clipValueMax } = attrs;
- return {
- x: () => where(logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), dy, zerosLike(dy)),
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const concatGradConfig = {
- kernelName: Concat,
- saveAllInputs: true,
- gradFunc: (dy, saved, attrs) => {
- const shapes = saved.map(t => t.shape);
- const { axis } = attrs;
- const $axis = parseAxisParam(axis, saved[0].shape)[0];
- const sizeSplits = shapes.map(s => s[$axis]);
- const derTensors = split(dy, sizeSplits, $axis);
- return derTensors.map(t => () => t);
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const conv2DGradConfig = {
- kernelName: Conv2D,
- inputsToSave: ['x', 'filter'],
- gradFunc: (dy, saved, attrs) => {
- const [x4D, $filter] = saved;
- const { dilations, strides, pad, dataFormat } = attrs;
- assert(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +
- `are not yet supported in gradients. Got dilations '${dilations}'`);
- return {
- x: () => conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad, dataFormat),
- filter: () => conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad, dataFormat)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const conv2DBackpropInputGradConfig = {
- kernelName: Conv2DBackpropInput,
- inputsToSave: ['dy', 'filter'],
- gradFunc: (ddx, saved, attrs) => {
- const [dy, filter] = saved;
- const { strides, pad, dataFormat, dimRoundingMode } = attrs;
- return {
- dy: () => conv2d(ddx, filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode),
- filter: () => conv2DBackpropFilter(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the derivative of the filter of a 3D convolution.
- *
- * @param x The input tensor, of rank 5 or rank 4 of shape
- * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is
- * assumed.
- * @param dy The dy image, of rank 5 or rank 4, of shape
- * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is
- * assumed.
- * @param filterShape The shape of the filter, length 5,
- * [filterDepth, filterHeight, filterWidth, inDepth, outDepth].
- * @param strides The strides of the convolution: [strideDepth, strideHeight,
- * strideWidth].
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- */
- function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
- let x5D = x;
- if (x.rank === 4) {
- x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
- }
- let dy5D = dy;
- if (dy5D.rank === 4) {
- dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
- }
- assert(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +
- `${x5D.shape}.`);
- assert(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +
- `${dy5D.shape}.`);
- assert(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +
- `${filterShape}.`);
- assert(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +
- `match input depth in filter (${filterShape[3]}.`);
- assert(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +
- `match output depth for filter (${filterShape[4]}).`);
- const forward = backend => {
- const dilations = 1;
- const convInfo = computeConv3DInfo(x5D.shape, filterShape, strides, dilations, pad);
- return backend.conv3dDerFilter(x5D, dy5D, convInfo);
- };
- const inputs = { x: x5D, y: dy5D };
- const attrs = { strides, pad };
- return ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropFilterV2, attrs);
- }
- const conv3DBackpropFilter = op({ conv3DBackpropFilter_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const conv3DGradConfig = {
- kernelName: Conv3D,
- inputsToSave: ['x', 'filter'],
- gradFunc: (dy, saved, attrs) => {
- const { dilations, strides, pad } = attrs;
- assert(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' +
- `not yet supported in gradients. Got dilations '${dilations}'`);
- const [x5D, $filter] = saved;
- return {
- x: () => conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad),
- filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const cosGradConfig = {
- kernelName: Cos,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(neg(sin(cast(x, 'float32'))), dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const coshGradConfig = {
- kernelName: Cosh,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(sinh(cast(x, 'float32')), dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const cumsumGradConfig = {
- kernelName: Cumsum,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { axis, exclusive, reverse } = attrs;
- return {
- x: () => {
- const permutation = getAxesPermutation([axis], x.rank);
- let out = cumsum(dy, axis, exclusive, !reverse);
- if (permutation != null) {
- out = transpose(out, permutation);
- }
- return out;
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const depthwiseConv2dNativeGradConfig = {
- kernelName: DepthwiseConv2dNative,
- inputsToSave: ['x', 'filter'],
- gradFunc: (dy, saved, attrs) => {
- const { dilations, strides, pad, dimRoundingMode } = attrs;
- const $dilations = dilations == null ? [1, 1] : dilations;
- assert(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' +
- `greater than 1 are not yet supported. Got dilations ` +
- `'${$dilations}'`);
- const [x, filter] = saved;
- assert(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` +
- `rank 4, but got rank ${x.rank}.`);
- assert(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` +
- `rank 4, but got rank ${filter.rank}.`);
- assert(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` +
- `channels (${x.shape[3]}) must match the inChannels dimension ` +
- `in filter ${filter.shape[2]}.`);
- assert(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' +
- `dilations must be 1. Got strides ${strides} and dilations ` +
- `'${$dilations}'.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in depthwiseConv2d: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
- return {
- x: () => depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, convInfo),
- filter: () => depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, convInfo),
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const dilation2dGradConfig = {
- kernelName: Dilation2D,
- inputsToSave: ['x', 'filter'],
- gradFunc: (dy, saved, attrs) => {
- const [x, filter] = saved;
- const inputInputs = { x, filter, dy };
- const filterInputs = { x, filter, dy };
- return {
- x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs),
- filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const divGradConfig = {
- kernelName: Div,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- const res = div(dy, cast(b, 'float32'));
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(res, reduceAxes), a.shape);
- }
- return res;
- };
- const derB = () => {
- let res = mul(dy, cast(a, 'float32'));
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = reshape(sum$1(res, reduceAxes), b.shape);
- }
- const tmp = square(b);
- return neg(div(res, cast(tmp, 'float32')));
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const eluGradConfig = {
- kernelName: Elu,
- outputsToSave: [true],
- gradFunc: (dy, saved) => {
- const [y] = saved;
- const backPropKernelFunc = (backend) => {
- return backend.eluDer(dy, y);
- };
- const inputs = { dy, y };
- return {
- x: () => ENGINE.runKernelFunc(backPropKernelFunc, inputs, null /* grad */, EluGrad)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const erfGradConfig = {
- kernelName: Erf,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- const a = mul(exp(neg(square(x))), 2 / Math.sqrt(Math.PI));
- return { x: () => mul(dy, a) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const expGradConfig = {
- kernelName: Exp,
- outputsToSave: [true],
- gradFunc: (dy, saved) => {
- const [y] = saved;
- return { x: () => mul(dy, y) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const expm1GradConfig = {
- kernelName: Expm1,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(dy, exp(x)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const floorGradConfig = {
- kernelName: Floor,
- gradFunc: (dy) => {
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const floorDivGradConfig = {
- kernelName: FloorDiv,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- const res = div(dy, cast(b, 'float32'));
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(res, reduceAxes), a.shape);
- }
- return res;
- };
- const derB = () => {
- let res = mul(dy, cast(a, 'float32'));
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = reshape(sum$1(res, reduceAxes), b.shape);
- }
- const tmp = square(b);
- return neg(div(res, cast(tmp, 'float32')));
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const fusedBatchNormGradConfig = {
- kernelName: FusedBatchNorm,
- inputsToSave: ['x', 'mean', 'variance', 'scale'],
- gradFunc: (dy, saved, attrs) => {
- const { varianceEpsilon } = attrs;
- const [x, mean, variance, scale] = saved;
- const scaleValue = scale == null ? scalar(1) : scale;
- const reductionAxes = getReductionAxes(mean.shape, x.shape);
- const tileShape = [];
- if (mean.rank === 1) {
- for (let i = 0; i < x.shape.length - 1; ++i) {
- tileShape.push(x.shape[i]);
- }
- tileShape.push(1);
- }
- const xMinusMean = sub(x, mean);
- const dyTimesScaleValue = mul(dy, scaleValue);
- const oneOverSqrtVariance = rsqrt(add$1(variance, scalar(varianceEpsilon)));
- const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
- const derX = () => {
- if (mean.rank === 1) {
- return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
- }
- else {
- return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
- }
- };
- const derMean = () => {
- let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
- if (mean.rank === 1) {
- meanDer = sum$1(meanDer, reductionAxes);
- }
- return reshape(meanDer, mean.shape);
- };
- const derVariance = () => {
- let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
- if (mean.rank === 1) {
- varianceDer = sum$1(varianceDer, reductionAxes);
- }
- return reshape(varianceDer, mean.shape);
- };
- const derScale = () => {
- const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
- let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
- if (mean.rank === 1) {
- scaleDer = sum$1(scaleDer, reductionAxes);
- }
- return reshape(scaleDer, mean.shape);
- };
- const derOffset = () => {
- let offsetDer = dy;
- if (mean.rank === 1) {
- offsetDer = sum$1(offsetDer, reductionAxes);
- }
- return reshape(offsetDer, mean.shape);
- };
- return {
- x: derX,
- mean: derMean,
- variance: derVariance,
- scale: derScale,
- offset: derOffset
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const gatherGradConfig = {
- kernelName: GatherV2,
- inputsToSave: ['x', 'indices'],
- gradFunc: (dy, saved, attrs) => {
- const [x, indices] = saved;
- const { axis } = attrs;
- const parsedAxis = parseAxisParam(axis, x.shape)[0];
- const derX = () => {
- const paramsShape = x.shape;
- const indicesSize = indices.size;
- const outerShape = paramsShape.slice(0, parsedAxis);
- const outerDims = outerShape.length;
- const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
- const innerDims = innerShape.length;
- const outerAxesIndices = arrayRange(0, outerDims);
- const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
- const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
- const values = reshape(dy, valuesShape);
- const reshapedIndices = reshape(indices, [indicesSize]);
- const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
- const valuesTranspose = transpose(values, transposeDims);
- let paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
- const invertTransposeDims = getUndoAxesPermutation(transposeDims);
- paramsGrad = transpose(paramsGrad, invertTransposeDims);
- return paramsGrad;
- };
- return { x: derX, indices: () => indices };
- }
- };
- function arrayRange(start, stop) {
- const result = [];
- for (let i = start; i < stop; ++i) {
- result.push(i);
- }
- return result;
- }
- function arrayConcat(arrays) {
- const result = [];
- for (let i = 0; i < arrays.length; ++i) {
- for (let j = 0; j < arrays[i].length; ++j) {
- result.push(arrays[i][j]);
- }
- }
- return result;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const greaterEqualGradConfig = {
- kernelName: GreaterEqual,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- return { a: () => zerosLike(a), b: () => zerosLike(b) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const identityGradConfig = {
- kernelName: Identity,
- gradFunc: (dy) => {
- return { x: () => cast(dy, 'float32') };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isFiniteGradConfig = {
- kernelName: IsFinite,
- gradFunc: (dy) => {
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isInfGradConfig = {
- kernelName: IsInf,
- gradFunc: (dy) => {
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isNanGradConfig = {
- kernelName: IsNan,
- gradFunc: (dy) => {
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const log1pGradConfig = {
- kernelName: Log1p,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, add$1(x, 1)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const logGradConfig = {
- kernelName: Log,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, cast(x, 'float32')) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const logSoftmaxGradConfig = {
- kernelName: LogSoftmax,
- inputsToSave: [],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const [value] = saved;
- const { axis } = attrs;
- return {
- logits: () => {
- const keepDims = true;
- const softmax = exp(value);
- return sub(dy, mul(sum$1(dy, axis, keepDims), softmax));
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
- const forward = backend => backend.LRNGrad(dy, x, y, depthRadius, bias, alpha, beta);
- const inputs = { x, y, dy };
- const attrs = { depthRadius, bias, alpha, beta };
- return ENGINE.runKernelFunc(forward, inputs, null /* grad */, LRNBackprop, attrs);
- }
- const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const lrnGradConfig = {
- kernelName: LRN,
- inputsToSave: ['x'],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const [x, y] = saved;
- const { depthRadius, bias, alpha, beta } = attrs;
- return {
- x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gradient helper function for the min and max operations.
- */
- function gradForMinAndMax(dy, y, xOrig, origAxes, permutedAxes) {
- if (y.rank < xOrig.rank) {
- y = reshape(y, expandShapeToKeepDim(y.shape, origAxes));
- }
- if (dy.rank < xOrig.rank) {
- dy = reshape(dy, expandShapeToKeepDim(dy.shape, origAxes));
- }
- return {
- x: () => {
- const dx = mul(dy, cast(equal(xOrig, y), dy.dtype));
- return permutedAxes == null ? dx : transpose(dx, permutedAxes);
- }
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxGradConfig = {
- kernelName: Max,
- inputsToSave: ['x'],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const maxAttrs = attrs;
- const { reductionIndices } = maxAttrs;
- const [x, y] = saved;
- const origAxes = parseAxisParam(reductionIndices, x.shape);
- const permutedAxes = getAxesPermutation(origAxes, x.rank);
- const maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes);
- return {
- x: () => {
- let out = maxGrad['x']();
- if (permutedAxes != null) {
- out = transpose(out);
- }
- return out;
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maximumGradConfig = {
- kernelName: Maximum,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const derA = () => mul(dy, cast(greaterEqual(a, b), 'float32'));
- const derB = () => mul(dy, cast(less(a, b), 'float32'));
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the backprop of a 3d max pool.
- *
- * @param dy The dy error, of rank 5 of shape
- * [batchSize, depth, height, width, channels].
- * assumed.
- * @param input The original input image, of rank 5 or rank 4 of shape
- * [batchSize, depth, height, width, channels].
- * @param output The original output image, of rank 5 of shape
- * [batchSize, outDepth, outHeight, outWidth, channels].
- * @param filterSize The filter size:
- * `[filterDepth, filterHeight, filterWidth]`.
- * `filterSize` is a single number,
- * then `filterDepth == filterHeight == filterWidth`.
- * @param strides The strides of the pooling:
- * `[strideDepth, strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param dilations Deprecated, this field will be gone in v3.0.0.
- * The dilation rates: `[dilationDepth, dilationHeight, dilationWidth]`
- * in which we sample input values across the depth, height and width
- * dimensions in dilated pooling.
- * Defaults to `[1, 1, 1]`. If `dilations` is a single number,
- * then `dilationDepth == dilationHeight == dilationWidth`.
- * If it is greater than 1, then all values of `strides` must be 1.
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The
- * rounding mode used when computing output dimensions if pad is a
- * number. If none is provided, it will not round and error if the output
- * is of fractional size.
- */
- function maxPool3dBackprop_(dy, input, output, filterSize, strides, dilations = [1, 1, 1], pad, dimRoundingMode) {
- const $dy = convertToTensor(dy, 'dy', 'maxPool3dBackprop');
- const $input = convertToTensor(input, 'input', 'maxPool3dBackprop');
- const $output = convertToTensor(output, 'output', 'maxPool3dBackprop');
- let dy5D = $dy;
- let input5D = $input;
- let output5D = $output;
- let reshapedTo5D = false;
- if ($input.rank === 4) {
- reshapedTo5D = true;
- dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
- input5D = reshape($input, [
- 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
- ]);
- output5D = reshape($output, [
- 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]
- ]);
- }
- assert(dy5D.rank === 5, () => `Error in maxPool3dBackprop: dy must be rank 5 but got rank ` +
- `${dy5D.rank}.`);
- assert(input5D.rank === 5, () => `Error in maxPool3dBackprop: input must be rank 5 but got rank ` +
- `${input5D.rank}.`);
- assert(output5D.rank === 5, () => `Error in maxPool3dBackprop: output must be rank 5 but got rank ` +
- `${output5D.rank}.`);
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool3dBackprop: Either strides or dilations ' +
- `must be 1. Got strides ${strides} and dilations '${dilations}'`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in maxPool3dBackprop: pad must be an integer when ` +
- `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = backend => {
- const convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- return backend.maxPool3dBackprop(dy5D, input5D, output5D, convInfo);
- };
- const inputs = { dy: dy5D, input: input5D, output: output5D };
- const attrs = { filterSize, strides, dilations, pad, dimRoundingMode };
- const res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, MaxPool3DBackprop, attrs);
- if (reshapedTo5D) {
- return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
- }
- return res;
- }
- const maxPool3dBackprop = op({ maxPool3dBackprop_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxPool3DGradConfig = {
- kernelName: MaxPool3D,
- inputsToSave: ['x'],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const [x, y] = saved;
- const { filterSize, strides, dilations, pad, dimRoundingMode } = attrs;
- const $dilations = dilations == null ? [1, 1, 1] : dilations;
- return {
- x: () => maxPool3dBackprop(dy, x, y, filterSize, strides, $dilations, pad, dimRoundingMode)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the backprop of a 2D max pool.
- *
- * @param dy The dy error, of rank 4 or rank 3 of shape
- * [batchSize, height, width, channels]. If rank 3, batch of 1 is
- * assumed.
- * @param input The original input image, of rank 4, of shape
- * [batchSize, height, width, channels].
- * @param output The original output image, of rank 4, of shape
- * [batchSize, outHeight, outWidth, channels].
- * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
- * `filterSize` is a single number, then `filterHeight == filterWidth`.
- * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
- * `strides` is a single number, then `strideHeight == strideWidth`.
- * @param pad A string from: 'same', 'valid'. The type of padding algorithm
- * used in the forward prop of the op.
- * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The
- * rounding mode used when computing output dimensions if pad is a
- * number. If none is provided, it will not round and error if the output
- * is of fractional size.
- */
- function maxPoolBackprop_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
- const $dy = convertToTensor(dy, 'dy', 'maxPoolBackprop');
- const $input = convertToTensor(input, 'input', 'maxPoolBackprop');
- const $output = convertToTensor(output, 'output', 'maxPoolBackprop');
- assert($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` +
- `(${$dy.rank})`);
- assert($dy.rank === 4, () => `Error in maxPoolBackprop: dy must be rank 4 but got rank ` +
- `${$dy.rank}.`);
- assert($input.rank === 4, () => `Error in maxPoolBackprop: input must be rank 4 but got rank ` +
- `${$input.rank}.`);
- if (dimRoundingMode != null) {
- assert(isInt(pad), () => `Error in maxPoolBackprop: pad must be an integer when using, ` +
- `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
- }
- const forward = backend => {
- const convInfo = computePool2DInfo($input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
- return backend.maxPoolBackprop($dy, $input, $output, convInfo);
- };
- const inputs = { dy: $dy, input: $input, output: $output };
- const attrs = { filterSize, strides, pad, dimRoundingMode };
- return ENGINE.runKernelFunc(forward, inputs, null, MaxPoolBackprop, attrs);
- }
- const maxPoolBackprop = op({ maxPoolBackprop_ });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxPoolGradConfig = {
- kernelName: MaxPool,
- inputsToSave: ['x'],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const [x, y] = saved;
- const { filterSize, strides, pad } = attrs;
- return {
- x: () => maxPoolBackprop(dy, x, y, filterSize, strides, pad)
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const minGradConfig = {
- kernelName: Min,
- inputsToSave: ['x'],
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const minAttrs = attrs;
- const { axis } = minAttrs;
- const [x, y] = saved;
- const origAxes = parseAxisParam(axis, x.shape);
- const permutedAxes = getAxesPermutation(origAxes, x.rank);
- const minGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes);
- return {
- x: () => {
- let out = minGrad['x']();
- if (permutedAxes != null) {
- out = transpose(out);
- }
- return out;
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const minimumGradConfig = {
- kernelName: Minimum,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const derA = () => mul(dy, cast(lessEqual(a, b), 'float32'));
- const derB = () => mul(dy, cast(greater(a, b), 'float32'));
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const mirrorPadGradConfig = {
- kernelName: MirrorPad,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- // Pad introduces values around the original tensor, so the gradient
- // slices the original shape out of the gradient.
- const x = saved[0];
- const { paddings } = attrs;
- const begin = paddings.map(p => p[0]);
- return { x: () => slice(dy, begin, x.shape) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const modGradConfig = {
- kernelName: Mod,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(dy, reduceAxes), a.shape);
- }
- return dy;
- };
- const derB = () => {
- const res = mul(dy, neg(floor(div(a, b))));
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(res, reduceAxes), b.shape);
- }
- return res;
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const multiplyGradConfig = {
- kernelName: Multiply,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- const res = mul(dy, cast(b, 'float32'));
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(res, reduceAxes), a.shape);
- }
- return res;
- };
- const derB = () => {
- const res = mul(dy, cast(a, 'float32'));
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- return reshape(sum$1(res, reduceAxes), b.shape);
- }
- return res;
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const negateGradConfig = {
- kernelName: Negate,
- gradFunc: (dy) => {
- return { x: () => neg(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const oneHotGradConfig = {
- kernelName: OneHot,
- inputsToSave: ['indices'],
- gradFunc: (dy, saved) => {
- const indices = saved[0];
- return { indices: () => zeros(indices.shape, 'float32') };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const onesLikeGradConfig = {
- kernelName: OnesLike,
- gradFunc: (dy) => {
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const padV2GradConfig = {
- kernelName: PadV2,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- // Pad introduces values around the original tensor, so the gradient
- // slices the original shape out of the gradient.
- const x = saved[0];
- const { paddings } = attrs;
- const begin = paddings.map(p => p[0]);
- return { x: () => slice(dy, begin, x.shape) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const powGradConfig = {
- kernelName: Pow,
- inputsToSave: ['a', 'b'],
- outputsToSave: [true],
- gradFunc: (dy, saved) => {
- const [a, b, y] = saved;
- const base = a;
- const exp = b;
- const outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
- const derBase = () => {
- const expFloat = cast(exp, 'float32');
- let res = mul(dy, mul(expFloat, pow(base, sub(expFloat, scalar(1)))));
- const reduceAxes = getReductionAxes(base.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, base.shape);
- };
- const derExp = () => {
- const condition = greater(base, 0);
- const logBase = where(condition, log(base), zerosLike(base));
- let res = mul(dy, mul(y, logBase));
- const reduceAxes = getReductionAxes(exp.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, exp.shape);
- };
- return { a: derBase, b: derExp };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const preluGradConfig = {
- kernelName: Prelu,
- inputsToSave: ['x', 'alpha'],
- gradFunc: (dy, saved) => {
- const [x, alpha] = saved;
- const mask = greater(x, 0);
- return {
- x: () => where(mask, dy, mul(dy, alpha)),
- alpha: () => {
- let res = where(mask, zerosLike(dy), mul(dy, x));
- const reduceAxes = getReductionAxes(alpha.shape, dy.shape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, alpha.shape);
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const reciprocalGradConfig = {
- kernelName: Reciprocal,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, neg(square(x))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const relu6GradConfig = {
- kernelName: Relu6,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- const mask = mul(lessEqual(x, 6), step(x));
- return { x: () => mul(dy, cast(mask, 'float32')) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const reluGradConfig = {
- kernelName: Relu,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(dy, cast(step(x), 'float32')) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const reshapeGradConfig = {
- kernelName: Reshape,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => reshape(dy, x.shape) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const resizeBilinearGradConfig = {
- kernelName: ResizeBilinear,
- inputsToSave: ['images'],
- gradFunc: (dy, saved, attrs) => {
- const [images] = saved;
- const backPropKernelFunc = (backend) => {
- const { alignCorners } = attrs;
- return backend.resizeBilinearBackprop(dy, images, alignCorners);
- };
- const inputs = { images };
- const imagesDer = () => ENGINE.runKernelFunc(backPropKernelFunc, inputs, null /* gradient */, ResizeBilinearGrad, attrs);
- return { images: imagesDer };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const resizeNearestNeighborGradConfig = {
- kernelName: ResizeNearestNeighbor,
- inputsToSave: ['images'],
- gradFunc: (dy, saved, attrs) => {
- const [images] = saved;
- const backPropKernelFunc = (backend) => {
- const { alignCorners } = attrs;
- return backend.resizeNearestNeighborBackprop(dy, images, alignCorners);
- };
- const inputs = { images };
- const imagesDer = () => ENGINE.runKernelFunc(backPropKernelFunc, inputs, null /* gradient */, ResizeNearestNeighborGrad, attrs);
- return { images: imagesDer };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const reverseGradConfig = {
- kernelName: Reverse,
- gradFunc: (dy, saved, attrs) => {
- const { dims } = attrs;
- const axes = parseAxisParam(dims, dy.shape);
- return { x: () => reverse(dy, axes) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const roundGradConfig = {
- kernelName: Round,
- gradFunc: (dy) => {
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const rsqrtGradConfig = {
- kernelName: Rsqrt,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => neg(div(dy, mul(pow(x, 1.5), 2))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const selectV2PoolGradConfig = {
- kernelName: SelectV2,
- inputsToSave: ['condition'],
- gradFunc: (dy, saved) => {
- const [condition] = saved;
- return {
- // TODO(julianoks): Return null for condition gradient
- // when backprop supports it.
- condition: () => cast(zerosLike(condition), 'float32'),
- t: () => mul(dy, cast(condition, dy.dtype)),
- e: () => mul(dy, cast(logicalNot(condition), dy.dtype))
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const seluGradConfig = {
- kernelName: Selu,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return {
- x: () => {
- const mask = greater(x, scalar(0));
- const scaleAlpha = scalar(SELU_SCALEALPHA);
- const scale = scalar(SELU_SCALE);
- const greaterThanZeroDer = mul(dy, scale);
- const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp(cast(x, 'float32')));
- return where(mask, greaterThanZeroDer, lessEqualZeroDer);
- }
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sigmoidGradConfig = {
- kernelName: Sigmoid,
- outputsToSave: [true],
- gradFunc: (dy, saved) => {
- const [y] = saved;
- return { x: () => mul(dy, mul(y, sub(scalar(1), y))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const signGradConfig = {
- kernelName: Sign,
- gradFunc: (dy) => {
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sinGradConfig = {
- kernelName: Sin,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(cos(cast(x, 'float32')), dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sinhGradConfig = {
- kernelName: Sinh,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(cosh(cast(x, 'float32')), dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sliceGradConfig = {
- kernelName: Slice,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { begin, size } = attrs;
- const inputShape = x.shape;
- const [begin_, size_] = parseSliceParams(x, begin, size);
- // Create an Nx2 padding where the first column represents how many
- // zeros are prepended (at start) for each dimension, and the second
- // column indicates how many zeros are appended (at end).
- // The number of zeros to append is the shape of the input
- // elementwise-subtracted by both the begin vector and sizes vector.
- const paddings = [];
- for (let i = 0; i < dy.rank; i++) {
- paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
- }
- return { x: () => pad(dy, paddings) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const softmaxGradConfig = {
- kernelName: Softmax,
- outputsToSave: [true],
- gradFunc: (dy, saved, attrs) => {
- const [y] = saved;
- const { dim } = attrs;
- const keepDims = true;
- const dyTimesY = mul(dy, y);
- return {
- logits: () => sub(dyTimesY, mul(sum$1(dyTimesY, [dim], keepDims), y))
- };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const softplusGradConfig = {
- kernelName: Softplus,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(dy, sigmoid(x)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const spaceToBatchNDGradConfig = {
- kernelName: SpaceToBatchND,
- gradFunc: (dy, saved, attrs) => {
- const { blockShape, paddings } = attrs;
- return { x: () => batchToSpaceND(dy, blockShape, paddings) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const splitVGradConfig = {
- kernelName: SplitV,
- gradFunc: (dy, saved, attrs) => {
- const { axis } = attrs;
- return { x: () => concat(dy, axis) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sqrtGradConfig = {
- kernelName: Sqrt,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, mul(sqrt(cast(x, 'float32')), 2)) };
- }
- };
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const squareGradConfig = {
- kernelName: Square,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => mul(dy, mul(cast(x, 'float32'), 2)) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const squaredDifferenceGradConfig = {
- kernelName: SquaredDifference,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const two = scalar(2);
- const derA = () => mul(dy, mul(two, sub(a, b)));
- const derB = () => mul(dy, mul(two, sub(b, a)));
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const stepGradConfig = {
- kernelName: Step,
- gradFunc: (dy) => {
- // TODO(manrajgrover): Return null for gradients when backprop supports
- // it.
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const subGradConfig = {
- kernelName: Sub,
- inputsToSave: ['a', 'b'],
- gradFunc: (dy, saved) => {
- const [a, b] = saved;
- const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const derA = () => {
- let res = dy;
- const reduceAxes = getReductionAxes(a.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(res, a.shape);
- };
- const derB = () => {
- let res = dy;
- const reduceAxes = getReductionAxes(b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = sum$1(res, reduceAxes);
- }
- return reshape(neg(res), b.shape);
- };
- return { a: derA, b: derB };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sumGradConfig = {
- kernelName: Sum,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const expandedDyShape = x.shape.slice();
- const { axis } = attrs;
- const axes = parseAxisParam(axis, x.shape);
- axes.forEach(axis => {
- expandedDyShape[axis] = 1;
- });
- const expandedDy = reshape(dy, expandedDyShape);
- const derX = mul(expandedDy, ones$1(x.shape, 'float32'));
- return { x: () => derX };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const tanGradConfig = {
- kernelName: Tan,
- inputsToSave: ['x'],
- gradFunc: (dy, saved) => {
- const [x] = saved;
- return { x: () => div(dy, square(cos(x))) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const tanhGradConfig = {
- kernelName: Tanh,
- outputsToSave: [true],
- gradFunc: (dy, saved) => {
- const [y] = saved;
- return { x: () => mul(sub(scalar(1), square(y)), dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const tileGradConfig = {
- kernelName: Tile,
- inputsToSave: ['x'],
- gradFunc: (dy, saved, attrs) => {
- const [x] = saved;
- const { reps } = attrs;
- const derX = () => {
- let xGrad = zerosLike(x);
- // TODO(cais): Maybe reduce memory footprint by avoiding repeated
- // slicing.
- if (x.rank === 1) {
- for (let i = 0; i < reps[0]; ++i) {
- xGrad = add$1(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]]));
- }
- }
- else if (x.rank === 2) {
- for (let i = 0; i < reps[0]; ++i) {
- for (let j = 0; j < reps[1]; ++j) {
- xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [
- x.shape[0], x.shape[1]
- ]));
- }
- }
- }
- else if (x.rank === 3) {
- for (let i = 0; i < reps[0]; ++i) {
- for (let j = 0; j < reps[1]; ++j) {
- for (let k = 0; k < reps[2]; ++k) {
- xGrad =
- add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
- }
- }
- }
- }
- else if (x.rank === 4) {
- for (let i = 0; i < reps[0]; ++i) {
- for (let j = 0; j < reps[1]; ++j) {
- for (let k = 0; k < reps[2]; ++k) {
- for (let l = 0; l < reps[3]; ++l) {
- xGrad =
- add$1(xGrad, slice(dy, [
- i * x.shape[0], j * x.shape[1], k * x.shape[2],
- l * x.shape[3]
- ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
- }
- }
- }
- }
- }
- else {
- throw new Error(`Gradient for tile operation is not implemented for rank-` +
- `${x.rank} tensors yet.`);
- }
- return xGrad;
- };
- return { x: derX };
- },
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const transposeGradConfig = {
- kernelName: Transpose,
- gradFunc: (dy, saved, attrs) => {
- const transposeAttrs = attrs;
- const { perm } = transposeAttrs;
- const undoPerm = getUndoAxesPermutation(perm);
- return { x: () => transpose(dy, undoPerm) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const unpackGradConfig = {
- kernelName: Unpack,
- gradFunc: (dy, saved, attrs) => {
- const unpackAttrs = attrs;
- const { axis } = unpackAttrs;
- return { value: () => stack(dy, axis) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const unsortedSegmentSumGradConfig = {
- kernelName: UnsortedSegmentSum,
- inputsToSave: ['segmentIds'],
- gradFunc: (dy, saved) => {
- const [segmentIds] = saved;
- const derX = () => {
- return gatherDropNegatives(dy, segmentIds);
- };
- return { x: derX };
- }
- };
- function gatherDropNegatives(x, indices) {
- // Helper function for unsorted segment ops. Gathers params for
- // positive segment ids and gathers 0 for inputs with negative segment id.
- // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
- const zeroClippedIndices = maximum(indices, zerosLike(indices));
- const gathered = gather(x, zeroClippedIndices);
- let isPositive = greaterEqual(indices, scalar(0, 'int32'));
- const numIters = gathered.rank - isPositive.rank;
- for (let i = 0; i < numIters; ++i) {
- isPositive = expandDims(isPositive, i + 1);
- }
- isPositive = logicalAnd(isPositive, ones$1(gathered.shape, 'bool'));
- const zeroSlice = zerosLike(gathered);
- return where(isPositive, gathered, zeroSlice);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const zerosLikeGradConfig = {
- kernelName: ZerosLike,
- gradFunc: (dy) => {
- return { x: () => zerosLike(dy) };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Export all kernel configs here so that the package can auto register them
- const gradConfigs = [
- absGradConfig,
- acosGradConfig,
- acoshGradConfig,
- addGradConfig,
- addNGradConfig,
- argMaxGradConfig,
- argMinGradConfig,
- asinGradConfig,
- asinhGradConfig,
- atan2GradConfig,
- atanGradConfig,
- atanhGradConfig,
- avgPool3DGradConfig,
- avgPoolGradConfig,
- batchMatMulGradConfig,
- batchToSpaceNDGradConfig,
- broadcastToGradConfig,
- castGradConfig,
- ceilGradConfig,
- clipByValueGradConfig,
- concatGradConfig,
- conv2DBackpropInputGradConfig,
- conv2DGradConfig,
- conv3DGradConfig,
- cosGradConfig,
- coshGradConfig,
- cumsumGradConfig,
- depthwiseConv2dNativeGradConfig,
- dilation2dGradConfig,
- divGradConfig,
- eluGradConfig,
- erfGradConfig,
- expGradConfig,
- expm1GradConfig,
- floorDivGradConfig,
- floorGradConfig,
- fusedBatchNormGradConfig,
- gatherGradConfig,
- greaterEqualGradConfig,
- identityGradConfig,
- isFiniteGradConfig,
- isInfGradConfig,
- isNanGradConfig,
- log1pGradConfig,
- logGradConfig,
- logSoftmaxGradConfig,
- lrnGradConfig,
- maxGradConfig,
- maxGradConfig,
- maximumGradConfig,
- maxPool3DGradConfig,
- maxPoolGradConfig,
- minGradConfig,
- minimumGradConfig,
- mirrorPadGradConfig,
- modGradConfig,
- multiplyGradConfig,
- negateGradConfig,
- oneHotGradConfig,
- onesLikeGradConfig,
- padV2GradConfig,
- padV2GradConfig,
- powGradConfig,
- preluGradConfig,
- reciprocalGradConfig,
- relu6GradConfig,
- reluGradConfig,
- reshapeGradConfig,
- resizeBilinearGradConfig,
- resizeNearestNeighborGradConfig,
- reverseGradConfig,
- roundGradConfig,
- rsqrtGradConfig,
- selectV2PoolGradConfig,
- seluGradConfig,
- sigmoidGradConfig,
- signGradConfig,
- sinGradConfig,
- sinhGradConfig,
- sliceGradConfig,
- softmaxGradConfig,
- softplusGradConfig,
- spaceToBatchNDGradConfig,
- spaceToBatchNDGradConfig,
- splitVGradConfig,
- splitVGradConfig,
- sqrtGradConfig,
- squaredDifferenceGradConfig,
- squareGradConfig,
- stepGradConfig,
- subGradConfig,
- sumGradConfig,
- tanGradConfig,
- tanhGradConfig,
- tileGradConfig,
- transposeGradConfig,
- unpackGradConfig,
- unsortedSegmentSumGradConfig,
- zerosLikeGradConfig
- ];
- for (const gradientConfig of gradConfigs) {
- registerGradient(gradientConfig);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.abs = function () {
- this.throwIfDisposed();
- return abs(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.acos = function () {
- this.throwIfDisposed();
- return acos(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.acosh = function () {
- this.throwIfDisposed();
- return acosh(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.addStrict = function (x) {
- this.throwIfDisposed();
- return addStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.add = function (b) {
- this.throwIfDisposed();
- return add$1(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.all = function (axis, keepDims) {
- this.throwIfDisposed();
- return all(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.any = function (axis, keepDims) {
- this.throwIfDisposed();
- return any(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.argMax = function (axis) {
- this.throwIfDisposed();
- return argMax(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.argMin = function (axis) {
- this.throwIfDisposed();
- return argMin(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.asScalar = function () {
- this.throwIfDisposed();
- assert(this.size === 1, () => 'The array must have only 1 element.');
- return reshape(this, []);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Casts a `tf.Tensor` to a specified dtype.
- *
- * @param dtype Data-type to cast the tensor to.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.asType = function (dtype) {
- this.throwIfDisposed();
- return cast(this, dtype);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Converts a `tf.Tensor` to a `tf.Tensor1D`.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.as1D = function () {
- this.throwIfDisposed();
- return reshape(this, [this.size]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor2D`.
- *
- * @param rows Number of rows in `tf.Tensor2D`.
- * @param columns Number of columns in `tf.Tensor2D`.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.as2D = function (rows, columns) {
- this.throwIfDisposed();
- return reshape(this, [rows, columns]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor3D`.
- *
- * @param rows Number of rows in `tf.Tensor3D`.
- * @param columns Number of columns in `tf.Tensor3D`.
- * @param depth Depth of `tf.Tensor3D`.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.as3D = function (rows, columns, depth) {
- this.throwIfDisposed();
- return reshape(this, [rows, columns, depth]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor4D`.
- *
- * @param rows Number of rows in `tf.Tensor4D`.
- * @param columns Number of columns in `tf.Tensor4D`.
- * @param depth Depth of `tf.Tensor4D`.
- * @param depth2 4th dimension of `tf.Tensor4D`.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.as4D = function (rows, columns, depth, depth2) {
- this.throwIfDisposed();
- return reshape(this, [rows, columns, depth, depth2]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor5D`.
- *
- * @param rows Number of rows in `tf.Tensor5D`.
- * @param columns Number of columns in `tf.Tensor5D`.
- * @param depth Depth of `tf.Tensor5D`.
- * @param depth2 4th dimension of `tf.Tensor5D`.
- * @param depth3 5th dimension of 'tf.Tensor5D'
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.as5D = function (rows, columns, depth, depth2, depth3) {
- this.throwIfDisposed();
- return reshape(this, [rows, columns, depth, depth2, depth3]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.asin = function () {
- this.throwIfDisposed();
- return asin(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.asinh = function () {
- this.throwIfDisposed();
- return asinh(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.atan = function () {
- this.throwIfDisposed();
- return atan(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.atan2 = function (b) {
- this.throwIfDisposed();
- return atan2(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.atanh = function () {
- this.throwIfDisposed();
- return atanh(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return avgPool(this, filterSize, strides, pad, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.batchToSpaceND = function (blockShape, crops) {
- this.throwIfDisposed();
- return batchToSpaceND(this, blockShape, crops);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) {
- this.throwIfDisposed();
- return batchNorm(this, mean, variance, offset, scale, varianceEpsilon);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.broadcastTo = function (shape) {
- this.throwIfDisposed();
- return broadcastTo(this, shape);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.cast = function (dtype) {
- this.throwIfDisposed();
- return cast(this, dtype);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.ceil = function () {
- this.throwIfDisposed();
- return ceil(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.clipByValue = function (min, max) {
- this.throwIfDisposed();
- return clipByValue(this, min, max);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.concat = function (x, axis) {
- this.throwIfDisposed();
- if (x instanceof Tensor) {
- x = [x];
- }
- return concat([this, ...x], axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
- this.throwIfDisposed();
- return conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
- this.throwIfDisposed();
- return conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.cos = function () {
- this.throwIfDisposed();
- return cos(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.cosh = function () {
- this.throwIfDisposed();
- return cosh(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.cumsum = function (axis, exclusive, reverse) {
- this.throwIfDisposed();
- return cumsum(this, axis, exclusive, reverse);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.depthToSpace = function (blockSize, dataFormat) {
- this.throwIfDisposed();
- return depthToSpace(this, blockSize, dataFormat);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated Use `depthwiseConv2d` instead.
- */
- Tensor.prototype.depthwiseConv2D = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
- deprecationWarn('depthwiseConv2D is deprecated, use depthwiseConv2d instead');
- this.throwIfDisposed();
- return depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.depthwiseConv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
- this.throwIfDisposed();
- return depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.dilation2d = function (filter, strides, pad, dilations, dataFormat) {
- this.throwIfDisposed();
- return dilation2d(this, filter, strides, pad, dilations, dataFormat);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.divNoNan = function (b) {
- this.throwIfDisposed();
- return divNoNan(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.divStrict = function (x) {
- this.throwIfDisposed();
- return divStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.div = function (b) {
- this.throwIfDisposed();
- return div(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.dot = function (b) {
- this.throwIfDisposed();
- return dot(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.elu = function () {
- this.throwIfDisposed();
- return elu(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.equalStrict = function (x) {
- this.throwIfDisposed();
- return equalStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.equal = function (b) {
- this.throwIfDisposed();
- return equal(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.erf = function () {
- this.throwIfDisposed();
- return erf(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.exp = function () {
- this.throwIfDisposed();
- return exp(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.expandDims = function (axis) {
- this.throwIfDisposed();
- return expandDims(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.expm1 = function () {
- this.throwIfDisposed();
- return expm1(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.fft = function () {
- this.throwIfDisposed();
- return fft(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Flatten a Tensor to a 1D array.
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.flatten = function () {
- this.throwIfDisposed();
- return reshape(this, [this.size]);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.floor = function () {
- this.throwIfDisposed();
- return floor(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.floorDiv = function (b) {
- this.throwIfDisposed();
- return floorDiv(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.gather = function (indices, axis) {
- this.throwIfDisposed();
- return gather(this, indices, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.greaterEqualStrict = function (x) {
- this.throwIfDisposed();
- return greaterEqualStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.greaterEqual = function (b) {
- this.throwIfDisposed();
- return greaterEqual(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.greaterStrict = function (x) {
- this.throwIfDisposed();
- return greaterStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.greater = function (b) {
- this.throwIfDisposed();
- return greater(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.ifft = function () {
- this.throwIfDisposed();
- return ifft(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.irfft = function () {
- this.throwIfDisposed();
- return irfft(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.isFinite = function () {
- this.throwIfDisposed();
- return isFinite$1(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.isInf = function () {
- this.throwIfDisposed();
- return isInf(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.isNaN = function () {
- this.throwIfDisposed();
- return isNaN$1(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.leakyRelu = function (alpha) {
- this.throwIfDisposed();
- return leakyRelu(this, alpha);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.lessEqualStrict = function (x) {
- this.throwIfDisposed();
- return lessEqualStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.lessEqual = function (b) {
- this.throwIfDisposed();
- return lessEqual(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.lessStrict = function (x) {
- this.throwIfDisposed();
- return lessStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.less = function (b) {
- this.throwIfDisposed();
- return less(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.localResponseNormalization = function (depthRadius, bias, alpha, beta) {
- this.throwIfDisposed();
- return localResponseNormalization(this, depthRadius, bias, alpha, beta);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logSigmoid = function () {
- this.throwIfDisposed();
- return logSigmoid(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logSoftmax = function (axis) {
- this.throwIfDisposed();
- return logSoftmax(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logSumExp = function (axis, keepDims) {
- this.throwIfDisposed();
- return logSumExp(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.log = function () {
- this.throwIfDisposed();
- return log(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.log1p = function () {
- this.throwIfDisposed();
- return log1p(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logicalAnd = function (b) {
- this.throwIfDisposed();
- return logicalAnd(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logicalNot = function () {
- this.throwIfDisposed();
- return logicalNot(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logicalOr = function (b) {
- this.throwIfDisposed();
- return logicalOr(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.logicalXor = function (b) {
- this.throwIfDisposed();
- return logicalXor(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.matMul = function (b, transposeA, transposeB) {
- this.throwIfDisposed();
- return matMul(this, b, transposeA, transposeB);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return maxPool(this, filterSize, strides, pad, dimRoundingMode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.max = function (axis, keepDims) {
- this.throwIfDisposed();
- return max(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.maximumStrict = function (x) {
- this.throwIfDisposed();
- return maximumStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.maximum = function (b) {
- this.throwIfDisposed();
- return maximum(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.mean = function (axis, keepDims) {
- this.throwIfDisposed();
- return mean(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.min = function (axis, keepDims) {
- this.throwIfDisposed();
- return min(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.minimumStrict = function (x) {
- this.throwIfDisposed();
- return minimumStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.minimum = function (b) {
- this.throwIfDisposed();
- return minimum(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.mirrorPad = function (paddings, mode) {
- this.throwIfDisposed();
- return mirrorPad(this, paddings, mode);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.modStrict = function (x) {
- this.throwIfDisposed();
- return modStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.mod = function (b) {
- this.throwIfDisposed();
- return mod(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.mulStrict = function (x) {
- this.throwIfDisposed();
- return mulStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.mul = function (b) {
- this.throwIfDisposed();
- return mul(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.neg = function () {
- this.throwIfDisposed();
- return neg(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.norm = function (ord, axis, keepDims) {
- this.throwIfDisposed();
- return norm(this, ord, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.notEqualStrict = function (x) {
- this.throwIfDisposed();
- return notEqualStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.notEqual = function (b) {
- this.throwIfDisposed();
- return notEqual(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.oneHot = function (depth, onValue = 1, offValue = 0) {
- this.throwIfDisposed();
- return oneHot(this, depth, onValue, offValue);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.onesLike = function () {
- this.throwIfDisposed();
- return onesLike(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.pad = function (paddings, constantValue) {
- this.throwIfDisposed();
- return pad(this, paddings, constantValue);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides) {
- this.throwIfDisposed();
- return pool(this, windowShape, poolingType, padding, dilationRate, strides);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.powStrict = function (exp) {
- this.throwIfDisposed();
- return powStrict(this, exp);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.pow = function (exp) {
- this.throwIfDisposed();
- return pow(this, exp);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.prelu = function (alpha) {
- this.throwIfDisposed();
- return prelu(this, alpha);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.prod = function (axis, keepDims) {
- this.throwIfDisposed();
- return prod(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.reciprocal = function () {
- this.throwIfDisposed();
- return reciprocal(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.relu = function () {
- this.throwIfDisposed();
- return relu(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.relu6 = function () {
- this.throwIfDisposed();
- return relu6(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reshapes the tensor into the shape of the provided tensor.
- *
- * @param x The tensor of required shape.
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.reshapeAs = function (x) {
- this.throwIfDisposed();
- return reshape(this, x.shape);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.reshape = function (shape) {
- this.throwIfDisposed();
- return reshape(this, shape);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.resizeBilinear = function (newShape2D, alignCorners) {
- this.throwIfDisposed();
- return resizeBilinear(this, newShape2D, alignCorners);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.resizeNearestNeighbor = function (newShape2D, alignCorners) {
- this.throwIfDisposed();
- return resizeNearestNeighbor(this, newShape2D, alignCorners);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.reverse = function (axis) {
- this.throwIfDisposed();
- return reverse(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.rfft = function () {
- this.throwIfDisposed();
- return rfft(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.round = function () {
- this.throwIfDisposed();
- return round(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.rsqrt = function () {
- this.throwIfDisposed();
- return rsqrt(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.selu = function () {
- this.throwIfDisposed();
- return selu(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
- this.throwIfDisposed();
- return separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sigmoid = function () {
- this.throwIfDisposed();
- return sigmoid(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sign = function () {
- this.throwIfDisposed();
- return sign(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sin = function () {
- this.throwIfDisposed();
- return sin(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sinh = function () {
- this.throwIfDisposed();
- return sinh(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.slice = function (begin, size) {
- this.throwIfDisposed();
- return slice(this, begin, size);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.softmax = function (dim) {
- this.throwIfDisposed();
- return softmax(this, dim);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.softplus = function () {
- this.throwIfDisposed();
- return softplus(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.spaceToBatchND = function (blockShape, paddings) {
- this.throwIfDisposed();
- return spaceToBatchND(this, blockShape, paddings);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.split = function (numOrSizeSplits, axis) {
- this.throwIfDisposed();
- return split(this, numOrSizeSplits, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sqrt = function () {
- this.throwIfDisposed();
- return sqrt(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.square = function () {
- this.throwIfDisposed();
- return square(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.squaredDifference = function (b) {
- this.throwIfDisposed();
- return squaredDifference(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.squaredDifferenceStrict = function (x) {
- this.throwIfDisposed();
- return squaredDifferenceStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.squeeze = function (axis) {
- this.throwIfDisposed();
- return squeeze(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.stack = function (x, axis) {
- this.throwIfDisposed();
- const tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this, ...x];
- return stack(tensorsToBeStacked, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.step = function (alpha) {
- this.throwIfDisposed();
- return step(this, alpha);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
- this.throwIfDisposed();
- return stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * @deprecated strict variants of ops have been deprecated
- */
- Tensor.prototype.subStrict = function (x) {
- this.throwIfDisposed();
- return subStrict(this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sub = function (b) {
- this.throwIfDisposed();
- return sub(this, b);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.sum = function (axis, keepDims) {
- this.throwIfDisposed();
- return sum$1(this, axis, keepDims);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.tan = function () {
- this.throwIfDisposed();
- return tan(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.tanh = function () {
- this.throwIfDisposed();
- return tanh$1(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.tile = function (reps) {
- this.throwIfDisposed();
- return tile(this, reps);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Casts the array to type `bool`
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.toBool = function () {
- this.throwIfDisposed();
- return cast(this, 'bool');
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Casts the array to type `float32`
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.toFloat = function () {
- this.throwIfDisposed();
- return cast(this, 'float32');
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Casts the array to type `int32`
- *
- * @doc {heading: 'Tensors', subheading: 'Classes'}
- */
- Tensor.prototype.toInt = function () {
- this.throwIfDisposed();
- return cast(this, 'int32');
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.topk = function (k, sorted) {
- this.throwIfDisposed();
- return topk(this, k, sorted);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.transpose = function (perm) {
- this.throwIfDisposed();
- return transpose(this, perm);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.unique = function (axis) {
- this.throwIfDisposed();
- return unique(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.unsortedSegmentSum = function (segmentIds, numSegments) {
- this.throwIfDisposed();
- return unsortedSegmentSum(this, segmentIds, numSegments);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.unstack = function (axis) {
- this.throwIfDisposed();
- return unstack(this, axis);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.where = function (condition, x) {
- this.throwIfDisposed();
- return where(condition, this, x);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- Tensor.prototype.zerosLike = function () {
- this.throwIfDisposed();
- return zerosLike(this);
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- let _epsilon;
- /**
- * Returns the value of the fuzz factor used in numeric expressions.
- */
- function epsilon() {
- if (_epsilon == null) {
- _epsilon = backend().epsilon();
- }
- return _epsilon;
- }
- /**
- * Sets the value of the fuzz factor used in numeric expressions.
- * @param e New value of epsilon.
- */
- function setEpsilon(e) {
- _epsilon = e;
- }
- /**
- * Returns the default image data format convention.
- */
- function imageDataFormat() {
- return 'channelsLast';
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Explicit error types.
- *
- * See the following link for more information about why the code includes
- * calls to setPrototypeOf:
- *
- * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
- */
- // tslint:enable
- /**
- * Equivalent of Python's AttributeError.
- */
- class AttributeError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, AttributeError.prototype);
- }
- }
- /**
- * Equivalent of Python's RuntimeError.
- */
- class RuntimeError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, RuntimeError.prototype);
- }
- }
- /**
- * Equivalent of Python's ValueError.
- */
- class ValueError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, ValueError.prototype);
- }
- }
- /**
- * Equivalent of Python's NotImplementedError.
- */
- class NotImplementedError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, NotImplementedError.prototype);
- }
- }
- /**
- * Equivalent of Python's AssertionError.
- */
- class AssertionError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, AssertionError.prototype);
- }
- }
- /**
- * Equivalent of Python's IndexError.
- */
- class IndexError extends Error {
- constructor(message) {
- super(message);
- // Set the prototype explicitly.
- Object.setPrototypeOf(this, IndexError.prototype);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // tslint:enable
- /**
- * If `value` is an Array, equivalent to Python's `value * numValues`.
- * If `value` is not an Array, equivalent to Python's `[value] * numValues`
- */
- // tslint:disable-next-line:no-any
- function pyListRepeat(value, numValues) {
- if (Array.isArray(value)) {
- // tslint:disable-next-line:no-any
- let newArray = [];
- for (let i = 0; i < numValues; i++) {
- newArray = newArray.concat(value);
- }
- return newArray;
- }
- else {
- const newArray = new Array(numValues);
- newArray.fill(value);
- return newArray;
- }
- }
- function assert$1(val, message) {
- if (!val) {
- throw new AssertionError(message);
- }
- }
- /**
- * Count the number of elements of the `array` that are equal to `reference`.
- */
- function count(array, refernce) {
- let counter = 0;
- for (const item of array) {
- if (item === refernce) {
- counter++;
- }
- }
- return counter;
- }
- /**
- * If an array is of length 1, just return the first element. Otherwise, return
- * the full array.
- * @param tensors
- */
- function singletonOrArray(xs) {
- if (xs.length === 1) {
- return xs[0];
- }
- return xs;
- }
- /**
- * Normalizes a list/tensor into a list.
- *
- * If a tensor is passed, we return
- * a list of size 1 containing the tensor.
- *
- * @param x target object to be normalized.
- */
- // tslint:disable-next-line:no-any
- function toList(x) {
- if (Array.isArray(x)) {
- return x;
- }
- return [x];
- }
- /**
- * Generate a UID for a list
- */
- // tslint:disable-next-line:no-any
- function objectListUid(objs) {
- const objectList = toList(objs);
- let retVal = '';
- for (const obj of objectList) {
- if (obj.id == null) {
- throw new ValueError(`Object ${obj} passed to objectListUid without an id`);
- }
- if (retVal !== '') {
- retVal = retVal + ', ';
- }
- retVal = `${retVal}${Math.abs(obj.id)}`;
- }
- return retVal;
- }
- /**
- * Converts string to snake-case.
- * @param name
- */
- function toSnakeCase(name) {
- const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
- const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
- /*
- If the class is private the name starts with "_" which is not secure
- for creating scopes. We prefix the name with "private" in this case.
- */
- if (insecure[0] !== '_') {
- return insecure;
- }
- return 'private' + insecure;
- }
- function toCamelCase(identifier) {
- // quick return for empty string or single character strings
- if (identifier.length <= 1) {
- return identifier;
- }
- // Check for the underscore indicating snake_case
- if (identifier.indexOf('_') === -1) {
- return identifier;
- }
- return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
- }
- // tslint:disable-next-line:no-any
- let _GLOBAL_CUSTOM_OBJECTS = {};
- function serializeKerasObject(instance) {
- if (instance === null || instance === undefined) {
- return null;
- }
- const dict = {};
- dict['className'] = instance.getClassName();
- dict['config'] = instance.getConfig();
- return dict;
- }
- /**
- * Replace ndarray-style scalar objects in serialization objects with numbers.
- *
- * Background: In some versions of tf.keras, certain scalar values in the HDF5
- * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
- * where in `num` is a plain number. This method converts such serialization
- * to a `number`.
- *
- * @param config The keras-format serialization object to be processed
- * (in place).
- */
- function convertNDArrayScalarsInConfig(config) {
- if (config == null || typeof config !== 'object') {
- return;
- }
- else if (Array.isArray(config)) {
- config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
- }
- else {
- const fields = Object.keys(config);
- for (const field of fields) {
- const value = config[field];
- if (value != null && typeof value === 'object') {
- if (!Array.isArray(value) && value['type'] === 'ndarray' &&
- typeof value['value'] === 'number') {
- config[field] = value['value'];
- }
- else {
- convertNDArrayScalarsInConfig(value);
- }
- }
- }
- }
- }
- /**
- * Deserialize a saved Keras Object
- * @param identifier either a string ID or a saved Keras dictionary
- * @param moduleObjects a list of Python class names to object constructors
- * @param customObjects a list of Python class names to object constructors
- * @param printableModuleName debug text for the object being reconstituted
- * @param fastWeightInit Optional flag to use fast weight initialization
- * during deserialization. This is applicable to cases in which
- * the initialization will be immediately overwritten by loaded weight
- * values. Default: `false`.
- * @returns a TensorFlow.js Layers object
- */
- // tslint:disable:no-any
- function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
- // tslint:enable
- if (typeof identifier === 'string') {
- const functionName = identifier;
- let fn;
- if (functionName in customObjects) {
- fn = customObjects[functionName];
- }
- else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
- fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
- }
- else {
- fn = moduleObjects[functionName];
- if (fn == null) {
- throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
- `This may be due to one of the following reasons:\n` +
- `1. The ${printableModuleName} is defined in Python, in which ` +
- `case it needs to be ported to TensorFlow.js or your JavaScript ` +
- `code.\n` +
- `2. The custom ${printableModuleName} is defined in JavaScript, ` +
- `but is not registered properly with ` +
- `tf.serialization.registerClass().`);
- // TODO(cais): Add link to tutorial page on custom layers.
- }
- }
- return fn;
- }
- else {
- // In this case we are dealing with a Keras config dictionary.
- const config = identifier;
- if (config['className'] == null || config['config'] == null) {
- throw new ValueError(`${printableModuleName}: Improper config format: ` +
- `${JSON.stringify(config)}.\n` +
- `'className' and 'config' must set.`);
- }
- const className = config['className'];
- let cls, fromConfig;
- if (className in customObjects) {
- [cls, fromConfig] = customObjects[className];
- }
- else if (className in _GLOBAL_CUSTOM_OBJECTS) {
- [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
- }
- else if (className in moduleObjects) {
- [cls, fromConfig] = moduleObjects[className];
- }
- if (cls == null) {
- throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
- `This may be due to one of the following reasons:\n` +
- `1. The ${printableModuleName} is defined in Python, in which ` +
- `case it needs to be ported to TensorFlow.js or your JavaScript ` +
- `code.\n` +
- `2. The custom ${printableModuleName} is defined in JavaScript, ` +
- `but is not registered properly with ` +
- `tf.serialization.registerClass().`);
- // TODO(cais): Add link to tutorial page on custom layers.
- }
- if (fromConfig != null) {
- // Porting notes: Instead of checking to see whether fromConfig accepts
- // customObjects, we create a customObjects dictionary and tack it on to
- // config['config'] as config['config'].customObjects. Objects can use it,
- // if they want.
- // tslint:disable-next-line:no-any
- const customObjectsCombined = {};
- for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
- customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
- }
- for (const key of Object.keys(customObjects)) {
- customObjectsCombined[key] = customObjects[key];
- }
- // Add the customObjects to config
- const nestedConfig = config['config'];
- nestedConfig['customObjects'] = customObjectsCombined;
- const backupCustomObjects = { ..._GLOBAL_CUSTOM_OBJECTS };
- for (const key of Object.keys(customObjects)) {
- _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
- }
- convertNDArrayScalarsInConfig(config['config']);
- const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
- _GLOBAL_CUSTOM_OBJECTS = { ...backupCustomObjects };
- return returnObj;
- }
- else {
- // Then `cls` may be a function returning a class.
- // In this case by convention `config` holds
- // the kwargs of the function.
- const backupCustomObjects = { ..._GLOBAL_CUSTOM_OBJECTS };
- for (const key of Object.keys(customObjects)) {
- _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
- }
- // In python this is **config['config'], for tfjs-layers we require
- // classes that use this fall-through construction method to take
- // a config interface that mimics the expansion of named parameters.
- const returnObj = new cls(config['config']);
- _GLOBAL_CUSTOM_OBJECTS = { ...backupCustomObjects };
- return returnObj;
- }
- }
- }
- /**
- * Compares two numbers for sorting.
- * @param a
- * @param b
- */
- function numberCompare(a, b) {
- return (a < b) ? -1 : ((a > b) ? 1 : 0);
- }
- /**
- * Comparison of two numbers for reverse sorting.
- * @param a
- * @param b
- */
- function reverseNumberCompare(a, b) {
- return -1 * numberCompare(a, b);
- }
- /**
- * Convert a string into the corresponding DType.
- * @param dtype
- * @returns An instance of DType.
- */
- function stringToDType(dtype) {
- switch (dtype) {
- case 'float32':
- return 'float32';
- default:
- throw new ValueError(`Invalid dtype: ${dtype}`);
- }
- }
- /**
- * Test the element-by-element equality of two Arrays of strings.
- * @param xs First array of strings.
- * @param ys Second array of strings.
- * @returns Wether the two arrays are all equal, element by element.
- */
- function stringsEqual(xs, ys) {
- if (xs == null || ys == null) {
- return xs === ys;
- }
- if (xs.length !== ys.length) {
- return false;
- }
- for (let i = 0; i < xs.length; ++i) {
- if (xs[i] !== ys[i]) {
- return false;
- }
- }
- return true;
- }
- /**
- * Get the unique elements of an array.
- * @param xs Array.
- * @returns An Array consisting of the unique elements in `xs`.
- */
- function unique$1(xs) {
- if (xs == null) {
- return xs;
- }
- const out = [];
- // TODO(cais): Maybe improve performance by sorting.
- for (const x of xs) {
- if (out.indexOf(x) === -1) {
- out.push(x);
- }
- }
- return out;
- }
- /**
- * Determine if an Object is empty (i.e., does not have own properties).
- * @param obj Object
- * @returns Whether the Object is empty.
- * @throws ValueError: If object is `null` or `undefined`.
- */
- function isObjectEmpty(obj) {
- if (obj == null) {
- throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
- }
- for (const key in obj) {
- if (obj.hasOwnProperty(key)) {
- return false;
- }
- }
- return true;
- }
- /**
- * Helper function used to build type union/enum run-time checkers.
- * @param values The list of allowed values.
- * @param label A string name for the type
- * @param value The value to test.
- * @throws ValueError: If the value is not in values nor `undefined`/`null`.
- */
- function checkStringTypeUnionValue(values, label, value) {
- if (value == null) {
- return;
- }
- if (values.indexOf(value) < 0) {
- throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
- }
- }
- /**
- * Helper function for verifying the types of inputs.
- *
- * Ensures that the elements of `x` are all of type `expectedType`.
- * Also verifies that the length of `x` is within bounds.
- *
- * @param x Object to test.
- * @param expectedType The string expected type of all of the elements in the
- * Array.
- * @param minLength Return false if x.length is less than this.
- * @param maxLength Return false if x.length is greater than this.
- * @returns true if and only if `x` is an `Array` with
- * length >= `minLength` and <= `maxLength`.
- */
- // tslint:disable:no-any
- function checkArrayTypeAndLength(x, expectedType, minLength = 0, maxLength = Infinity) {
- assert$1(minLength >= 0);
- assert$1(maxLength >= minLength);
- return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&
- x.every(e => typeof e === expectedType));
- }
- // tslint:enable:no-any
- /**
- * Assert that a value or an array of value are positive integer.
- *
- * @param value The value being asserted on. May be a single number or an array
- * of numbers.
- * @param name Name of the value, used to make the error message.
- */
- function assertPositiveInteger(value, name) {
- if (Array.isArray(value)) {
- assert(value.length > 0, () => `${name} is unexpectedly an empty array.`);
- value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
- }
- else {
- assert(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
- `${formatAsFriendlyString(value)}.`);
- }
- }
- /**
- * Format a value into a display-friendly, human-readable fashion.
- *
- * - `null` is formatted as `'null'`
- * - Strings are formated with flanking pair of quotes.
- * - Arrays are formatted with flanking pair of square brackets.
- *
- * @param value The value to display.
- * @return Formatted string.
- */
- // tslint:disable-next-line:no-any
- function formatAsFriendlyString(value) {
- if (value === null) {
- return 'null';
- }
- else if (Array.isArray(value)) {
- return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
- }
- else if (typeof value === 'string') {
- return `"${value}"`;
- }
- else {
- return `${value}`;
- }
- }
- /**
- * Returns a function `f2` (decorator) which wraps the original function
- * `f`. `f2` guarantees that `f` can be called at most once
- * every `waitMs` ms. If `f2` is called more often, it will return
- * the last returned result of `f`.
- *
- * @param f The original function `f` to wrap.
- * @param waitMs The time between two consecutive calls to `f` in ms.
- */
- function debounce(f, waitMs) {
- let lastTime = now();
- let lastResult;
- const f2 = (...args) => {
- const now$1 = now();
- if (now$1 - lastTime < waitMs) {
- return lastResult;
- }
- lastTime = now$1;
- lastResult = f(...args);
- return lastResult;
- };
- return f2;
- }
- /**
- * Returns the fusable activation given a layers identifier.
- *
- * @param activationName The layers identifier string.
- * @return The name of the fusable activation.
- */
- function mapActivationToFusedKernel(activationName) {
- if (activationName === 'relu') {
- return 'relu';
- }
- if (activationName === 'linear') {
- return 'linear';
- }
- if (activationName === 'elu') {
- return 'elu';
- }
- return null;
- }
- /**
- * Returns the cartesian product of sets of values.
- * This works the same as itertools.product in Python.
- *
- * Example:
- *
- * filters = [128, 256, 512]
- * paddings = ['same', 'valid']
- *
- * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],
- * [512, 'same'], [512, 'valid']]
- *
- * @param arrayOfValues List/array of values.
- * @return The cartesian product.
- */
- function getCartesianProductOfValues(...arrayOfValues) {
- assert$1(arrayOfValues.length > 0, 'arrayOfValues is empty');
- for (const values of arrayOfValues) {
- assert$1(Array.isArray(values), 'one of the values is not an array');
- assert$1(values.length > 0, 'one of the values is empty');
- }
- return arrayOfValues.reduce((products, values) => {
- if (products.length === 0) {
- return values.map(value => [value]);
- }
- return values
- .map(value => {
- return products.map((prevValue) => [...prevValue, value]);
- })
- .reduce((flattenedProduct, unflattenedProduct) => {
- return flattenedProduct.concat(unflattenedProduct);
- }, []);
- }, []);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Helper function used by many of the Constraints to find the L2Norms.
- */
- function calcL2Norms(w, axis) {
- return tidy(() => sqrt(sum$1(mul(w, w), axis, true)));
- }
- /**
- * Base class for functions that impose constraints on weight values
- *
- * @doc {
- * heading: 'Constraints',
- * subheading: 'Classes',
- * namespace: 'constraints'
- * }
- */
- class Constraint extends Serializable {
- getConfig() {
- return {};
- }
- }
- class MaxNorm extends Constraint {
- constructor(args) {
- super();
- this.defaultMaxValue = 2;
- this.defaultAxis = 0;
- this.maxValue =
- args.maxValue != null ? args.maxValue : this.defaultMaxValue;
- this.axis = args.axis != null ? args.axis : this.defaultAxis;
- }
- apply(w) {
- return tidy(() => {
- const norms = calcL2Norms(w, this.axis);
- const desired = clipByValue(norms, 0, this.maxValue);
- return mul(w, div(desired, add$1(epsilon(), norms)));
- });
- }
- getConfig() {
- return { maxValue: this.maxValue, axis: this.axis };
- }
- }
- /** @nocollapse */
- MaxNorm.className = 'MaxNorm';
- registerClass(MaxNorm);
- class UnitNorm extends Constraint {
- constructor(args) {
- super();
- this.defaultAxis = 0;
- this.axis = args.axis != null ? args.axis : this.defaultAxis;
- }
- apply(w) {
- return tidy(() => div(w, add$1(epsilon(), calcL2Norms(w, this.axis))));
- }
- getConfig() {
- return { axis: this.axis };
- }
- }
- /** @nocollapse */
- UnitNorm.className = 'UnitNorm';
- registerClass(UnitNorm);
- class NonNeg extends Constraint {
- apply(w) {
- return relu(w);
- }
- }
- /** @nocollapse */
- NonNeg.className = 'NonNeg';
- registerClass(NonNeg);
- class MinMaxNorm extends Constraint {
- constructor(args) {
- super();
- this.defaultMinValue = 0.0;
- this.defaultMaxValue = 1.0;
- this.defaultRate = 1.0;
- this.defaultAxis = 0;
- this.minValue =
- args.minValue != null ? args.minValue : this.defaultMinValue;
- this.maxValue =
- args.maxValue != null ? args.maxValue : this.defaultMaxValue;
- this.rate = args.rate != null ? args.rate : this.defaultRate;
- this.axis = args.axis != null ? args.axis : this.defaultAxis;
- }
- apply(w) {
- return tidy(() => {
- const norms = calcL2Norms(w, this.axis);
- const desired = add$1(mul(this.rate, clipByValue(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms));
- return mul(w, div(desired, add$1(epsilon(), norms)));
- });
- }
- getConfig() {
- return {
- minValue: this.minValue,
- maxValue: this.maxValue,
- rate: this.rate,
- axis: this.axis
- };
- }
- }
- /** @nocollapse */
- MinMaxNorm.className = 'MinMaxNorm';
- registerClass(MinMaxNorm);
- // Maps the JavaScript-like identifier keys to the corresponding registry
- // symbols.
- const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
- 'maxNorm': 'MaxNorm',
- 'minMaxNorm': 'MinMaxNorm',
- 'nonNeg': 'NonNeg',
- 'unitNorm': 'UnitNorm'
- };
- function serializeConstraint(constraint) {
- return serializeKerasObject(constraint);
- }
- function deserializeConstraint(config, customObjects = {}) {
- return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
- }
- function getConstraint(identifier) {
- if (identifier == null) {
- return null;
- }
- if (typeof identifier === 'string') {
- const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
- CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
- identifier;
- const config = { className, config: {} };
- return deserializeConstraint(config);
- }
- else if (identifier instanceof Constraint) {
- return identifier;
- }
- else {
- return deserializeConstraint(identifier);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * MaxNorm weight constraint.
- *
- * Constrains the weights incident to each hidden unit
- * to have a norm less than or equal to a desired value.
- *
- * References
- * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- * Srivastava, Hinton, et al.
- * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
- *
- * @doc {heading: 'Constraints',namespace: 'constraints'}
- */
- function maxNorm(args) {
- return new MaxNorm(args);
- }
- /**
- * Constrains the weights incident to each hidden unit to have unit norm.
- *
- * @doc {heading: 'Constraints', namespace: 'constraints'}
- */
- function unitNorm(args) {
- return new UnitNorm(args);
- }
- /**
- * Constains the weight to be non-negative.
- *
- * @doc {heading: 'Constraints', namespace: 'constraints'}
- */
- function nonNeg() {
- return new NonNeg();
- }
- /** @doc {heading: 'Constraints', namespace: 'constraints'} */
- function minMaxNorm(config) {
- return new MinMaxNorm(config);
- }
-
- var exports_constraints = /*#__PURE__*/Object.freeze({
- __proto__: null,
- maxNorm: maxNorm,
- unitNorm: unitNorm,
- nonNeg: nonNeg,
- minMaxNorm: minMaxNorm
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
- const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
- const VALID_POOL_MODE_VALUES = ['max', 'avg'];
- const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
- const VALID_SAMPLE_WEIGHT_MODES = ['temporal'];
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // A map from the requested scoped name of a Tensor to the number of Tensors
- // wanting that name so far. This allows enforcing name uniqueness by appending
- // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
- const nameMap = new Map();
- function checkDataFormat(value) {
- checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
- }
- function checkPaddingMode(value) {
- checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
- }
- function checkPoolMode(value) {
- checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
- }
- const _nameScopeStack = [];
- const _nameScopeDivider = '/';
- /**
- * Enter namescope, which can be nested.
- */
- function nameScope(name, fn) {
- _nameScopeStack.push(name);
- try {
- const val = fn();
- _nameScopeStack.pop();
- return val;
- }
- catch (e) {
- _nameScopeStack.pop();
- throw e;
- }
- }
- /**
- * Get the current namescope as a flat, concatenated string.
- */
- function currentNameScopePrefix() {
- if (_nameScopeStack.length === 0) {
- return '';
- }
- else {
- return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
- }
- }
- /**
- * Get the name a Tensor (or Variable) would have if not uniqueified.
- * @param tensorName
- * @return Scoped name string.
- */
- function getScopedTensorName(tensorName) {
- if (!isValidTensorName(tensorName)) {
- throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
- }
- return currentNameScopePrefix() + tensorName;
- }
- /**
- * Get unique names for Tensors and Variables.
- * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
- * `getScopedTensorName()`.
- * @return A unique version of the given fully scoped name.
- * If this is the first time that the scoped name is seen in this session,
- * then the given `scopedName` is returned unaltered. If the same name is
- * seen again (producing a collision), an incrementing suffix is added to the
- * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
- */
- function getUniqueTensorName(scopedName) {
- if (!isValidTensorName(scopedName)) {
- throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
- }
- if (!nameMap.has(scopedName)) {
- nameMap.set(scopedName, 0);
- }
- const index = nameMap.get(scopedName);
- nameMap.set(scopedName, nameMap.get(scopedName) + 1);
- if (index > 0) {
- const result = `${scopedName}_${index}`;
- // Mark the composed name as used in case someone wants
- // to call getUniqueTensorName("name_1").
- nameMap.set(result, 1);
- return result;
- }
- else {
- return scopedName;
- }
- }
- const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
- /**
- * Determine whether a string is a valid tensor name.
- * @param name
- * @returns A Boolean indicating whether `name` is a valid tensor name.
- */
- function isValidTensorName(name) {
- return !!name.match(tensorNameRegex);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Determine if a number is an integer.
- */
- function isInteger(x) {
- return x === parseInt(x.toString(), 10);
- }
- /**
- * Calculate the product of an array of numbers.
- * @param array The array to calculate the product over.
- * @param begin Beginning index, inclusive.
- * @param end Ending index, exclusive.
- * @return The product.
- */
- function arrayProd(array, begin, end) {
- if (begin == null) {
- begin = 0;
- }
- if (end == null) {
- end = array.length;
- }
- let prod = 1;
- for (let i = begin; i < end; ++i) {
- prod *= array[i];
- }
- return prod;
- }
- /**
- * A helper function transforms the two input types to an instance of Tensor1D,
- * so the return value can be fed directly into various TF.js Core functions.
- * @param array
- */
- function toArray1D(array) {
- array = Array.isArray(array) ? new Float32Array(array) : array;
- return tensor1d(array);
- }
- /**
- * Compute minimum value.
- * @param array
- * @return minimum value.
- */
- function min$1(array) {
- return min(toArray1D(array)).dataSync()[0];
- }
- /**
- * Compute maximum value.
- * @param array
- * @return maximum value
- */
- function max$1(array) {
- return max(toArray1D(array)).dataSync()[0];
- }
- /**
- * Compute sum of array.
- * @param array
- * @return The sum.
- */
- function sum$2(array) {
- return sum$1(toArray1D(array)).dataSync()[0];
- }
- /**
- * Compute mean of array.
- * @param array
- * @return The mean.
- */
- function mean$2(array) {
- return sum$2(array) / array.length;
- }
- /**
- * Compute variance of array.
- * @param array
- * @return The variance.
- */
- function variance(array) {
- const demeaned = sub(toArray1D(array), scalar(mean$2(array)));
- const sumSquare = sum$1(mul(demeaned, demeaned)).dataSync()[0];
- return sumSquare / array.length;
- }
- /**
- * Compute median of array.
- * @param array
- * @return The median value.
- */
- function median(array) {
- const arraySorted = array.slice().sort((a, b) => a - b);
- const lowIdx = Math.floor((arraySorted.length - 1) / 2);
- const highIdx = Math.ceil((arraySorted.length - 1) / 2);
- if (lowIdx === highIdx) {
- return arraySorted[lowIdx];
- }
- return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
- }
- /**
- * Generate an array of integers in [begin, end).
- * @param begin Beginning integer, inclusive.
- * @param end Ending integer, exclusive.
- * @returns Range array.
- * @throws ValueError, iff `end` < `begin`.
- */
- function range$1(begin, end) {
- if (end < begin) {
- throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
- }
- const out = [];
- for (let i = begin; i < end; ++i) {
- out.push(i);
- }
- return out;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // tslint:enable
- /* Setting and getting backend from deeplearn.js. */
- // Default deeplearn.js backend is WebGL (GPU).
- let backend$1 = 'webgl';
- function setBackend$1(requestedBackend) {
- setBackend(requestedBackend);
- backend$1 = requestedBackend;
- }
- function getBackend$1() {
- return backend$1;
- }
- /**
- * Indicates whether the backend is operating symbolically.
- *
- * This function will be used to determine how to interpret user code. If
- * it returns true, calls to the backend construct a symbolic graph; if
- * it returns false, calls to the backend execute immediately.
- */
- function isBackendSymbolic() {
- return false;
- }
- /**
- * Get the number of elements in a Tensor.
- * @param x The Tensor.
- * @return Number of elements in `x`.
- */
- function countParams(x) {
- const shape = x.shape;
- if (shape.length > 0) {
- return shape.reduce((a, b) => a * b);
- }
- else {
- // Scalar.
- return 1;
- }
- }
- /**
- * Casts a tensor to a different dtype and returns it.
- * @param x Input tensor.
- * @param dtype String: 'float32'|'int32'|'bool'.
- * @returns Tensor of the specified `dtype`.
- */
- function cast$1(x, dtype) {
- return x.asType(dtype);
- }
- /**
- * Adds a 1-sized dimension at index "axis".
- * @param x Input tensor.
- * @param axis Position where to add the new axis.
- * @returns Result of the dimension expansion.
- */
- function expandDims$1(x, axis = -1) {
- const outShape = x.shape.slice();
- if (axis < 0) {
- axis = outShape.length + axis + 1;
- }
- outShape.splice(axis, 0, 1);
- return x.reshape(outShape);
- }
- /**
- * Repeats a 2D tensor.
- *
- * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
- * will have shape `[samples, 2, dim]`.
- *
- * @param x Input tensor.
- * @param n Integer, number of times to repeat.
- * @returns The result of the repeat operation.
- * @throws ValueError: If input tensor is not 2D.
- */
- function repeat(x, n) {
- return tidy(() => {
- if (x.shape.length !== 2) {
- throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
- `rank-${x.shape.length} tensor.`);
- }
- const y = expandDims$1(x, 1);
- return tile$2(y, [1, n, 1]);
- });
- }
- /**
- * Flatten a Tensor into 1D.
- * @param x Input tensor.
- * @return The result of the flattening `x`.
- */
- function flatten$1(x) {
- const newShape = [arrayProd(x.shape)];
- return x.reshape(newShape);
- }
- /**
- * Turn a nD tensor into a 2D tensor with same 0th dimension.
- * In other words, it flattens each data samples of a batch.
- *
- * @param x The tensor to flatten. The rank of this tensor is required to be 2
- * or higher.
- * @return The result of the flattening.
- */
- function batchFlatten(x) {
- if (x.rank <= 1) {
- throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
- }
- const newShape = [x.shape[0], arrayProd(x.shape, 1)];
- return x.reshape(newShape);
- }
- /**
- * Do slicing along the first axis.
- * @param array input `tf.Tensor`.
- * @param start starting index, inclusive.
- * @param size size of the slice along the first axis.
- * @returns result of the slicing.
- * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
- */
- function sliceAlongFirstAxis(array, start, size) {
- return tidy(() => {
- switch (array.rank) {
- case 1:
- return slice1d(array, start, size);
- case 2:
- return slice2d(array, [start, 0], [size, array.shape[1]]);
- case 3:
- return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
- case 4:
- return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
- case 5:
- return slice(array, [start, 0, 0, 0, 0], [
- size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
- ]);
- case 6:
- return slice(array, [start, 0, 0, 0, 0, 0], [
- size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
- array.shape[5]
- ]);
- default:
- throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
- `${array.rank}`);
- }
- });
- }
- /**
- * Do slicing along the last axis.
- * @param array input `tf.Tensor`.
- * @param start starting index, inclusive.
- * @param size size of the slice along the last axis.
- * @returns result of the slicing.
- * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
- */
- function sliceAlongLastAxis(array, start, size) {
- return tidy(() => {
- switch (array.rank) {
- case 1:
- return slice1d(array, start, size);
- case 2:
- return slice2d(array, [0, start], [array.shape[0], size]);
- case 3:
- return slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
- case 4:
- return slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
- default:
- throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
- `${array.rank}`);
- }
- });
- }
- /**
- * Do slicing along the sepcified axis.
- * @param array input `tf.Tensor`.
- * @param start starting index, inclusive.
- * @param size of the slice along the chosen axis.
- * @param choose an axis.
- * @returns result of the slicing.
- * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
- */
- function sliceAlongAxis(array, start, size, axis) {
- return tidy(() => {
- switch (array.rank) {
- case 1:
- return slice1d(array, start, size);
- case 2:
- switch (axis) {
- case 1:
- return sliceAlongFirstAxis(array, start, size);
- case 2:
- return sliceAlongLastAxis(array, start, size);
- default:
- throw new ValueError(`The axis is not within the rank of the tensor ` +
- `${axis}`);
- }
- case 3:
- switch (axis) {
- case 1:
- return sliceAlongFirstAxis(array, start, size);
- case 2:
- return slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
- case 3:
- return sliceAlongLastAxis(array, start, size);
- default:
- throw new ValueError(`The axis is not within the rank of the tensor ` +
- `${axis}`);
- }
- case 4:
- switch (axis) {
- case 1:
- return sliceAlongFirstAxis(array, start, size);
- case 2:
- return slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
- case 3:
- return slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
- case 4:
- return sliceAlongLastAxis(array, start, size);
- default:
- throw new ValueError(`The axis is not within the rank of the tensor ` +
- `${axis}`);
- }
- default:
- throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
- `${array.rank}`);
- }
- });
- }
- /**
- * Concatenates a list of tensors alongside the specified axis.
- * @param tensors `Array` of tensors to concatenate.
- * @param axis Concatenation axis.
- * @returns The result of the concatenation.
- */
- function concatenate(tensors, axis = -1) {
- let rank;
- if (axis < 0) {
- rank = tensors[0].rank;
- if (rank !== 0) {
- axis = rank;
- }
- else {
- axis = 0;
- }
- }
- if (axis === tensors[0].rank) {
- // Porting Note: This is necessary because tfc.concat() requires axis to be
- // in the interval [-rank, rank).
- axis = -1;
- }
- // Porting Note: Sparse concat is not supported yet.
- return concat(tensors, axis);
- }
- /**
- * Concatenate two arrays along the first dimension.
- * @param a The 1st `tf.Tensor` to concatenate.
- * @param b The 2nd `tf.Tensor` to concatenate.
- * @returns Result of the concatenation.
- * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
- */
- function concatAlongFirstAxis(a, b) {
- switch (a.rank) {
- case 1:
- return concat1d([a, b]);
- case 2:
- return concat2d([a, b], 0);
- case 3:
- return concat3d([a, b], 0);
- case 4:
- return concat4d([a, b], 0);
- default:
- throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
- `tensor rank: ${a.rank}`);
- }
- }
- /**
- * Creates a tensor by tiling `x` by `n`.
- * @param x A tensor.
- * @param n An Array of integers or a single integer. If an Array, the length
- * must be the same as the number of dimensions in `x`. If a single integer,
- * it will be treated as an Array of length 1.
- */
- function tile$2(x, n) {
- if (!Array.isArray(n)) {
- n = [n];
- }
- if (x.rank !== n.length) {
- throw new ValueError(`The length of input n (${n.length}) does not match ` +
- `the number of dimensions in input x (${x.rank})`);
- }
- return tile(x, n);
- }
- /* Creation of random tensors. */
- /**
- * Get a tensor with normal distribution of values.
- *
- * @param shape Shape of the tensor.
- * @param mean mean value of the normal distribution.
- * @param stddev standard deviation of the normal distribution.
- * @param dtype
- * @param seed
- * @return The normal tensor.
- */
- function randomNormal$1(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
- return randomNormal(shape, mean, stddev, dtype, seed);
- }
- /* Linear Algebra */
- /**
- * Multiply two tensors and returns the result as a tensor.
- *
- * For 2D tensors, this is equivalent to matrix multiplication (matMul).
- * For tensors of higher ranks, it follows the Theano behavior,
- * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
- *
- * For N dimensions it is a sum product over the last axis of x and the
- * second-to-last of y:
- *
- * @param a A tensor of at least rank 2.
- * @param b A tensor of at least rank 2.
- * @param activation (optional) A string identifying the activation
- * function.
- * @return Result of the dot operation.
- */
- function dot$1(a, b, activation, bias) {
- if ((a.rank < 2) || (b.rank < 2)) {
- throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
- ` but got x shape = ${a.shape} and y shape = ${b.shape}`);
- }
- if (b.rank >= 3) {
- const xLastDim = a.shape.slice(-1)[0];
- const ySecondLastDim = b.shape.slice(-2)[0];
- if (xLastDim !== ySecondLastDim) {
- throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
- ` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
- ` y shape = ${b.shape}`);
- }
- }
- // Handle basic 2D x 2D case.
- if ((a.rank === 2) && (b.rank === 2)) {
- const transposeA = false;
- const transposeB = false;
- // tfc.fused.matMul only fuses certain activation functions. Unsupported
- // activation functions are treated as 'linear' activations, which is
- // equivalent to a no-op.
- return matMul$1({
- a,
- b: b,
- transposeA,
- transposeB,
- bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
- activation
- });
- }
- else {
- // Reshape x into the analogous 2D Tensor.
- const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
- const aLastDim = aFirstDims.pop();
- a = a.reshape([-1, aLastDim]);
- // Reshape y into the analogous 2D Tensor, and keep track of the
- // required dimensions to reproduce the output shape.
- const bShape = b.shape.slice();
- const bLastDim = bShape.pop();
- const ySecondLastDim = bShape.pop();
- const yOtherDims = [...bShape, bLastDim];
- // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
- // where r is the rank of y.
- const perm = Array.from({ length: b.rank }, (_, i) => {
- if (i === 0) {
- return b.rank - 2;
- }
- else if (i <= b.rank - 2) {
- return i - 1;
- }
- return i;
- });
- b = b.transpose(perm).reshape([ySecondLastDim, -1]);
- // Multiply x and y as 2D Tensors, and then reshape back to original.
- const outputShape = [...aFirstDims, ...yOtherDims];
- const transposeA = false;
- const transposeB = false;
- return matMul$1({
- a,
- b,
- transposeA,
- transposeB,
- bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
- activation
- })
- .reshape(outputShape);
- }
- }
- /**
- * Compute the sign Tensor of an input Tensor.
- *
- * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
- * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
- * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
- *
- * @param x Input `tf.Tensor`.
- * @return The sign `tf.Tensor`.
- */
- function sign$1(x) {
- // TODO(cais): Move to the core.
- return tidy(() => {
- const zerosLikeX = zerosLike(x);
- const onesLikeX = onesLike(x);
- return where(equal(x, zerosLikeX), zerosLikeX, where(greater(x, zerosLike(x)), onesLikeX, mul(-1, onesLikeX)));
- });
- }
- /**
- * Computes the one-hot representation of an integer tensor.
- * @param indices nD integer tensor of shape
- * `(batch_size, dim1, dim2, ... dim(n-1))`
- * @param numClasses Integer, number of classes to consider.
- * @returns (n + 1)D one hot representation of the input
- * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
- */
- function oneHot$1(indices, numClasses) {
- return tidy(() => {
- if (indices.rank !== 1) {
- throw new Error('Only 1D one-hot tensors are supported in the ' +
- 'deeplearn backend, at present.');
- }
- indices = indices.toInt();
- return oneHot(indices, numClasses).toFloat();
- });
- }
- /* Elementary math functions. */
- /**
- * Retrieves the elements of indices `indices` in the tensor `reference`.
- * @param reference A tensor.
- * @param indices An integer tensor of indices or an `Array` of integers.
- * @param axis Axis along which to perform the gather operation.
- * @returns The result of the gathering as a tensor.
- */
- function gather$1(reference, indices, axis) {
- return tidy(() => {
- if (Array.isArray(indices)) {
- indices = tensor1d(indices, 'int32');
- }
- else {
- indices = indices.toInt();
- }
- return gather(reference, indices, axis);
- });
- }
- /**
- * Element-wise square.
- * @param x Input tensor.
- * @return element-wise x^2
- */
- function square$1(x) {
- return mul(x, x);
- }
- /**
- * Element-wise exponentiation.
- *
- * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
- * takes advatnage of the backend's (e.g., TensorFlow's) automatic
- * conversion to tensor. Here we allow `a` to be either a number or a tensor.
- *
- * @param x The base tensor.
- * @param a The exponent, tensor or number. If a number, it is rounded to the
- * nearest integer and converted to a tensor.
- * @returns A tensor of the same shape as `x`.
- */
- function pow$1(x, a) {
- return tidy(() => {
- if (typeof (a) === 'number') {
- a = scalar(Math.round(a), 'int32');
- }
- if (a.dtype !== 'int32') {
- throw new NotImplementedError(`Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);
- }
- return pow(x, a);
- });
- }
- /**
- * Reshapes bias tensor according to rank of x.
- */
- function reshapeBias(xRank, bias, dataFormat) {
- const biasShape = bias.shape;
- if (bias.rank !== 1 && bias.rank !== xRank) {
- throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
- `; expected it to be 1 or ${xRank}`);
- }
- if (xRank === 5) {
- if (dataFormat === 'channelsFirst') {
- if (biasShape.length === 1) {
- return bias.reshape([1, biasShape[0], 1, 1, 1]);
- }
- else {
- return bias.reshape([1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
- }
- }
- else if (dataFormat === 'channelsLast') {
- if (biasShape.length === 1) {
- return bias.reshape([1, 1, 1, 1, biasShape[0]]);
- }
- else {
- return bias.reshape([1].concat(biasShape));
- }
- }
- }
- else if (xRank === 4) {
- if (dataFormat === 'channelsFirst') {
- if (biasShape.length === 1) {
- return bias.reshape([1, biasShape[0], 1, 1]);
- }
- else {
- return bias.reshape([1, biasShape[2], biasShape[0], biasShape[1]]);
- }
- }
- else if (dataFormat === 'channelsLast') {
- if (biasShape.length === 1) {
- return bias.reshape([1, 1, 1, biasShape[0]]);
- }
- else {
- return bias.reshape([1].concat(biasShape));
- }
- }
- }
- else if (xRank === 3) {
- if (dataFormat === 'channelsFirst') {
- if (biasShape.length === 1) {
- return bias.reshape([1, biasShape[0], 1]);
- }
- else {
- return bias.reshape([1, biasShape[1], biasShape[0]]);
- }
- }
- else if (dataFormat === 'channelsLast') {
- if (biasShape.length === 1) {
- return bias.reshape([1, 1, biasShape[0]]);
- }
- else {
- return bias.reshape([1].concat(biasShape));
- }
- }
- }
- else if (xRank < 3) {
- return bias;
- }
- throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
- }
- /* Neural-network operations. */
- /**
- * Add a bias to a tensor.
- *
- * @param x The tensor to add the bias to.
- * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
- * @return Result of the bias adding.
- * @throws ValueError: If the rank of `bias` is incorrect.
- */
- function biasAdd(x, bias, dataFormat) {
- return tidy(() => {
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- checkDataFormat(dataFormat);
- return x.add(reshapeBias(x.rank, bias, dataFormat));
- });
- }
- /**
- * Exponential linear unit (ELU).
- * @param x A tensor or variable to compute the activation function for.
- * @param alpha: A scalar, a scaling factor for the negative section.
- * @return Output of the ELU operation.
- */
- function elu$1(x, alpha = 1) {
- // TODO(cais): Add support for alpha values other than 1.
- if (alpha !== 1) {
- throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
- `yet.`);
- }
- return elu(x);
- }
- /**
- * Softsign of a tensor.
- *
- * Defined as x / (abs(x) + 1), element-wise.
- *
- * @param x: Input.
- * @returns Output.
- */
- function softsign(x) {
- return tidy(() => div(x, abs(x).add(1)));
- }
- /**
- * Sets entries in `x` to zero at random, while scaling the entire tensor.
- *
- * @param x input tensor.
- * @param level fraction of the entries in the tensor that will be set to 0.
- * @param noiseShape shape of randomly generated keep/drop flags, must be
- * broadcastable to the shape of `x`. Optional.
- * @param seed random seed to ensure determinism. Optional.
- * @returns Result of the dropout operation.
- */
- function dropout$1(x, level, noiseShape, seed) {
- return tidy(() => dropout(x, level, noiseShape, seed));
- }
- /**
- * Element-wise, segment-wise linear approximation of sigmoid.
- *
- * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
- * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
- *
- * @param x Input tensor.
- * @returns Output tensor.
- */
- function hardSigmoid(x) {
- return tidy(() => {
- const y = add$1(.5, mul(.2, x));
- return clipByValue(y, 0, 1);
- });
- }
- /**
- * Invoke `x` in the training phase, and `alt` otherwise.
- *
- * Porting Note: We do not create placeholder tensors for the `training`
- * boolean flag here, because there is no such thing in the TF.js imperative
- * backend.
- *
- * @param x The function to invoke iff `training` is `true`.
- * @param alt The function to invoke iff `training` is `false`.
- * @param training Boolean flag for whether training phase is active.
- * @returns The return value of `x()` if `training` is `true`, or the return
- * value of `alt()` if `training` is `false`.
- */
- function inTrainPhase(x, alt, training = false) {
- return training ? x() : alt();
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
- const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
- // We can't easily extract a string[] from the string union type, but we can
- // recapitulate the list, enforcing at compile time that the values are valid
- // and that we have the right number of them.
- /**
- * A string array of valid Initializer class names.
- *
- * This is guaranteed to match the `InitializerClassName` union type.
- */
- const initializerClassNames = [
- 'Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform',
- 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity'
- ];
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- function checkFanMode(value) {
- checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
- }
- function checkDistribution(value) {
- checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
- }
- /**
- * Initializer base class.
- *
- * @doc {
- * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
- */
- class Initializer extends Serializable {
- fromConfigUsesCustomObjects() {
- return false;
- }
- getConfig() {
- return {};
- }
- }
- class Zeros extends Initializer {
- apply(shape, dtype) {
- return zeros(shape, dtype);
- }
- }
- /** @nocollapse */
- Zeros.className = 'Zeros';
- registerClass(Zeros);
- class Ones extends Initializer {
- apply(shape, dtype) {
- return ones$1(shape, dtype);
- }
- }
- /** @nocollapse */
- Ones.className = 'Ones';
- registerClass(Ones);
- class Constant extends Initializer {
- constructor(args) {
- super();
- if (typeof args !== 'object') {
- throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
- }
- if (args.value === undefined) {
- throw new ValueError(`config must have value set but got ${args}`);
- }
- this.value = args.value;
- }
- apply(shape, dtype) {
- return tidy(() => mul(scalar(this.value), ones$1(shape, dtype)));
- }
- getConfig() {
- return {
- value: this.value,
- };
- }
- }
- /** @nocollapse */
- Constant.className = 'Constant';
- registerClass(Constant);
- class RandomUniform extends Initializer {
- constructor(args) {
- super();
- this.DEFAULT_MINVAL = -0.05;
- this.DEFAULT_MAXVAL = 0.05;
- this.minval = args.minval || this.DEFAULT_MINVAL;
- this.maxval = args.maxval || this.DEFAULT_MAXVAL;
- this.seed = args.seed;
- }
- apply(shape, dtype) {
- return randomUniform(shape, this.minval, this.maxval, dtype);
- }
- getConfig() {
- return { minval: this.minval, maxval: this.maxval, seed: this.seed };
- }
- }
- /** @nocollapse */
- RandomUniform.className = 'RandomUniform';
- registerClass(RandomUniform);
- class RandomNormal extends Initializer {
- constructor(args) {
- super();
- this.DEFAULT_MEAN = 0.;
- this.DEFAULT_STDDEV = 0.05;
- this.mean = args.mean || this.DEFAULT_MEAN;
- this.stddev = args.stddev || this.DEFAULT_STDDEV;
- this.seed = args.seed;
- }
- apply(shape, dtype) {
- dtype = dtype || 'float32';
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
- }
- return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
- }
- getConfig() {
- return { mean: this.mean, stddev: this.stddev, seed: this.seed };
- }
- }
- /** @nocollapse */
- RandomNormal.className = 'RandomNormal';
- registerClass(RandomNormal);
- class TruncatedNormal extends Initializer {
- constructor(args) {
- super();
- this.DEFAULT_MEAN = 0.;
- this.DEFAULT_STDDEV = 0.05;
- this.mean = args.mean || this.DEFAULT_MEAN;
- this.stddev = args.stddev || this.DEFAULT_STDDEV;
- this.seed = args.seed;
- }
- apply(shape, dtype) {
- dtype = dtype || 'float32';
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
- }
- return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
- }
- getConfig() {
- return { mean: this.mean, stddev: this.stddev, seed: this.seed };
- }
- }
- /** @nocollapse */
- TruncatedNormal.className = 'TruncatedNormal';
- registerClass(TruncatedNormal);
- class Identity$1 extends Initializer {
- constructor(args) {
- super();
- this.gain = args.gain != null ? args.gain : 1.0;
- }
- apply(shape, dtype) {
- return tidy(() => {
- if (shape.length !== 2 || shape[0] !== shape[1]) {
- throw new ValueError('Identity matrix initializer can only be used for' +
- ' 2D square matrices.');
- }
- else {
- return mul(this.gain, eye(shape[0]));
- }
- });
- }
- getConfig() {
- return { gain: this.gain };
- }
- }
- /** @nocollapse */
- Identity$1.className = 'Identity';
- registerClass(Identity$1);
- /**
- * Computes the number of input and output units for a weight shape.
- * @param shape Shape of weight.
- * @param dataFormat data format to use for convolution kernels.
- * Note that all kernels in Keras are standardized on the
- * CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
- * @return An length-2 array: fanIn, fanOut.
- */
- function computeFans(shape, dataFormat = 'channelsLast') {
- let fanIn;
- let fanOut;
- checkDataFormat(dataFormat);
- if (shape.length === 2) {
- fanIn = shape[0];
- fanOut = shape[1];
- }
- else if ([3, 4, 5].indexOf(shape.length) !== -1) {
- if (dataFormat === 'channelsFirst') {
- const receptiveFieldSize = arrayProd(shape, 2);
- fanIn = shape[1] * receptiveFieldSize;
- fanOut = shape[0] * receptiveFieldSize;
- }
- else if (dataFormat === 'channelsLast') {
- const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
- fanIn = shape[shape.length - 2] * receptiveFieldSize;
- fanOut = shape[shape.length - 1] * receptiveFieldSize;
- }
- }
- else {
- const shapeProd = arrayProd(shape);
- fanIn = Math.sqrt(shapeProd);
- fanOut = Math.sqrt(shapeProd);
- }
- return [fanIn, fanOut];
- }
- class VarianceScaling extends Initializer {
- /**
- * Constructor of VarianceScaling.
- * @throws ValueError for invalid value in scale.
- */
- constructor(args) {
- super();
- if (args.scale < 0.0) {
- throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
- }
- this.scale = args.scale == null ? 1.0 : args.scale;
- this.mode = args.mode == null ? 'fanIn' : args.mode;
- checkFanMode(this.mode);
- this.distribution =
- args.distribution == null ? 'normal' : args.distribution;
- checkDistribution(this.distribution);
- this.seed = args.seed;
- }
- apply(shape, dtype) {
- const fans = computeFans(shape);
- const fanIn = fans[0];
- const fanOut = fans[1];
- let scale = this.scale;
- if (this.mode === 'fanIn') {
- scale /= Math.max(1, fanIn);
- }
- else if (this.mode === 'fanOut') {
- scale /= Math.max(1, fanOut);
- }
- else {
- scale /= Math.max(1, (fanIn + fanOut) / 2);
- }
- if (this.distribution === 'normal') {
- const stddev = Math.sqrt(scale);
- dtype = dtype || 'float32';
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
- }
- return truncatedNormal(shape, 0, stddev, dtype, this.seed);
- }
- else {
- const limit = Math.sqrt(3 * scale);
- return randomUniform(shape, -limit, limit, dtype);
- }
- }
- getConfig() {
- return {
- scale: this.scale,
- mode: this.mode,
- distribution: this.distribution,
- seed: this.seed
- };
- }
- }
- /** @nocollapse */
- VarianceScaling.className = 'VarianceScaling';
- registerClass(VarianceScaling);
- class GlorotUniform extends VarianceScaling {
- /**
- * Constructor of GlorotUniform
- * @param scale
- * @param mode
- * @param distribution
- * @param seed
- */
- constructor(args) {
- super({
- scale: 1.0,
- mode: 'fanAvg',
- distribution: 'uniform',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, GlorotUniform is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- GlorotUniform.className = 'GlorotUniform';
- registerClass(GlorotUniform);
- class GlorotNormal extends VarianceScaling {
- /**
- * Constructor of GlorotNormal.
- * @param scale
- * @param mode
- * @param distribution
- * @param seed
- */
- constructor(args) {
- super({
- scale: 1.0,
- mode: 'fanAvg',
- distribution: 'normal',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, GlorotNormal is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- GlorotNormal.className = 'GlorotNormal';
- registerClass(GlorotNormal);
- class HeNormal extends VarianceScaling {
- constructor(args) {
- super({
- scale: 2.0,
- mode: 'fanIn',
- distribution: 'normal',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, HeNormal is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- HeNormal.className = 'HeNormal';
- registerClass(HeNormal);
- class HeUniform extends VarianceScaling {
- constructor(args) {
- super({
- scale: 2.0,
- mode: 'fanIn',
- distribution: 'uniform',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, HeUniform is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- HeUniform.className = 'HeUniform';
- registerClass(HeUniform);
- class LeCunNormal extends VarianceScaling {
- constructor(args) {
- super({
- scale: 1.0,
- mode: 'fanIn',
- distribution: 'normal',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, LeCunNormal is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- LeCunNormal.className = 'LeCunNormal';
- registerClass(LeCunNormal);
- class LeCunUniform extends VarianceScaling {
- constructor(args) {
- super({
- scale: 1.0,
- mode: 'fanIn',
- distribution: 'uniform',
- seed: args == null ? null : args.seed
- });
- }
- getClassName() {
- // In Python Keras, LeCunUniform is not a class, but a helper method
- // that creates a VarianceScaling object. Use 'VarianceScaling' as
- // class name to be compatible with that.
- return VarianceScaling.className;
- }
- }
- /** @nocollapse */
- LeCunUniform.className = 'LeCunNormal';
- registerClass(LeCunUniform);
- class Orthogonal extends Initializer {
- constructor(args) {
- super();
- this.DEFAULT_GAIN = 1;
- this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
- this.seed = args.seed;
- if (this.seed != null) {
- throw new NotImplementedError('Random seed is not implemented for Orthogonal Initializer yet.');
- }
- }
- apply(shape, dtype) {
- return tidy(() => {
- if (shape.length < 2) {
- throw new NotImplementedError('Shape must be at least 2D.');
- }
- if (shape[0] * shape[1] > 2000) {
- console.warn(`Orthogonal initializer is being called on a matrix with more ` +
- `than 2000 (${shape[0] * shape[1]}) elements: ` +
- `Slowness may result.`);
- }
- // TODO(cais): Add seed support.
- const normalizedShape = shape[0] > shape[1] ? [shape[1], shape[0]] : shape;
- const a = randomNormal$1(normalizedShape, 0, 1, 'float32');
- let q = linalg.gramSchmidt(a);
- if (shape[0] > shape[1]) {
- q = q.transpose();
- }
- return mul(this.gain, q);
- });
- }
- getConfig() {
- return {
- gain: this.gain,
- seed: this.seed,
- };
- }
- }
- /** @nocollapse */
- Orthogonal.className = 'Orthogonal';
- registerClass(Orthogonal);
- // Maps the JavaScript-like identifier keys to the corresponding registry
- // symbols.
- const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
- 'constant': 'Constant',
- 'glorotNormal': 'GlorotNormal',
- 'glorotUniform': 'GlorotUniform',
- 'heNormal': 'HeNormal',
- 'heUniform': 'HeUniform',
- 'identity': 'Identity',
- 'leCunNormal': 'LeCunNormal',
- 'leCunUniform': 'LeCunUniform',
- 'ones': 'Ones',
- 'orthogonal': 'Orthogonal',
- 'randomNormal': 'RandomNormal',
- 'randomUniform': 'RandomUniform',
- 'truncatedNormal': 'TruncatedNormal',
- 'varianceScaling': 'VarianceScaling',
- 'zeros': 'Zeros'
- };
- function deserializeInitializer(config, customObjects = {}) {
- return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
- }
- function serializeInitializer(initializer) {
- return serializeKerasObject(initializer);
- }
- function getInitializer(identifier) {
- if (typeof identifier === 'string') {
- const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
- INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
- identifier;
- /* We have four 'helper' classes for common initializers that
- all get serialized as 'VarianceScaling' and shouldn't go through
- the deserializeInitializer pathway. */
- if (className === 'GlorotNormal') {
- return new GlorotNormal();
- }
- else if (className === 'GlorotUniform') {
- return new GlorotUniform();
- }
- else if (className === 'HeNormal') {
- return new HeNormal();
- }
- else if (className === 'HeUniform') {
- return new HeUniform();
- }
- else if (className === 'LeCunNormal') {
- return new LeCunNormal();
- }
- else if (className === 'LeCunUniform') {
- return new LeCunUniform();
- }
- else {
- const config = {};
- config['className'] = className;
- config['config'] = {};
- return deserializeInitializer(config);
- }
- }
- else if (identifier instanceof Initializer) {
- return identifier;
- }
- else {
- return deserializeInitializer(identifier);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Initializer that generates tensors initialized to 0.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function zeros$1() {
- return new Zeros();
- }
- /**
- * Initializer that generates tensors initialized to 1.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function ones$2() {
- return new Ones();
- }
- /**
- * Initializer that generates values initialized to some constant.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function constant(args) {
- return new Constant(args);
- }
- /**
- * Initializer that generates random values initialized to a uniform
- * distribution.
- *
- * Values will be distributed uniformly between the configured minval and
- * maxval.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function randomUniform$1(args) {
- return new RandomUniform(args);
- }
- /**
- * Initializer that generates random values initialized to a normal
- * distribution.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function randomNormal$2(args) {
- return new RandomNormal(args);
- }
- /**
- * Initializer that generates random values initialized to a truncated normal.
- * distribution.
- *
- * These values are similar to values from a `RandomNormal` except that values
- * more than two standard deviations from the mean are discarded and re-drawn.
- * This is the recommended initializer for neural network weights and filters.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function truncatedNormal$1(args) {
- return new TruncatedNormal(args);
- }
- /**
- * Initializer that generates the identity matrix.
- * Only use for square 2D matrices.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function identity(args) {
- return new Identity$1(args);
- }
- /**
- * Initializer capable of adapting its scale to the shape of weights.
- * With distribution=NORMAL, samples are drawn from a truncated normal
- * distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
- * - number of input units in the weight tensor, if mode = FAN_IN.
- * - number of output units, if mode = FAN_OUT.
- * - average of the numbers of input and output units, if mode = FAN_AVG.
- * With distribution=UNIFORM,
- * samples are drawn from a uniform distribution
- * within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
- *
- * @doc {heading: 'Initializers',namespace: 'initializers'}
- */
- function varianceScaling(config) {
- return new VarianceScaling(config);
- }
- /**
- * Glorot uniform initializer, also called Xavier uniform initializer.
- * It draws samples from a uniform distribution within [-limit, limit]
- * where `limit` is `sqrt(6 / (fan_in + fan_out))`
- * where `fan_in` is the number of input units in the weight tensor
- * and `fan_out` is the number of output units in the weight tensor
- *
- * Reference:
- * Glorot & Bengio, AISTATS 2010
- * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function glorotUniform(args) {
- return new GlorotUniform(args);
- }
- /**
- * Glorot normal initializer, also called Xavier normal initializer.
- * It draws samples from a truncated normal distribution centered on 0
- * with `stddev = sqrt(2 / (fan_in + fan_out))`
- * where `fan_in` is the number of input units in the weight tensor
- * and `fan_out` is the number of output units in the weight tensor.
- *
- * Reference:
- * Glorot & Bengio, AISTATS 2010
- * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function glorotNormal(args) {
- return new GlorotNormal(args);
- }
- /**
- * He normal initializer.
- *
- * It draws samples from a truncated normal distribution centered on 0
- * with `stddev = sqrt(2 / fanIn)`
- * where `fanIn` is the number of input units in the weight tensor.
- *
- * Reference:
- * He et al., http://arxiv.org/abs/1502.01852
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function heNormal(args) {
- return new HeNormal(args);
- }
- /**
- * He uniform initializer.
- *
- * It draws samples from a uniform distribution within [-limit, limit]
- * where `limit` is `sqrt(6 / fan_in)`
- * where `fanIn` is the number of input units in the weight tensor.
- *
- * Reference:
- * He et al., http://arxiv.org/abs/1502.01852
- *
- * @doc {heading: 'Initializers',namespace: 'initializers'}
- */
- function heUniform(args) {
- return new HeUniform(args);
- }
- /**
- * LeCun normal initializer.
- *
- * It draws samples from a truncated normal distribution centered on 0
- * with `stddev = sqrt(1 / fanIn)`
- * where `fanIn` is the number of input units in the weight tensor.
- *
- * References:
- * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- * [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function leCunNormal(args) {
- return new LeCunNormal(args);
- }
- /**
- * LeCun uniform initializer.
- *
- * It draws samples from a uniform distribution in the interval
- * `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
- * where `fanIn` is the number of input units in the weight tensor.
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function leCunUniform(args) {
- return new LeCunUniform(args);
- }
- /**
- * Initializer that generates a random orthogonal matrix.
- *
- * Reference:
- * [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
- *
- * @doc {heading: 'Initializers', namespace: 'initializers'}
- */
- function orthogonal(args) {
- return new Orthogonal(args);
- }
-
- var exports_initializers = /*#__PURE__*/Object.freeze({
- __proto__: null,
- zeros: zeros$1,
- ones: ones$2,
- constant: constant,
- randomUniform: randomUniform$1,
- randomNormal: randomNormal$2,
- truncatedNormal: truncatedNormal$1,
- identity: identity,
- varianceScaling: varianceScaling,
- glorotUniform: glorotUniform,
- glorotNormal: glorotNormal,
- heNormal: heNormal,
- heUniform: heUniform,
- leCunNormal: leCunNormal,
- leCunUniform: leCunUniform,
- orthogonal: orthogonal
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Utilities related to persistent state in the backend.
- */
- /**
- * An ID to track `tf.SymbolicTensor`s and derived classes.
- * Required in different places in engine/topology.ts to identify unique
- * tensors.
- */
- let _nextUniqueTensorId = 0;
- function getNextUniqueTensorId() {
- return _nextUniqueTensorId++;
- }
- const _uidPrefixes = {};
- /**
- * Provides a unique UID given a string prefix.
- *
- * @param prefix
- */
- function getUid(prefix = '') {
- if (!(prefix in _uidPrefixes)) {
- _uidPrefixes[prefix] = 0;
- }
- _uidPrefixes[prefix] += 1;
- return prefix + _uidPrefixes[prefix].toString();
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // tslint:enable
- /**
- * Determine whether the input is an Array of Shapes.
- */
- function isArrayOfShapes(x) {
- return Array.isArray(x) && Array.isArray(x[0]);
- }
- /**
- * Special case of normalizing shapes to lists.
- *
- * @param x A shape or list of shapes to normalize into a list of Shapes.
- * @return A list of Shapes.
- */
- function normalizeShapeList(x) {
- if (x.length === 0) {
- return [];
- }
- if (!Array.isArray(x[0])) {
- return [x];
- }
- return x;
- }
- /**
- * Helper function to obtain exactly one Tensor.
- * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
- * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
- * @throws ValueError: If `xs` is an `Array` and its length is not 1.
- */
- function getExactlyOneTensor(xs) {
- let x;
- if (Array.isArray(xs)) {
- if (xs.length !== 1) {
- throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
- }
- x = xs[0];
- }
- else {
- x = xs;
- }
- return x;
- }
- /**
- * Helper function to obtain exactly on instance of Shape.
- *
- * @param shapes Input single `Shape` or Array of `Shape`s.
- * @returns If input is a single `Shape`, return it unchanged. If the input is
- * an `Array` containing exactly one instance of `Shape`, return the instance.
- * Otherwise, throw a `ValueError`.
- * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
- * 1.
- */
- function getExactlyOneShape(shapes) {
- if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
- if (shapes.length === 1) {
- shapes = shapes;
- return shapes[0];
- }
- else {
- throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
- }
- }
- else {
- return shapes;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Count the elements in an Array of LayerVariables.
- *
- * @param weights: The LayerVariables of which the constituent numbers are to
- * be counted.
- * @returns A count of the elements in all the LayerVariables
- */
- function countParamsInWeights(weights) {
- let count = 0;
- for (const weight of weights) {
- if (weight.shape.length === 0) {
- count += 1;
- }
- else {
- count += weight.shape.reduce((a, b) => a * b);
- }
- }
- return count;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
- /**
- * A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a
- * dtype and shape, but its value is mutable. The value is itself represented
- * as a`tf.Tensor`, and can be read with the `read()` method and updated with
- * the `write()` method.
- */
- class LayerVariable {
- /**
- * Construct Variable from a `tf.Tensor`.
- *
- * If not explicitly named, the Variable will be given a name with the
- * prefix 'Variable'. Variable names are unique. In the case of name
- * collision, suffixies '_' will be added to the name.
- *
- * @param val Initial value of the Variable.
- * @param name Name of the variable. If `null` or `undefined` is provided, it
- * will default a name with the prefix 'Variable'.
- * @param constraint Optional, projection function to be applied to the
- * variable after optimize updates
- * @throws ValueError if `name` is `null` or `undefined`.
- */
- constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) {
- this.dtype = dtype == null ? 'float32' : dtype;
- this.shape = val.shape;
- this.id = getNextUniqueTensorId();
- name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
- this.originalName = getScopedTensorName(name);
- this.name = getUniqueTensorName(this.originalName);
- this.trainable_ = trainable;
- this.constraint = constraint;
- this.val = variable(val, this.trainable_, this.name, this.dtype);
- }
- /**
- * Get a snapshot of the Variable's value.
- *
- * The returned value is a snapshot of the Variable's value at the time of
- * the invocation. Future mutations in the value of the tensor will only
- * be reflected by future calls to this method.
- */
- read() {
- this.assertNotDisposed();
- return this.val;
- }
- /**
- * Update the value of the Variable.
- *
- * @param newVal: The new value to update to. Must be consistent with the
- * dtype and shape of the Variable.
- * @return This Variable.
- */
- write(newVal) {
- // TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match.
- this.assertNotDisposed();
- checkShapesMatch(this.val, newVal);
- // Skip updating if this is the exact same tensor.
- if (this.val.id !== newVal.id) {
- this.val.assign(newVal);
- if (this.constraint != null) {
- this.val.assign(this.constraint.apply(this.val));
- }
- }
- return this;
- }
- /**
- * Dispose this LayersVariable instance from memory.
- */
- dispose() {
- this.assertNotDisposed();
- this.val.dispose();
- }
- assertNotDisposed() {
- if (this.val.isDisposed) {
- throw new Error(`LayersVariable ${this.name} is already disposed.`);
- }
- }
- get trainable() {
- return this.trainable_;
- }
- set trainable(trainable) {
- this.trainable_ = trainable;
- this.val.trainable = trainable;
- }
- }
- function checkShapesMatch(x, y) {
- if (x.shape.toString() !== y.shape.toString()) {
- throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' +
- JSON.stringify(y.shape));
- }
- }
- /**
- * Create a Variable.
- * @param x The initial value of the `Variable`.
- * @param dtype optional, the type of the variable.
- * @param name optional, the name of the variable, default provided by
- * Variable.
- * @param constraint optional, a constraint to be applied after every update.
- * @return The newly instantiated `Variable`.
- */
- function variable$1(x, dtype, name, constraint) {
- return new LayerVariable(x, dtype, name, true, constraint);
- }
- /**
- * Instantiates an all-zeros Variable and returns it.
- *
- * @param shape Shape of the tensor.
- * @param dtype DType of the tensor.
- * @param name Name of the tensor.
- * @return An all-zero Variable.
- */
- function zerosVariable(shape, dtype, name) {
- // TODO(cais): Implement logic for dtype.
- return new LayerVariable(zeros(shape), dtype, name);
- }
- /**
- * Instantiates an all-zeros tensor of the same shape as another tensor.
- *
- * @param x The other tensor.
- * @param dtype DType of the tensor.
- * @param name Name of the tensor.
- * @return A newly instantiated Variable.
- */
- function zerosLike$1(x, dtype, name) {
- return new LayerVariable(zerosLike(x), dtype, name);
- }
- /**
- * Instantiates an all-ones tensor and returns it.
- *
- * @param shape Shape of the tensor.
- * @param dtype DType of the tensor.
- * @param name Name of the tensor.
- * @return An all-ones Variable.
- */
- function onesVariable(shape, dtype, name) {
- // TODO(cais): Implement logic for dtype.
- const allocated = ones$1(shape);
- return new LayerVariable(allocated, dtype, name);
- }
- /**
- * Instantiates an all-ones tensor of the same shape as another tensor.
- *
- * @param x The other tensor.
- * @param dtype DType of the tensor.
- * @param name Name of the tensor.
- * @return A newly instantiated Variable.
- */
- function onesLike$1(x, dtype, name) {
- const allocated = onesLike(x);
- return new LayerVariable(allocated, dtype, name);
- }
- /**
- * Instantiate an identity matrix and returns it, as a Variable
- *
- * @param size Number of rows/columns.
- * @param dtype Data type of returned Variable.
- * @param name Name of returned Variable.
- * @return A Variable, an identity matrix.
- */
- function eyeVariable(size, dtype, name) {
- return new LayerVariable(eye(size), dtype, name);
- }
- /**
- * Get a Variable with uniform distribution of values.
- * @param shape Shape of the tensor.
- * @param minval Lower bound of the uniform distribution.
- * @param maxval Upper bound of the uniform distribution.
- * @param dtype
- * @param seed
- * @param name Optional name.
- * @return The uniform-random Variable.
- */
- function randomUniformVariable(shape, minval, maxval, dtype, seed, name = 'randomUniform') {
- return new LayerVariable(randomUniform(shape, minval, maxval, dtype), dtype, name);
- }
- /**
- * Get a Variable with truncated-normal distribution of values.
- * @param shape Shape of the tensor.
- * @param mean mean value of the normal distribution.
- * @param stddev standard deviation of the normal distribution.
- * @param dtype
- * @param seed
- * @param name Optional name.
- * @return The truncated-normal-random Variable.
- */
- function truncatedNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'truncatedNormal') {
- // TODO(cais): Implement logic for dtype and seed once they are supported
- // by deeplearn.js.
- dtype = dtype || 'float32';
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
- }
- return new LayerVariable(truncatedNormal(shape, mean, stddev, dtype, seed), dtype, name);
- }
- /**
- * Get a Variable with normal distribution of values.
- * @param shape Shape of the tensor.
- * @param mean mean value of the normal distribution.
- * @param stddev standard deviation of the normal distribution.
- * @param dtype
- * @param seed
- * @param name Optional name.
- * @return The truncated-normal-random Variable.
- */
- function randomNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'randomNormal') {
- dtype = dtype || 'float32';
- if (dtype !== 'float32' && dtype !== 'int32') {
- throw new NotImplementedError(`randomNormalVariable does not support dType ${dtype}.`);
- }
- return new LayerVariable(randomNormal(shape, mean, stddev, dtype, seed), dtype, name);
- }
- /**
- * Update the value of a Variable.
- * @param x The Variable to be updated.
- * @param xNew The new value to update to.
- * @return The Variable updated.
- */
- function update(x, xNew) {
- return x.write(xNew);
- }
- /**
- * Update the value of a Variable by adding an increment.
- * @param x The Variable to be updated.
- * @param increment The incrment to add to `x`.
- * @return The Variable updated.
- */
- function updateAdd(x, increment) {
- return x.write(add$1(x.read(), increment));
- }
- /**
- * Update the value of a Variable by subtracting a decrement.
- * @param x The Variable to be updated.
- * @param decrement The decrement to subtract from `x`.
- * @return The Variable updated.
- */
- function updateSub(x, decrement) {
- return x.write(sub(x.read(), decrement));
- }
- /**
- * Get the values of an array of Variables.
- *
- * @param tensors An `Array` of `Variable`s to get the values of.
- * @return The values of the inputs, as an `Array` of`tf.Tensor`s.
- */
- function batchGetValue(xs) {
- return xs.map(x => x.read());
- }
- /**
- * Update the value of multiple Variables at once.
- *
- * @param variablesAndValues An `Array`, each element is of type
- * [Variable, Tensor]. The first item is the
- * `Variable` of which the value is to be updated. The second item
- * carries the new value.
- */
- function batchSetValue(variablesAndValues) {
- variablesAndValues.forEach(variableAndValue => {
- const variable = variableAndValue[0];
- variable.write(variableAndValue[1]);
- });
- }
- /**
- * Returns the gradients of `variables` w.r.t. the return value of `lossFn`.
- * @param lossFn A function which returns a Scalar to be used as the function
- * value (i.e., numerator) for differentiation.
- * @param variables List of variables to be used as the independent variables
- * (i.e., denominator) for differentiation.
- * @returns An Array of gradients tensors.
- */
- function gradients(lossFn, variables) {
- // TODO(cais): The return type signature can be simplified if deeplearn makes
- // the corresponding type public.
- const variableList = variables.map(variable => variable.read());
- const valudAndGrads = variableGrads(lossFn, variableList);
- return variables.map(variable => valudAndGrads.grads[variable.name]);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Specifies the ndim, dtype and shape of every input to a layer.
- *
- * Every layer should expose (if appropriate) an `inputSpec` attribute:
- * a list of instances of InputSpec (one per input tensor).
- *
- * A null entry in a shape is compatible with any dimension,
- * a null shape is compatible with any shape.
- */
- class InputSpec {
- constructor(args) {
- this.dtype = args.dtype;
- this.shape = args.shape;
- /*
- TODO(michaelterry): Could throw error if ndim and shape are both defined
- (then backport).
- */
- if (args.shape != null) {
- this.ndim = args.shape.length;
- }
- else {
- this.ndim = args.ndim;
- }
- this.maxNDim = args.maxNDim;
- this.minNDim = args.minNDim;
- this.axes = args.axes || {};
- }
- }
- /**
- * `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value.
- *
- * They are most often encountered when building a graph of `Layer`s for a
- * a `tf.LayersModel` and the input data's shape, but not values are known.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- class SymbolicTensor {
- /**
- *
- * @param dtype
- * @param shape
- * @param sourceLayer The Layer that produced this symbolic tensor.
- * @param inputs The inputs passed to sourceLayer's __call__() method.
- * @param nodeIndex
- * @param tensorIndex
- * @param callArgs The keyword arguments passed to the __call__() method.
- * @param name
- * @param outputTensorIndex The index of this tensor in the list of outputs
- * returned by apply().
- */
- constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
- this.dtype = dtype;
- this.shape = shape;
- this.sourceLayer = sourceLayer;
- this.inputs = inputs;
- this.callArgs = callArgs;
- this.outputTensorIndex = outputTensorIndex;
- this.id = getNextUniqueTensorId();
- if (name != null) {
- this.originalName = getScopedTensorName(name);
- this.name = getUniqueTensorName(this.originalName);
- }
- this.rank = shape.length;
- }
- }
- let _nextNodeID = 0;
- /**
- * A `Node` describes the connectivity between two layers.
- *
- * Each time a layer is connected to some new input,
- * a node is added to `layer.inboundNodes`.
- *
- * Each time the output of a layer is used by another layer,
- * a node is added to `layer.outboundNodes`.
- *
- * `nodeIndices` and `tensorIndices` are basically fine-grained coordinates
- * describing the origin of the `inputTensors`, verifying the following:
- *
- * `inputTensors[i] ==
- * inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[
- * tensorIndices[i]]`
- *
- * A node from layer A to layer B is added to:
- * A.outboundNodes
- * B.inboundNodes
- */
- class Node {
- constructor(args,
- // TODO(michaelterry): Define actual type for this.
- callArgs) {
- this.callArgs = callArgs;
- this.id = _nextNodeID++;
- /*
- Layer instance (NOT a list).
- this is the layer that takes a list of input tensors
- and turns them into a list of output tensors.
- the current node will be added to
- the inboundNodes of outboundLayer.
- */
- this.outboundLayer = args.outboundLayer;
- /*
- The following 3 properties describe where
- the input tensors come from: which layers,
- and for each layer, which node and which
- tensor output of each node.
- */
- // List of layer instances.
- this.inboundLayers = args.inboundLayers;
- // List of integers, 1:1 mapping with inboundLayers.
- this.nodeIndices = args.nodeIndices;
- // List of integers, 1:1 mapping with inboundLayers.
- this.tensorIndices = args.tensorIndices;
- /*
- Following 2 properties:
- tensor inputs and outputs of outboundLayer.
- */
- // List of tensors. 1:1 mapping with inboundLayers.
- this.inputTensors = args.inputTensors;
- // List of tensors, created by outboundLayer.call().
- this.outputTensors = args.outputTensors;
- /*
- Following 2 properties: input and output masks.
- List of tensors, 1:1 mapping with inputTensor.
- */
- this.inputMasks = args.inputMasks;
- // List of tensors, created by outboundLayer.computeMask().
- this.outputMasks = args.outputMasks;
- // Following 2 properties: input and output shapes.
- // List of shape tuples, shapes of inputTensors.
- this.inputShapes = args.inputShapes;
- // List of shape tuples, shapes of outputTensors.
- this.outputShapes = args.outputShapes;
- // Add nodes to all layers involved.
- for (const layer of args.inboundLayers) {
- if (layer != null) {
- layer.outboundNodes.push(this);
- }
- }
- args.outboundLayer.inboundNodes.push(this);
- }
- getConfig() {
- const inboundNames = [];
- for (const layer of this.inboundLayers) {
- if (layer != null) {
- inboundNames.push(layer.name);
- }
- else {
- inboundNames.push(null);
- }
- }
- return {
- outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
- inboundLayers: inboundNames,
- nodeIndices: this.nodeIndices,
- tensorIndices: this.tensorIndices
- };
- }
- }
- let _nextLayerID = 0;
- /**
- * A layer is a grouping of operations and weights that can be composed to
- * create a `tf.LayersModel`.
- *
- * Layers are constructed by using the functions under the
- * [tf.layers](#Layers-Basic) namespace.
- *
- * @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'}
- */
- class Layer extends Serializable {
- constructor(args = {}) {
- super();
- this._callHook = null;
- this._addedWeightNames = [];
- // Porting Notes: PyKeras does not have this property in this base Layer
- // class. Instead lets Layer subclass set it dynamically and checks the
- // value with `hasattr`. In tfjs-layers, we let this be a member of this
- // base class.
- this._stateful = false;
- this.id = _nextLayerID++;
- this.activityRegularizer = null;
- this.inputSpec = null;
- this.supportsMasking = false;
- // These properties will be set upon call of this.build()
- this._trainableWeights = [];
- this._nonTrainableWeights = [];
- this._losses = [];
- this._updates = [];
- this._built = false;
- /*
- These lists will be filled via successive calls
- to this.addInboundNode().
- */
- this.inboundNodes = [];
- this.outboundNodes = [];
- let name = args.name;
- if (!name) {
- const prefix = this.getClassName();
- name = toSnakeCase(prefix) + '_' + getUid(prefix);
- }
- this.name = name;
- this.trainable_ = args.trainable == null ? true : args.trainable;
- if (args.inputShape != null || args.batchInputShape != null) {
- /*
- In this case we will later create an input layer
- to insert before the current layer
- */
- let batchInputShape;
- if (args.batchInputShape != null) {
- batchInputShape = args.batchInputShape;
- }
- else if (args.inputShape != null) {
- let batchSize = null;
- if (args.batchSize != null) {
- batchSize = args.batchSize;
- }
- batchInputShape = [batchSize].concat(args.inputShape);
- }
- this.batchInputShape = batchInputShape;
- // Set dtype.
- let dtype = args.dtype;
- if (dtype == null) {
- dtype = args.inputDType;
- }
- if (dtype == null) {
- dtype = 'float32';
- }
- this.dtype = dtype;
- }
- if (args.weights != null) {
- this.initialWeights = args.weights;
- }
- else {
- this.initialWeights = null;
- }
- // The value of `_refCount` is initialized to null. When the layer is used
- // in a symbolic way for the first time, it will be set to 1.
- this._refCount = null;
- this.fastWeightInitDuringBuild = false;
- }
- /**
- * Converts a layer and its index to a unique (immutable type) name.
- * This function is used internally with `this.containerNodes`.
- * @param layer The layer.
- * @param nodeIndex The layer's position (e.g. via enumerate) in a list of
- * nodes.
- *
- * @returns The unique name.
- */
- static nodeKey(layer, nodeIndex) {
- return layer.name + '_ib-' + nodeIndex.toString();
- }
- /**
- * Returns this.inboundNode at index nodeIndex.
- *
- * Porting note: This is a replacement for _get_node_attribute_at_index()
- * @param nodeIndex
- * @param attrName The name of the attribute related to request for this node.
- */
- getNodeAtIndex(nodeIndex, attrName) {
- if (this.inboundNodes.length === 0) {
- throw new RuntimeError('The layer has never been called ' +
- `and thus has no defined ${attrName}.`);
- }
- if (this.inboundNodes.length <= nodeIndex) {
- throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` +
- `but the layer has only ${this.inboundNodes.length} inbound nodes.`);
- }
- return this.inboundNodes[nodeIndex];
- }
- /**
- * Retrieves the input tensor(s) of a layer at a given node.
- *
- * @param nodeIndex Integer, index of the node from which to retrieve the
- * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
- * was called.
- *
- * @return A tensor (or list of tensors if the layer has multiple inputs).
- */
- getInputAt(nodeIndex) {
- return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
- }
- /**
- * Retrieves the output tensor(s) of a layer at a given node.
- *
- * @param nodeIndex Integer, index of the node from which to retrieve the
- * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
- * was called.
- *
- * @return A tensor (or list of tensors if the layer has multiple outputs).
- */
- getOutputAt(nodeIndex) {
- return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
- }
- // Properties
- /**
- * Retrieves the input tensor(s) of a layer.
- *
- * Only applicable if the layer has exactly one inbound node,
- * i.e. if it is connected to one incoming layer.
- *
- * @return Input tensor or list of input tensors.
- *
- * @exception AttributeError if the layer is connected to more than one
- * incoming layers.
- */
- get input() {
- if (this.inboundNodes.length > 1) {
- throw new AttributeError(`Layer ${this.name}` +
- ' has multiple inbound nodes, ' +
- 'hence the notion of "layer input" ' +
- 'is ill-defined. ' +
- 'Use `getInputAt(nodeIndex)` instead.');
- }
- else if (this.inboundNodes.length === 0) {
- throw new AttributeError(`Layer ${this.name}` +
- ' is not connected, no input to return.');
- }
- return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
- }
- /**
- * Retrieves the output tensor(s) of a layer.
- *
- * Only applicable if the layer has exactly one inbound node,
- * i.e. if it is connected to one incoming layer.
- *
- * @return Output tensor or list of output tensors.
- *
- * @exception AttributeError if the layer is connected to more than one
- * incoming layers.
- */
- get output() {
- if (this.inboundNodes.length === 0) {
- throw new AttributeError(`Layer ${this.name}` +
- ' has no inbound nodes.');
- }
- if (this.inboundNodes.length > 1) {
- throw new AttributeError(`Layer ${this.name}` +
- ' has multiple inbound nodes, ' +
- 'hence the notion of "layer output" ' +
- 'is ill-defined. ' +
- 'Use `getOutputAt(nodeIndex)` instead.');
- }
- return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
- }
- get losses() {
- return this._losses;
- }
- /**
- * Retrieves the Layer's current loss values.
- *
- * Used for regularizers during training.
- */
- calculateLosses() {
- // Porting Node: This is an augmentation to Layer.loss in PyKeras.
- // In PyKeras, Layer.loss returns symbolic tensors. Here a concrete
- // Tensor (specifically Scalar) values are returned. This is due to the
- // imperative backend.
- return this.losses.map(lossFn => lossFn());
- }
- get updates() {
- return this._updates;
- }
- get built() {
- return this._built;
- }
- set built(built) {
- this._built = built;
- }
- get trainable() {
- return this.trainable_;
- }
- set trainable(trainable) {
- this._trainableWeights.forEach(w => w.trainable = trainable);
- this.trainable_ = trainable;
- }
- get trainableWeights() {
- if (this.trainable_) {
- return this._trainableWeights.filter(w => w.trainable);
- }
- else {
- return [];
- }
- }
- set trainableWeights(weights) {
- this._trainableWeights = weights;
- }
- get nonTrainableWeights() {
- if (this.trainable) {
- return this._trainableWeights.filter(w => !w.trainable)
- .concat(this._nonTrainableWeights);
- }
- else {
- return this._trainableWeights.concat(this._nonTrainableWeights);
- }
- }
- set nonTrainableWeights(weights) {
- this._nonTrainableWeights = weights;
- }
- /**
- * The concatenation of the lists trainableWeights and nonTrainableWeights
- * (in this order).
- */
- get weights() {
- return this.trainableWeights.concat(this.nonTrainableWeights);
- }
- get stateful() {
- return this._stateful;
- }
- /**
- * Reset the states of the layer.
- *
- * This method of the base Layer class is essentially a no-op.
- * Subclasses that are stateful (e.g., stateful RNNs) should override this
- * method.
- */
- resetStates() {
- if (!this.stateful) {
- throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' +
- 'object.');
- }
- }
- /**
- * Checks compatibility between the layer and provided inputs.
- *
- * This checks that the tensor(s) `input`
- * verify the input assumptions of the layer
- * (if any). If not, exceptions are raised.
- *
- * @param inputs Input tensor or list of input tensors.
- *
- * @exception ValueError in case of mismatch between
- * the provided inputs and the expectations of the layer.
- */
- assertInputCompatibility(inputs) {
- inputs = toList(inputs);
- if (this.inputSpec == null || this.inputSpec.length === 0) {
- return;
- }
- const inputSpec = toList(this.inputSpec);
- if (inputs.length !== inputSpec.length) {
- throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
- `but it received ${inputs.length} input tensors. ` +
- `Input received: ${inputs}`);
- }
- for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
- const x = inputs[inputIndex];
- const spec = inputSpec[inputIndex];
- if (spec == null) {
- continue;
- }
- // Check ndim.
- const ndim = x.rank;
- if (spec.ndim != null) {
- if (ndim !== spec.ndim) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` +
- `expected ndim=${spec.ndim}, found ndim=${ndim}`);
- }
- }
- if (spec.maxNDim != null) {
- if (ndim > spec.maxNDim) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
- `: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`);
- }
- }
- if (spec.minNDim != null) {
- if (ndim < spec.minNDim) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
- `: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`);
- }
- }
- // Check dtype.
- if (spec.dtype != null) {
- if (x.dtype !== spec.dtype) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` +
- `: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`);
- }
- }
- // Check specific shape axes.
- if (spec.axes) {
- const xShape = x.shape;
- for (const key in spec.axes) {
- const axis = Number(key);
- const value = spec.axes[key];
- // Perform Python-style slicing in case axis < 0;
- // TODO(cais): Use https://github.com/alvivi/typescript-underscore to
- // ensure type safety through Underscore calls.
- const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
- if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
- `${this.name}: expected axis ${axis} of input shape to ` +
- `have value ${value} but got shape ${xShape}.`);
- }
- }
- }
- // Check shape.
- if (spec.shape != null) {
- for (let i = 0; i < spec.shape.length; ++i) {
- const specDim = spec.shape[i];
- const dim = x.shape[i];
- if (specDim != null && dim != null) {
- if (specDim !== dim) {
- throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
- `${this.name}: expected shape=${spec.shape}, ` +
- `found shape=${x.shape}.`);
- }
- }
- }
- }
- }
- }
- /**
- * This is where the layer's logic lives.
- *
- * @param inputs Input tensor, or list/tuple of input tensors.
- * @param kwargs Additional keyword arguments.
- *
- * @return A tensor or list/tuple of tensors.
- */
- call(inputs, kwargs) {
- return inputs;
- }
- invokeCallHook(inputs, kwargs) {
- if (this._callHook != null) {
- this._callHook(inputs, kwargs);
- }
- }
- /**
- * Set call hook.
- * This is currently used for testing only.
- * @param callHook
- */
- setCallHook(callHook) {
- this._callHook = callHook;
- }
- /**
- * Clear call hook.
- * This is currently used for testing only.
- */
- clearCallHook() {
- this._callHook = null;
- }
- /**
- * Builds or executes a `Layer's logic.
- *
- * When called with `tf.Tensor`(s), execute the `Layer`s computation and
- * return Tensor(s). For example:
- *
- * ```js
- * const denseLayer = tf.layers.dense({
- * units: 1,
- * kernelInitializer: 'zeros',
- * useBias: false
- * });
- *
- * // Invoke the layer's apply() method with a `tf.Tensor` (with concrete
- * // numeric values).
- * const input = tf.ones([2, 2]);
- * const output = denseLayer.apply(input);
- *
- * // The output's value is expected to be [[0], [0]], due to the fact that
- * // the dense layer has a kernel initialized to all-zeros and does not have
- * // a bias.
- * output.print();
- * ```
- *
- * When called with `tf.SymbolicTensor`(s), this will prepare the layer for
- * future execution. This entails internal book-keeping on shapes of
- * expected Tensors, wiring layers together, and initializing weights.
- *
- * Calling `apply` with `tf.SymbolicTensor`s are typically used during the
- * building of non-`tf.Sequential` models. For example:
- *
- * ```js
- * const flattenLayer = tf.layers.flatten();
- * const denseLayer = tf.layers.dense({units: 1});
- *
- * // Use tf.layers.input() to obtain a SymbolicTensor as input to apply().
- * const input = tf.input({shape: [2, 2]});
- * const output1 = flattenLayer.apply(input);
- *
- * // output1.shape is [null, 4]. The first dimension is the undetermined
- * // batch size. The second dimension comes from flattening the [2, 2]
- * // shape.
- * console.log(JSON.stringify(output1.shape));
- *
- * // The output SymbolicTensor of the flatten layer can be used to call
- * // the apply() of the dense layer:
- * const output2 = denseLayer.apply(output1);
- *
- * // output2.shape is [null, 1]. The first dimension is the undetermined
- * // batch size. The second dimension matches the number of units of the
- * // dense layer.
- * console.log(JSON.stringify(output2.shape));
- *
- * // The input and output and be used to construct a model that consists
- * // of the flatten and dense layers.
- * const model = tf.model({inputs: input, outputs: output2});
- * ```
- *
- * @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them.
- * @param kwargs Additional keyword arguments to be passed to `call()`.
- *
- * @return Output of the layer's `call` method.
- *
- * @exception ValueError error in case the layer is missing shape information
- * for its `build` call.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- // Porting Note: This is a replacement for __call__() in Python.
- apply(inputs, kwargs) {
- kwargs = kwargs || {};
- this.assertNotDisposed();
- // Ensure inputs are all the same type.
- const inputsList = toList(inputs);
- let allAreSymbolic = true;
- for (const input of inputsList) {
- if (!(input instanceof SymbolicTensor)) {
- allAreSymbolic = false;
- break;
- }
- }
- let noneAreSymbolic = true;
- for (const input of inputsList) {
- if (input instanceof SymbolicTensor) {
- noneAreSymbolic = false;
- break;
- }
- }
- if (allAreSymbolic === noneAreSymbolic) {
- throw new ValueError('Arguments to apply() must be all ' +
- 'SymbolicTensors or all Tensors');
- }
- // TODO(michaelterry): nameScope() may not be necessary.
- return nameScope(this.name, () => {
- // Handle laying building (weight creating, input spec locking).
- if (!this.built) {
- /*
- Throw exceptions in case the input is not compatible
- with the inputSpec specified in the layer constructor.
- */
- this.assertInputCompatibility(inputs);
- // Collect input shapes to build layer.
- const inputShapes = [];
- for (const xElem of toList(inputs)) {
- inputShapes.push(xElem.shape);
- }
- this.build(singletonOrArray(inputShapes));
- this.built = true;
- // Load weights that were specified at layer instantiation.
- if (this.initialWeights) {
- this.setWeights(this.initialWeights);
- }
- if (this._refCount === null && noneAreSymbolic) {
- // The first use of this layer is a non-symbolic call, set ref count
- // to 1 so the Layer can be properly disposed if its dispose() method
- // is called.
- this._refCount = 1;
- }
- }
- /*
- Throw exceptions in case the input is not compatible
- with the inputSpec set at build time.
- */
- this.assertInputCompatibility(inputs);
- // Handle mask propagation.
- // TODO(michaelterry): Mask propagation not currently implemented.
- // Actually call the layer, collecting output(s), mask(s), and shape(s).
- if (noneAreSymbolic) {
- let output = this.call(inputs, kwargs);
- // TODO(michaelterry): Compute the outputMask
- // If the layer returns tensors from its inputs, unmodified,
- // we copy them to avoid loss of tensor metadata.
- const outputList = toList(output);
- const outputListCopy = [];
- // TODO(michaelterry): This copying may not be necessary given our eager
- // backend.
- for (let x of outputList) {
- if (inputsList.indexOf(x) !== -1) {
- x = x.clone();
- }
- outputListCopy.push(x);
- }
- output = singletonOrArray(outputListCopy);
- if (this.activityRegularizer != null) {
- throw new NotImplementedError('Layer invocation in the presence of activity ' +
- 'regularizer(s) is not supported yet.');
- }
- // TODO(michaelterry): Call addInboundNode()?
- return output;
- }
- else {
- const inputShape = collectInputShape(inputs);
- const outputShape = this.computeOutputShape(inputShape);
- let output;
- const outputDType = guessOutputDType(inputs);
- this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] :
- inputShape);
- if (outputShape != null && outputShape.length > 0 &&
- Array.isArray(outputShape[0])) {
- // We have multiple output shapes. Create multiple output tensors.
- output = outputShape
- .map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index));
- }
- else {
- output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name);
- }
- /*
- Add an inbound node to the layer, so that it keeps track
- of the call and of all new variables created during the call.
- This also updates the layer history of the output tensor(s).
- If the input tensor(s) had no previous history,
- this does nothing.
- */
- this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
- this._refCount++;
- if (this.activityRegularizer != null) {
- throw new NotImplementedError('Layer invocation in the presence of activity ' +
- 'regularizer(s) is not supported yet.');
- }
- return output;
- }
- });
- }
- /**
- * Check compatibility between input shape and this layer's batchInputShape.
- *
- * Print warning if any incompatibility is found.
- *
- * @param inputShape Input shape to be checked.
- */
- warnOnIncompatibleInputShape(inputShape) {
- if (this.batchInputShape == null) {
- return;
- }
- else if (inputShape.length !== this.batchInputShape.length) {
- console.warn(`The rank of the input tensor provided (shape: ` +
- `${JSON.stringify(inputShape)}) does not match that of the ` +
- `batchInputShape (${JSON.stringify(this.batchInputShape)}) ` +
- `of the layer ${this.name}`);
- }
- else {
- let dimMismatch = false;
- this.batchInputShape.forEach((dimension, i) => {
- if (dimension != null && inputShape[i] != null &&
- inputShape[i] !== dimension) {
- dimMismatch = true;
- }
- });
- if (dimMismatch) {
- console.warn(`The shape of the input tensor ` +
- `(${JSON.stringify(inputShape)}) does not ` +
- `match the expectation of layer ${this.name}: ` +
- `${JSON.stringify(this.batchInputShape)}`);
- }
- }
- }
- /**
- * Retrieves the output shape(s) of a layer.
- *
- * Only applicable if the layer has only one inbound node, or if all inbound
- * nodes have the same output shape.
- *
- * @returns Output shape or shapes.
- * @throws AttributeError: if the layer is connected to more than one incoming
- * nodes.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- get outputShape() {
- if (this.inboundNodes == null || this.inboundNodes.length === 0) {
- throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` +
- `defined output shape.`);
- }
- const allOutputShapes = [];
- for (const node of this.inboundNodes) {
- const shapeString = JSON.stringify(node.outputShapes);
- if (allOutputShapes.indexOf(shapeString) === -1) {
- allOutputShapes.push(shapeString);
- }
- }
- if (allOutputShapes.length === 1) {
- const outputShapes = this.inboundNodes[0].outputShapes;
- if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) &&
- outputShapes.length === 1) {
- return outputShapes[0];
- }
- else {
- return outputShapes;
- }
- }
- else {
- throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` +
- `output shapes. Hence the notion of "output shape" is ill-defined ` +
- `for the layer.`);
- // TODO(cais): Implement getOutputShapeAt().
- }
- }
- /**
- * Counts the total number of numbers (e.g., float32, int32) in the
- * weights.
- *
- * @returns An integer count.
- * @throws RuntimeError: If the layer is not built yet (in which case its
- * weights are not defined yet.)
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- countParams() {
- if (!this.built) {
- throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` +
- `but the layer is not built yet. Build it first by calling ` +
- `build(batchInputShape).`);
- }
- return countParamsInWeights(this.weights);
- }
- /**
- * Creates the layer weights.
- *
- * Must be implemented on all layers that have weights.
- *
- * Called when apply() is called to construct the weights.
- *
- * @param inputShape A `Shape` or array of `Shape` (unused).
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- build(inputShape) {
- this.built = true;
- }
- /**
- * Returns the current values of the weights of the layer.
- *
- * @param trainableOnly Whether to get the values of only trainable weights.
- * @returns Weight values as an `Array` of `tf.Tensor`s.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- getWeights(trainableOnly = false) {
- return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
- }
- /**
- * Sets the weights of the layer, from Tensors.
- *
- * @param weights a list of Tensors. The number of arrays and their shape
- * must match number of the dimensions of the weights of the layer (i.e.
- * it should match the output of `getWeights`).
- *
- * @exception ValueError If the provided weights list does not match the
- * layer's specifications.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- setWeights(weights) {
- tidy(() => {
- const params = this.weights;
- if (params.length !== weights.length) {
- // TODO(cais): Restore the following and use `providedWeights`, instead
- // of `weights` in the error message, once the deeplearn.js bug is
- // fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const
- // providedWeights = JSON.stringify(weights).substr(0, 50);
- throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` +
- `with a weight list of length ${weights.length}, ` +
- `but the layer was expecting ${params.length} weights. ` +
- `Provided weights: ${weights}...`);
- }
- if (params.length === 0) {
- return;
- }
- const weightValueTuples = [];
- const paramValues = batchGetValue(params);
- for (let i = 0; i < paramValues.length; ++i) {
- const pv = paramValues[i];
- const p = params[i];
- const w = weights[i];
- if (!arraysEqual(pv.shape, w.shape)) {
- throw new ValueError(`Layer weight shape ${pv.shape} ` +
- `not compatible with provided weight shape ${w.shape}`);
- }
- weightValueTuples.push([p, w]);
- }
- batchSetValue(weightValueTuples);
- });
- }
- /**
- * Adds a weight variable to the layer.
- *
- * @param name Name of the new weight variable.
- * @param shape The shape of the weight.
- * @param dtype The dtype of the weight.
- * @param initializer An initializer instance.
- * @param regularizer A regularizer instance.
- * @param trainable Whether the weight should be trained via backprop or not
- * (assuming that the layer itself is also trainable).
- * @param constraint An optional trainable.
- * @return The created weight variable.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint) {
- // Reject duplicate weight names.
- if (this._addedWeightNames.indexOf(name) !== -1) {
- throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`);
- }
- this._addedWeightNames.push(name);
- if (dtype == null) {
- dtype = 'float32';
- }
- if (this.fastWeightInitDuringBuild) {
- initializer = getInitializer('zeros');
- }
- const initValue = initializer.apply(shape, dtype);
- const weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
- initValue.dispose();
- // Request backend not to dispose the weights of the model on scope() exit.
- if (regularizer != null) {
- this.addLoss(() => regularizer.apply(weight.read()));
- }
- if (trainable == null) {
- trainable = true;
- }
- if (trainable) {
- this._trainableWeights.push(weight);
- }
- else {
- this._nonTrainableWeights.push(weight);
- }
- return weight;
- }
- /**
- * Set the fast-weight-initialization flag.
- *
- * In cases where the initialized weight values will be immediately
- * overwritten by loaded weight values during model loading, setting
- * the flag to `true` saves unnecessary calls to potentially expensive
- * initializers and speeds up the loading process.
- *
- * @param value Target value of the flag.
- */
- setFastWeightInitDuringBuild(value) {
- this.fastWeightInitDuringBuild = value;
- }
- /**
- * Add losses to the layer.
- *
- * The loss may potentionally be conditional on some inputs tensors,
- * for instance activity losses are conditional on the layer's inputs.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- addLoss(losses) {
- if (losses == null || Array.isArray(losses) && losses.length === 0) {
- return;
- }
- // Update this.losses
- losses = toList(losses);
- if (this._losses !== undefined && this._losses !== null) {
- this.losses.push(...losses);
- }
- }
- /**
- * Computes the output shape of the layer.
- *
- * Assumes that the layer will be built to match that input shape provided.
- *
- * @param inputShape A shape (tuple of integers) or a list of shape tuples
- * (one per output tensor of the layer). Shape tuples can include null for
- * free dimensions, instead of an integer.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- computeOutputShape(inputShape) {
- return inputShape;
- }
- /**
- * Computes an output mask tensor.
- *
- * @param inputs Tensor or list of tensors.
- * @param mask Tensor or list of tensors.
- *
- * @return null or a tensor (or list of tensors, one per output tensor of the
- * layer).
- */
- computeMask(inputs, mask) {
- if (!this.supportsMasking) {
- if (mask != null) {
- if (Array.isArray(mask)) {
- mask.forEach(maskElement => {
- if (maskElement != null) {
- throw new TypeError(`Layer ${this.name} does not support masking, ` +
- 'but was passed an inputMask.');
- }
- });
- }
- else {
- throw new TypeError(`Layer ${this.name} does not support masking, ` +
- 'but was passed an inputMask.');
- }
- }
- // masking not explicitly supported: return null as mask
- return null;
- }
- // if masking is explictly supported, by default
- // carry over the input mask
- return mask;
- }
- /**
- * Internal method to create an inbound node for the layer.
- *
- * @param inputTensors List of input tensors.
- * @param outputTensors List of output tensors.
- * @param inputMasks List of input masks (a mask can be a tensor, or null).
- * @param outputMasks List of output masks (a mask can be a tensor, or null).
- * @param inputShapes List of input shape tuples.
- * @param outputShapes List of output shape tuples.
- * @param kwargs Dictionary of keyword arguments that were passed to the
- * `call` method of the layer at the call that created the node.
- */
- addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) {
- const inputTensorList = toList(inputTensors);
- outputTensors = toList(outputTensors);
- inputMasks = toList(inputMasks);
- outputMasks = toList(outputMasks);
- inputShapes = normalizeShapeList(inputShapes);
- outputShapes = normalizeShapeList(outputShapes);
- // Collect input tensor(s) coordinates.
- const inboundLayers = [];
- const nodeIndices = [];
- const tensorIndices = [];
- for (const x of inputTensorList) {
- /*
- * TODO(michaelterry): Keras adds this value to tensors; it's not
- * clear whether we'll use this or not.
- */
- inboundLayers.push(x.sourceLayer);
- nodeIndices.push(x.nodeIndex);
- tensorIndices.push(x.tensorIndex);
- }
- // Create node, add it to inbound nodes.
- // (This call has side effects.)
- // tslint:disable-next-line:no-unused-expression
- new Node({
- outboundLayer: this,
- inboundLayers,
- nodeIndices,
- tensorIndices,
- inputTensors: inputTensorList,
- outputTensors,
- inputMasks,
- outputMasks,
- inputShapes,
- outputShapes
- }, kwargs);
- // Update tensor history
- for (let i = 0; i < outputTensors.length; i++) {
- // TODO(michaelterry: _uses_learning_phase not tracked.
- outputTensors[i].sourceLayer = this;
- outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
- outputTensors[i].tensorIndex = i;
- }
- }
- /**
- * Returns the config of the layer.
- *
- * A layer config is a TS dictionary (serializable)
- * containing the configuration of a layer.
- * The same layer can be reinstantiated later
- * (without its trained weights) from this configuration.
- *
- * The config of a layer does not include connectivity
- * information, nor the layer class name. These are handled
- * by 'Container' (one layer of abstraction above).
- *
- * Porting Note: The TS dictionary follows TS naming standrds for
- * keys, and uses tfjs-layers type-safe Enums. Serialization methods
- * should use a helper function to convert to the pythonic storage
- * standard. (see serialization_utils.convertTsToPythonic)
- *
- * @returns TS dictionary of configuration.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- getConfig() {
- const config = { name: this.name, trainable: this.trainable };
- if (this.batchInputShape != null) {
- config['batchInputShape'] = this.batchInputShape;
- }
- if (this.dtype != null) {
- config['dtype'] = this.dtype;
- }
- return config;
- }
- /**
- * Dispose the weight variables that this Layer instance holds.
- *
- * @returns {number} Number of disposed variables.
- */
- disposeWeights() {
- this.weights.forEach(weight => weight.dispose());
- return this.weights.length;
- }
- assertNotDisposed() {
- if (this._refCount === 0) {
- throw new Error(`Layer '${this.name}' is already disposed.`);
- }
- }
- /**
- * Attempt to dispose layer's weights.
- *
- * This method decrease the reference count of the Layer object by 1.
- *
- * A Layer is reference-counted. Its reference count is incremented by 1
- * the first item its `apply()` method is called and when it becomes a part
- * of a new `Node` (through calling the `apply()`) method on a
- * `tf.SymbolicTensor`).
- *
- * If the reference count of a Layer becomes 0, all the weights will be
- * disposed and the underlying memory (e.g., the textures allocated in WebGL)
- * will be freed.
- *
- * Note: If the reference count is greater than 0 after the decrement, the
- * weights of the Layer will *not* be disposed.
- *
- * After a Layer is disposed, it cannot be used in calls such as `apply()`,
- * `getWeights()` or `setWeights()` anymore.
- *
- * @returns A DisposeResult Object with the following fields:
- * - refCountAfterDispose: The reference count of the Container after this
- * `dispose()` call.
- * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
- * during this `dispose()` call.
- * @throws {Error} If the layer is not built yet, or if the layer has already
- * been disposed.
- *
- * @doc {heading: 'Models', 'subheading': 'Classes'}
- */
- dispose() {
- if (!this.built) {
- throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` +
- `built yet.`);
- }
- if (this._refCount === null) {
- throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` +
- `yet.`);
- }
- this.assertNotDisposed();
- let numDisposedVariables = 0;
- if (--this._refCount === 0) {
- numDisposedVariables = this.disposeWeights();
- }
- return { refCountAfterDispose: this._refCount, numDisposedVariables };
- }
- }
- /**
- * Collects the input shape(s) of a list of `tf.Tensor`s or
- * `tf.SymbolicTensor`s.
- *
- * TODO(michaelterry): Update PyKeras docs (backport).
- *
- * @param inputTensors List of input tensors (or single input tensor).
- *
- * @return List of shape tuples (or single tuple), one tuple per input.
- */
- function collectInputShape(inputTensors) {
- inputTensors =
- toList(inputTensors);
- const shapes = [];
- for (const x of inputTensors) {
- shapes.push(x.shape);
- }
- return singletonOrArray(shapes);
- }
- /**
- * Guesses output dtype based on inputs.
- *
- * At present, just returns 'float32' for any input.
- *
- * @param inputTensors List of input tensors (or single input tensor).
- *
- * @return The guessed DType. At present, always returns 'float32'.
- */
- function guessOutputDType(inputTensors) {
- return 'float32';
- }
- /**
- * Returns the list of input tensors necessary to compute `tensor`.
- *
- * Output will always be a list of tensors (potentially with 1 element).
- *
- * @param tensor The tensor to start from.
- * @param layer Origin layer of the tensor.
- * @param nodeIndex Origin node index of the tensor.
- *
- * @return Array of input tensors.
- */
- function getSourceInputs(tensor, layer, nodeIndex) {
- if (layer == null || (nodeIndex != null && nodeIndex > 0)) {
- layer = tensor.sourceLayer;
- nodeIndex = tensor.nodeIndex;
- }
- if (layer.inboundNodes.length === 0) {
- return [tensor];
- }
- else {
- const node = layer.inboundNodes[nodeIndex];
- if (node.inboundLayers.length === 0) {
- return node.inputTensors;
- }
- else {
- const sourceTensors = [];
- for (let i = 0; i < node.inboundLayers.length; i++) {
- const x = node.inputTensors[i];
- const layer = node.inboundLayers[i];
- const nodeIndex = node.nodeIndices[i];
- const previousSources = getSourceInputs(x, layer, nodeIndex);
- // Avoid input redundancy.
- for (const x of previousSources) {
- if (sourceTensors.indexOf(x) === -1) {
- sourceTensors.push(x);
- }
- }
- }
- return sourceTensors;
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class InputLayer extends Layer {
- constructor(args) {
- super({
- dtype: args.dtype,
- name: args.name != null ? args.name : getUid('input').toString()
- });
- // Normalize config.batchSize and config.sparse
- if (args.batchSize == null) {
- args.batchSize = null;
- }
- if (args.sparse == null) {
- args.sparse = false;
- }
- this.trainable = false;
- this.built = true;
- this.sparse = args.sparse;
- if (args.inputShape != null && args.batchInputShape != null) {
- throw new ValueError('Only provide the inputShape OR ' +
- 'batchInputShape argument to inputLayer, not both at the same time.');
- }
- let batchInputShape = args.batchInputShape;
- if (batchInputShape == null) {
- if (args.inputShape == null) {
- throw new ValueError('An InputLayer should be passed either a ' +
- '`batchInputShape` or an `inputShape`.');
- }
- else {
- batchInputShape = [args.batchSize].concat(args.inputShape);
- }
- }
- else {
- // TODO(michaelterry): Backport to PyKeras
- if (args.batchSize != null) {
- throw new ValueError('Cannot specify batchSize if batchInputShape is ' +
- 'specified when creating an InputLayer.');
- }
- }
- const dtype = args.dtype || 'float32';
- this.batchInputShape = batchInputShape;
- this.dtype = dtype;
- // TODO(michaelterry): Backport this to PyKeras?
- this.inputSpec = [{ shape: batchInputShape }];
- const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name);
- inputTensor.nodeIndex = 0;
- inputTensor.tensorIndex = 0;
- // Create an input node to add to this.outboundNode.
- // (This call has side effects.)
- // tslint:disable-next-line:no-unused-expression
- new Node({
- outboundLayer: this,
- inboundLayers: [],
- nodeIndices: [],
- tensorIndices: [],
- inputTensors: [inputTensor],
- outputTensors: [inputTensor],
- inputMasks: [null],
- outputMasks: [null],
- inputShapes: [batchInputShape],
- outputShapes: [batchInputShape]
- });
- }
- apply(inputs, kwargs) {
- throw new ValueError('Cannot pass any input to an ' +
- `InputLayer's apply() method. InputLayer name: ${this.name}`);
- }
- dispose() {
- // dispose() for InputLayer is overridden as no-op.
- return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 };
- }
- getConfig() {
- return {
- batchInputShape: this.batchInputShape,
- dtype: this.dtype,
- sparse: this.sparse,
- name: this.name
- };
- }
- }
- /** @nocollapse */
- InputLayer.className = 'InputLayer';
- registerClass(InputLayer);
- function Input(config) {
- if (config.batchShape == null && config.shape == null) {
- throw new Error('Please provide to Input either a `shape`' +
- ' or a `batchShape` argument. Note that ' +
- '`shape` does not include the batch ' +
- 'dimension.');
- }
- if (config.batchShape != null && config.shape != null) {
- // TODO(michaelterry): Backport to PyKeras.
- throw new ValueError('Please provide either a `shape` or `batchShape` ' +
- 'argument to Input, but not both.');
- }
- let batchShape = config.batchShape;
- if (config.shape != null && batchShape == null) {
- batchShape = [null].concat(config.shape);
- }
- let dtype = config.dtype;
- if (dtype == null) {
- dtype = 'float32';
- }
- const inputLayer = new InputLayer({
- batchInputShape: batchShape,
- name: config.name,
- dtype,
- sparse: config.sparse
- });
- const outputs = inputLayer.inboundNodes[0].outputTensors;
- return outputs[0];
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Turn any Scalar values in a Logs object into actual number values.
- *
- * @param logs The `Logs` object to be resolved in place.
- */
- async function resolveScalarsInLogs(logs) {
- if (logs == null) {
- return;
- }
- const promises = [];
- const keys = [];
- const scalarsToDispose = [];
- for (const key in logs) {
- const value = logs[key];
- if (typeof value !== 'number') {
- const valueScalar = value;
- promises.push(valueScalar.data());
- keys.push(key);
- scalarsToDispose.push(valueScalar);
- }
- }
- if (promises.length > 0) {
- const values = await Promise.all(promises);
- for (let i = 0; i < values.length; ++i) {
- logs[keys[i]] = values[i][0];
- }
- // Dispose the original scalar tensors.
- dispose(scalarsToDispose);
- }
- }
- /**
- * Dispose all Tensors in an UnresolvedLogs object.
- *
- * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in
- * places where the values can be `tf.Tensor` or `number`.
- */
- function disposeTensorsInLogs(logs) {
- if (logs == null) {
- return;
- }
- for (const key in logs) {
- const value = logs[key];
- if (typeof value !== 'number') {
- value.dispose();
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /** Verbosity logging level when fitting a model. */
- var ModelLoggingVerbosity;
- (function (ModelLoggingVerbosity) {
- ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
- ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
- })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
- /** How often to yield to the main thread when training (in ms). */
- const DEFAULT_YIELD_EVERY_MS = 125;
- /**
- * Abstract base class used to build new callbacks.
- *
- * The `logs` dictionary that callback methods take as argument will contain
- * keys for quantities relevant to the current batch or epoch.
- *
- * Currently, the `.fit()` method of the `Sequential` model class
- * will include the following quantities in the `logs` that
- * it passes to its callbacks:
- *
- * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
- * (if validation is enabled in `fit`), and `valAcc` (if validation and
- * accuracy monitoring are enabled).
- * onBatchBegin: Logs include `size`, the number of samples in the current
- * batch.
- * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
- * is enabled).
- */
- class BaseCallback {
- constructor() {
- // TODO(michaelterry): This type is a best guess.
- this.validationData = null;
- }
- setParams(params) {
- this.params = params;
- }
- async onEpochBegin(epoch, logs) { }
- async onEpochEnd(epoch, logs) { }
- async onBatchBegin(batch, logs) { }
- async onBatchEnd(batch, logs) { }
- async onTrainBegin(logs) { }
- async onTrainEnd(logs) { }
- // LayersModel needs to call Callback.setModel(), but cannot actually depend
- // on Callback because that creates a cyclic dependency. Providing this no-op
- // method on BaseCallback breaks the cycle: this way LayersModel can depend on
- // BaseCallback but not on Callback. The argument is typed as `Container`
- // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback
- // overrides this method and enforces that the argument is really a
- // LayersModel.
- setModel(model) {
- // Do nothing. Use Callback instead of BaseCallback to track the model.
- }
- }
- /**
- * Container abstracting a list of callbacks.
- */
- class CallbackList {
- // TODO(cais): When the need arises, uncomment the following lines and
- // implement the queue for time values.
- // private deltaTBatch: number;
- // private deltaTsBatchBegin: Array;
- // private deltaTsBatchEnd: Array;
- /**
- * Constructor of CallbackList.
- * @param callbacks Array of `Callback` instances.
- * @param queueLength Queue length for keeping running statistics over
- * callback execution time.
- */
- constructor(callbacks, queueLength = 10) {
- // TODO(cais): Make use of queueLength when implementing the queue for time
- // values.
- if (callbacks == null) {
- callbacks = [];
- }
- this.callbacks = callbacks;
- this.queueLength = queueLength;
- }
- append(callback) {
- this.callbacks.push(callback);
- }
- setParams(params) {
- for (const callback of this.callbacks) {
- callback.setParams(params);
- }
- }
- setModel(model) {
- for (const callback of this.callbacks) {
- callback.setModel(model);
- }
- }
- /**
- * Called at the start of an epoch.
- * @param epoch Index of epoch.
- * @param logs Dictionary of logs.
- */
- async onEpochBegin(epoch, logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onEpochBegin(epoch, logs);
- }
- }
- /**
- * Called at the end of an epoch.
- * @param epoch Index of epoch.
- * @param logs Dictionary of logs.
- */
- async onEpochEnd(epoch, logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onEpochEnd(epoch, logs);
- }
- }
- /**
- * Called right before processing a batch.
- * @param batch Index of batch within the current epoch.
- * @param logs Dictionary of logs.
- */
- async onBatchBegin(batch, logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onBatchBegin(batch, logs);
- }
- }
- /**
- * Called at the end of a batch.
- * @param batch Index of batch within the current epoch.
- * @param logs Dictionary of logs.
- */
- async onBatchEnd(batch, logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onBatchEnd(batch, logs);
- }
- }
- /**
- * Called at the beginning of training.
- * @param logs Dictionary of logs.
- */
- async onTrainBegin(logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onTrainBegin(logs);
- }
- }
- /**
- * Called at the end of training.
- * @param logs Dictionary of logs.
- */
- async onTrainEnd(logs) {
- if (logs == null) {
- logs = {};
- }
- for (const callback of this.callbacks) {
- await callback.onTrainEnd(logs);
- }
- }
- }
- /**
- * Callback that accumulates epoch averages of metrics.
- *
- * This callback is automatically applied to every LayersModel.
- */
- class BaseLogger extends BaseCallback {
- constructor() {
- super();
- }
- async onEpochBegin(epoch) {
- this.seen = 0;
- this.totals = {};
- }
- async onBatchEnd(batch, logs) {
- if (logs == null) {
- logs = {};
- }
- const batchSize = logs['size'] == null ? 0 : logs['size'];
- this.seen += batchSize;
- for (const key in logs) {
- const value = logs[key];
- if (typeof value === 'number') {
- if (!this.totals.hasOwnProperty(key)) {
- this.totals[key] = 0;
- }
- this.totals[key] = this.totals[key] + value * batchSize;
- }
- else {
- let oldTotalsToDispose;
- if (key in this.totals) {
- oldTotalsToDispose = this.totals[key];
- }
- else {
- this.totals[key] = 0;
- }
- const total = tidy(() => add$1((this.totals[key]), mul(value, batchSize)));
- this.totals[key] = total;
- if (oldTotalsToDispose != null) {
- oldTotalsToDispose.dispose();
- }
- }
- }
- }
- async onEpochEnd(epoch, logs) {
- if (logs != null) {
- for (const key of this.params['metrics']) {
- if (this.totals[key] == null) {
- continue;
- }
- if (typeof this.totals[key] === 'number') {
- logs[key] = this.totals[key] / this.seen;
- }
- else {
- tidy(() => {
- const log = mul(div(1, this.seen), this.totals[key]);
- logs[key] = log;
- this.totals[key].dispose();
- keep(logs[key]);
- });
- }
- }
- }
- }
- }
- /**
- * Callback that records events into a `History` object. This callback is
- * automatically applied to every TF.js Layers model. The `History` object
- * gets returned by the `fit` method of models.
- */
- class History extends BaseCallback {
- async onTrainBegin(logs) {
- this.epoch = [];
- this.history = {};
- }
- async onEpochEnd(epoch, logs) {
- if (logs == null) {
- logs = {};
- }
- this.epoch.push(epoch);
- for (const key in logs) {
- if (this.history[key] == null) {
- this.history[key] = [];
- }
- this.history[key].push(logs[key]);
- }
- }
- /**
- * Await the values of all losses and metrics.
- */
- async syncData() {
- const promises = [];
- const keys = [];
- const indices = [];
- for (const key in this.history) {
- const valueArray = this.history[key];
- for (let i = 0; i < valueArray.length; ++i) {
- if (typeof valueArray[i] !== 'number') {
- const valueScalar = valueArray[i];
- promises.push(valueScalar.data());
- keys.push(key);
- indices.push(i);
- }
- }
- }
- const values = await Promise.all(promises);
- for (let n = 0; n < values.length; ++n) {
- const tensorToDispose = this.history[keys[n]][indices[n]];
- tensorToDispose.dispose();
- this.history[keys[n]][indices[n]] = values[n][0];
- }
- }
- }
- /**
- * Custom callback for training.
- */
- class CustomCallback extends BaseCallback {
- constructor(args, yieldEvery) {
- super();
- this.currentEpoch = 0;
- this.yieldEvery = yieldEvery || 'auto';
- if (this.yieldEvery === 'auto') {
- this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
- }
- if (this.yieldEvery === 'never' && args.onYield != null) {
- throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' +
- 'Either change `yieldEvery` or remove the callback');
- }
- if (isNumber(this.yieldEvery)) {
- // Decorate `maybeWait` so it will be called at most once every
- // `yieldEvery` ms.
- this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery);
- }
- this.trainBegin = args.onTrainBegin;
- this.trainEnd = args.onTrainEnd;
- this.epochBegin = args.onEpochBegin;
- this.epochEnd = args.onEpochEnd;
- this.batchBegin = args.onBatchBegin;
- this.batchEnd = args.onBatchEnd;
- this.yield = args.onYield;
- }
- async maybeWait(epoch, batch, logs) {
- const ps = [];
- if (this.yield != null) {
- await resolveScalarsInLogs(logs);
- ps.push(this.yield(epoch, batch, logs));
- }
- ps.push(nextFrame());
- await Promise.all(ps);
- }
- async onEpochBegin(epoch, logs) {
- this.currentEpoch = epoch;
- if (this.epochBegin != null) {
- await resolveScalarsInLogs(logs);
- await this.epochBegin(epoch, logs);
- }
- }
- async onEpochEnd(epoch, logs) {
- const ps = [];
- if (this.epochEnd != null) {
- await resolveScalarsInLogs(logs);
- ps.push(this.epochEnd(epoch, logs));
- }
- if (this.yieldEvery === 'epoch') {
- ps.push(nextFrame());
- }
- await Promise.all(ps);
- }
- async onBatchBegin(batch, logs) {
- if (this.batchBegin != null) {
- await resolveScalarsInLogs(logs);
- await this.batchBegin(batch, logs);
- }
- }
- async onBatchEnd(batch, logs) {
- const ps = [];
- if (this.batchEnd != null) {
- await resolveScalarsInLogs(logs);
- ps.push(this.batchEnd(batch, logs));
- }
- if (this.yieldEvery === 'batch') {
- ps.push(nextFrame());
- }
- else if (isNumber(this.yieldEvery)) {
- ps.push(this.maybeWait(this.currentEpoch, batch, logs));
- }
- await Promise.all(ps);
- }
- async onTrainBegin(logs) {
- if (this.trainBegin != null) {
- await resolveScalarsInLogs(logs);
- await this.trainBegin(logs);
- }
- }
- async onTrainEnd(logs) {
- if (this.trainEnd != null) {
- await resolveScalarsInLogs(logs);
- await this.trainEnd(logs);
- }
- }
- }
- /**
- * Standardize callbacks or configurations of them to an Array of callbacks.
- */
- function standardizeCallbacks(callbacks, yieldEvery) {
- if (callbacks == null) {
- callbacks = {};
- }
- if (callbacks instanceof BaseCallback) {
- return [callbacks];
- }
- if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
- return callbacks;
- }
- // Convert custom callback configs to custom callback objects.
- const callbackConfigs = toList(callbacks);
- return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
- }
- /**
- * A global registry for callback constructors to be used during
- * LayersModel.fit().
- */
- class CallbackConstructorRegistry {
- /**
- * Blocks public access to constructor.
- */
- constructor() { }
- /**
- * Register a tf.LayersModel.fit() callback constructor.
- *
- * The registered callback constructor will be used to instantiate
- * callbacks for every tf.LayersModel.fit() call afterwards.
- *
- * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
- * is to be reigstered.
- * @param callbackConstructor A no-arg constructor for `tf.Callback`.
- * @throws Error, if the same callbackConstructor has been registered before,
- * either at the same or a different `verbosityLevel`.
- */
- static registerCallbackConstructor(verbosityLevel, callbackConstructor) {
- assert(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` +
- `but got ${verbosityLevel}`);
- CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
- if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
- CallbackConstructorRegistry.constructors[verbosityLevel] = [];
- }
- CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
- }
- static checkForDuplicate(callbackConstructor) {
- for (const levelName in CallbackConstructorRegistry.constructors) {
- const constructors = CallbackConstructorRegistry.constructors[+levelName];
- constructors.forEach(ctor => {
- if (ctor === callbackConstructor) {
- throw new ValueError('Duplicate callback constructor.');
- }
- });
- }
- }
- /**
- * Clear all registered callback constructors.
- */
- static clear() {
- CallbackConstructorRegistry.constructors = {};
- }
- /**
- * Create callbacks using the registered callback constructors.
- *
- * Given `verbosityLevel`, all constructors registered at that level or above
- * will be called and the instantiated callbacks will be used.
- *
- * @param verbosityLevel: Level of verbosity.
- */
- static createCallbacks(verbosityLevel) {
- const constructors = [];
- for (const levelName in CallbackConstructorRegistry.constructors) {
- const level = +levelName;
- if (verbosityLevel >= level) {
- constructors.push(...CallbackConstructorRegistry.constructors[level]);
- }
- }
- return constructors.map(ctor => new ctor());
- }
- }
- CallbackConstructorRegistry.constructors = {};
- function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
- const history = new History();
- const actualCallbacks = [
- new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)
- ];
- if (callbacks != null) {
- actualCallbacks.push(...callbacks);
- }
- actualCallbacks.push(history);
- const callbackList = new CallbackList(actualCallbacks);
- // TODO(cais): Figure out when this LayersModel instance can have a
- // dynamically
- // set property called 'callback_model' as in PyKeras.
- callbackList.setParams({
- epochs,
- initialEpoch,
- samples: numTrainSamples,
- steps: stepsPerEpoch,
- batchSize,
- verbose,
- doValidation,
- metrics: callbackMetrics,
- });
- return { callbackList, history };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Instantiate a layer from a config dictionary.
- * @param config dict of the form {class_name: str, config: dict}
- * @param customObjects dict mapping class names (or function names)
- * of custom (non-Keras) objects to class/functions
- * @param fastWeightInit Optional flag to use fast weight initialization
- * during deserialization. This is applicable to cases in which
- * the initialization will be immediately overwritten by loaded weight
- * values. Default: `false`.
- * @returns Layer instance (may be LayersModel, Sequential, Layer...)
- */
- function deserialize(config, customObjects = {}, fastWeightInit = false) {
- return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Normalizes a tensor wrt the L2 norm alongside the specified axis.
- * @param x
- * @param axis Axis along which to perform normalization.
- */
- function l2Normalize(x, axis) {
- return tidy(() => {
- if (x.dtype !== 'float32') {
- x = x.asType('float32');
- }
- const squareSum = sum$1(square$1(x), axis, true);
- const epsilonTensor = fill(squareSum.shape, epsilon());
- const norm = sqrt(maximum(squareSum, epsilonTensor));
- return div(x, norm);
- });
- }
- function meanSquaredError$1(yTrue, yPred) {
- return tidy(() => mean(square$1(sub(yPred, yTrue)), -1));
- }
- function meanAbsoluteError(yTrue, yPred) {
- return tidy(() => mean(abs(sub(yPred, yTrue)), -1));
- }
- function meanAbsolutePercentageError(yTrue, yPred) {
- return tidy(() => {
- const diff = sub(yTrue, yPred);
- const clippedTrue = clipByValue(abs(yTrue), epsilon(), Number.MAX_VALUE);
- const absResult = abs(div(diff, clippedTrue));
- return mul(100, mean(absResult, -1));
- });
- }
- function meanSquaredLogarithmicError(yTrue, yPred) {
- return tidy(() => {
- const clippedPred = clipByValue(yPred, epsilon(), Number.MAX_VALUE);
- const firstLog = log(add$1(1, clippedPred));
- const clippedTrue = clipByValue(yTrue, epsilon(), Number.MAX_VALUE);
- const secondLog = log(add$1(1, clippedTrue));
- return mean(square$1(sub(firstLog, secondLog)), -1);
- });
- }
- function squaredHinge(yTrue, yPred) {
- return tidy(() => {
- const maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
- return mean(square$1(maxResult), -1);
- });
- }
- function hinge(yTrue, yPred) {
- return tidy(() => {
- const maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
- return mean(maxResult, -1);
- });
- }
- function categoricalHinge(yTrue, yPred) {
- return tidy(() => {
- const pos = sum$1(mul(yTrue, yPred), -1);
- const neg = max(mul(sub(1, yTrue), yPred), -1);
- return maximum(0, add$1(1, sub(neg, pos)));
- });
- }
- /**
- * Logarithm of the hyperbolic cosine of the prediction error.
- *
- * `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
- * to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
- * like the mean squared error, but will not be so strongly affected by the
- * occasional wildly incorrect prediction.
- */
- function logcosh(yTrue, yPred) {
- return tidy(() => {
- const log2 = Math.log(2);
- const predictionDiff = sub(yPred, yTrue);
- const logcoshResult = sub(add$1(predictionDiff, softplus(mul(-2, predictionDiff))), log2);
- return mean(logcoshResult, -1);
- });
- }
- function categoricalCrossentropy(target, output, fromLogits = false) {
- return tidy(() => {
- if (fromLogits) {
- output = softmax(output);
- }
- else {
- // scale preds so that the class probabilities of each sample sum to 1.
- const outputSum = sum$1(output, output.shape.length - 1, true);
- output = div(output, outputSum);
- }
- output = clipByValue(output, epsilon(), 1 - epsilon());
- return neg(sum$1(mul(target.toFloat(), log(output)), output.shape.length - 1));
- });
- }
- /**
- * Categorical crossentropy with integer targets.
- *
- * @param target An integer tensor.
- * @param output A tensor resulting from a softmax (unless `fromLogits` is
- * `true`, in which case `output` is expected to be the logits).
- * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
- * a tensor of logits.
- */
- function sparseCategoricalCrossentropy(target, output, fromLogits = false) {
- return tidy(() => {
- const flatTarget = floor(flatten$1(target)).toInt();
- output = clipByValue(output, epsilon(), 1 - epsilon());
- const outputShape = output.shape;
- const oneHotTarget = oneHot(flatTarget, outputShape[outputShape.length - 1])
- .reshape(outputShape);
- return categoricalCrossentropy(oneHotTarget, output, fromLogits);
- });
- }
- /**
- * From TensorFlow's implementation in nn_impl.py:
- *
- * For brevity, let `x = logits`, `z = labels`. The logistic loss is
- * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
- * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
- * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
- * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
- * = (1 - z) * x + log(1 + exp(-x))
- * = x - x * z + log(1 + exp(-x))
- * For x < 0, to avoid overflow in exp(-x), we reformulate the above
- * x - x * z + log(1 + exp(-x))
- * = log(exp(x)) - x * z + log(1 + exp(-x))
- * = - x * z + log(1 + exp(x))
- * Hence, to ensure stability and avoid overflow, the implementation uses this
- * equivalent formulation
- * max(x, 0) - x * z + log(1 + exp(-abs(x)))
- *
- * @param labels The labels.
- * @param logits The logits.
- */
- function sigmoidCrossEntropyWithLogits(labels, logits) {
- if (!arraysEqual(labels.shape, logits.shape)) {
- throw new ValueError(`logits and labels must have the same shape, but got shapes ` +
- `${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`);
- }
- return tidy(() => {
- // The logistic loss formula from above is
- // x - x * z + log(1 + exp(-x))
- // For x < 0, a more numerically stable formula is
- // -x * z + log(1 + exp(x))
- // Note that these two expressions can be combined into the following:
- // max(x, 0) - x * z + log(1 + exp(-abs(x)))
- const reluLogits = logits.relu();
- const negAbsLogits = logits.abs().neg();
- return reluLogits.sub(logits.mul(labels)).add(negAbsLogits.exp().log1p());
- });
- }
- function binaryCrossentropy(yTrue, yPred) {
- return tidy(() => {
- let y;
- y = clipByValue(yPred, epsilon(), 1 - epsilon());
- y = log(div(y, sub(1, y)));
- return mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
- });
- }
- function kullbackLeiblerDivergence(yTrue, yPred) {
- return tidy(() => {
- const clippedTrue = clipByValue(yTrue, epsilon(), 1);
- const clippedPred = clipByValue(yPred, epsilon(), 1);
- return sum$1(mul(yTrue, log(div(clippedTrue, clippedPred))), -1);
- });
- }
- function poisson(yTrue, yPred) {
- return tidy(() => {
- const logPred = log(add$1(epsilon(), yPred));
- return mean(sub(yPred, mul(yTrue, logPred)), -1);
- });
- }
- function cosineProximity(yTrue, yPred) {
- return tidy(() => {
- const trueNormalized = l2Normalize(yTrue, -1);
- const predNormalized = l2Normalize(yPred, -1);
- const trueXPred = mul(trueNormalized, predNormalized);
- return neg(sum$1(trueXPred, -1));
- });
- }
- const mse = meanSquaredError$1;
- const MSE = meanSquaredError$1;
- const mae = meanAbsoluteError;
- const MAE = meanAbsoluteError;
- const mape = meanAbsolutePercentageError;
- const MAPE = meanAbsolutePercentageError;
- const msle = meanSquaredLogarithmicError;
- const MSLE = meanSquaredLogarithmicError;
- const kld = kullbackLeiblerDivergence;
- const KLD = kullbackLeiblerDivergence;
- const cosine = cosineProximity;
- // TODO(michaelterry): Add deserialize() function.
- const lossesMap = {
- meanSquaredError: meanSquaredError$1,
- meanAbsoluteError,
- meanAbsolutePercentageError,
- meanSquaredLogarithmicError,
- squaredHinge,
- hinge,
- categoricalHinge,
- logcosh,
- categoricalCrossentropy,
- sparseCategoricalCrossentropy,
- binaryCrossentropy,
- kullbackLeiblerDivergence,
- poisson,
- cosineProximity
- };
- // Porting note: This diverges from the PyKeras implementation and may need to
- // change based on (de)serialization requirements.
- function get(identifierOrFn) {
- if (typeof identifierOrFn === 'string') {
- if (identifierOrFn in lossesMap) {
- return lossesMap[identifierOrFn];
- }
- let errMsg = `Unknown loss ${identifierOrFn}`;
- if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
- errMsg = `Unknown loss ${identifierOrFn}. ` +
- 'Use "categoricalCrossentropy" as the string name for ' +
- 'tf.losses.softmaxCrossEntropy';
- }
- throw new ValueError(errMsg);
- }
- else {
- return identifierOrFn;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- function binaryAccuracy(yTrue, yPred) {
- return tidy(() => {
- const threshold = mul(.5, onesLike(yPred));
- const yPredThresholded = cast$1(greater(yPred, threshold), yTrue.dtype);
- return mean(equal(yTrue, yPredThresholded), -1);
- });
- }
- function categoricalAccuracy(yTrue, yPred) {
- return tidy(() => cast$1(equal(argMax(yTrue, -1), argMax(yPred, -1)), 'float32'));
- }
- function truePositives(yTrue, yPred) {
- return tidy(() => {
- return logicalAnd(yTrue.equal(1), yPred.equal(1)).sum().cast('float32');
- });
- }
- function falseNegatives(yTrue, yPred) {
- return tidy(() => {
- return logicalAnd(yTrue.equal(1), yPred.equal(0)).sum().cast('float32');
- });
- }
- function falsePositives(yTrue, yPred) {
- return tidy(() => {
- return logicalAnd(yTrue.equal(0), yPred.equal(1)).sum().cast('float32');
- });
- }
- function precision(yTrue, yPred) {
- return tidy(() => {
- const tp = truePositives(yTrue, yPred);
- const fp = falsePositives(yTrue, yPred);
- const denominator = tp.add(fp);
- return where(greater(denominator, 0), tp.div(denominator), 0)
- .cast('float32');
- });
- }
- function recall(yTrue, yPred) {
- return tidy(() => {
- const tp = truePositives(yTrue, yPred);
- const fn = falseNegatives(yTrue, yPred);
- const denominator = tp.add(fn);
- return where(greater(denominator, 0), tp.div(denominator), 0)
- .cast('float32');
- });
- }
- function binaryCrossentropy$1(yTrue, yPred) {
- return binaryCrossentropy(yTrue, yPred);
- }
- function sparseCategoricalAccuracy(yTrue, yPred) {
- if (yTrue.rank === yPred.rank) {
- yTrue = yTrue.squeeze([yTrue.rank - 1]);
- }
- yPred = yPred.argMax(-1);
- if (yPred.dtype !== yTrue.dtype) {
- yPred = yPred.asType(yTrue.dtype);
- }
- return equal(yTrue, yPred).asType('float32');
- }
- function topKCategoricalAccuracy(yTrue, yPred) {
- throw new NotImplementedError();
- }
- function sparseTopKCategoricalAccuracy(yTrue, yPred) {
- throw new NotImplementedError();
- }
- // Aliases.
- const mse$1 = meanSquaredError$1;
- const MSE$1 = meanSquaredError$1;
- const mae$1 = meanAbsoluteError;
- const MAE$1 = meanAbsoluteError;
- const mape$1 = meanAbsolutePercentageError;
- const MAPE$1 = meanAbsolutePercentageError;
- const categoricalCrossentropy$1 = categoricalCrossentropy;
- const cosine$1 = cosineProximity;
- const sparseCategoricalCrossentropy$1 = sparseCategoricalCrossentropy;
- // TODO(cais, nielsene): Add serialize().
- const metricsMap = {
- binaryAccuracy,
- categoricalAccuracy,
- precision,
- categoricalCrossentropy: categoricalCrossentropy$1,
- sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
- mse: mse$1,
- MSE: MSE$1,
- mae: mae$1,
- MAE: MAE$1,
- mape: mape$1,
- MAPE: MAPE$1,
- cosine: cosine$1
- };
- function get$1(identifier) {
- if (typeof identifier === 'string' && identifier in metricsMap) {
- return metricsMap[identifier];
- }
- else if (typeof identifier !== 'string' && identifier != null) {
- return identifier;
- }
- else {
- throw new ValueError(`Unknown metric ${identifier}`);
- }
- }
- /**
- * Get the shortcut function name.
- *
- * If the fn name is a string,
- * directly return the string name.
- * If the function is included in metricsMap or lossesMap,
- * return key of the map.
- * - If the function relative to multiple keys,
- * return the first found key as the function name.
- * - If the function exists in both lossesMap and metricsMap,
- * search lossesMap first.
- * If the function is not included in metricsMap or lossesMap,
- * return the function name.
- *
- * @param fn loss function, metric function, or short cut name.
- * @returns Loss or Metric name in string.
- */
- function getLossOrMetricName(fn) {
- assert$1(fn !== null, `Unknown LossOrMetricFn ${fn}`);
- if (typeof fn === 'string') {
- return fn;
- }
- else {
- let fnName;
- for (const key of Object.keys(lossesMap)) {
- if (lossesMap[key] === fn) {
- fnName = key;
- break;
- }
- }
- if (fnName !== undefined) {
- return fnName;
- }
- for (const key of Object.keys(metricsMap)) {
- if (metricsMap[key] === fn) {
- fnName = key;
- break;
- }
- }
- if (fnName !== undefined) {
- return fnName;
- }
- return fn.name;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // Add (de)serialize()
- // Porting note: This diverges from the PyKeras implementation and may need to
- // change based on (de)serialization requirements.
- function getOptimizer(identifier) {
- const optimizerMap = {
- 'Adagrad': () => train.adagrad(0.01),
- 'Adadelta': () => train.adadelta(1, 0.95, epsilon()),
- 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()),
- 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0),
- 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()),
- 'SGD': () => train.sgd(0.01)
- };
- optimizerMap['adagrad'] = optimizerMap['Adagrad'];
- optimizerMap['adadelta'] = optimizerMap['Adadelta'];
- optimizerMap['adam'] = optimizerMap['Adam'];
- optimizerMap['adamax'] = optimizerMap['Adamax'];
- optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
- optimizerMap['sgd'] = optimizerMap['SGD'];
- if (identifier in optimizerMap) {
- return optimizerMap[identifier]();
- }
- throw new ValueError(`Unknown Optimizer ${identifier}`);
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /** Utility functions related to user-defined metadata. */
- // Maximum recommended serialized size for user-defined metadata.
- // Beyond this limit, a warning message will be printed during model loading and
- // saving.
- const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
- /**
- * Check validity of user-defined metadata.
- *
- * @param userDefinedMetadata
- * @param modelName Name of the model that the user-defined metadata belongs to.
- * Used during construction of error messages.
- * @param checkSize Whether to check the size of the metadata is under
- * recommended limit. Default: `false`. If `true`, will try stringify the
- * JSON object and print a console warning if the serialzied size is above the
- * limit.
- * @throws Error if `userDefinedMetadata` is not a plain JSON object.
- */
- function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) {
- if (userDefinedMetadata == null ||
- typeof userDefinedMetadata !== 'object' ||
- Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype ||
- !plainObjectCheck(userDefinedMetadata)) {
- throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
- }
- if (checkSize) {
- const out = JSON.stringify(userDefinedMetadata);
- if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
- console.warn(`User-defined metadata of model "${modelName}" is too large in ` +
- `size (length=${out.length} when serialized). It is not ` +
- `recommended to store such large objects in user-defined metadata. ` +
- `Please make sure its serialized length is <= ` +
- `${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`);
- }
- }
- }
- /**
- * Check if an input is plain JSON object or any valid subfield of it.
- *
- * @param x The input to be checked.
- * @param assertObject Whether to assert `x` is a JSON object, i.e., reject
- * cases of arrays and primitives.
- * @return Returns `true` if and only if `x` is a plain JSON object,
- * a JSON-valid primitive including string, number, boolean and null,
- * or an array of the said types.
- */
- // tslint:disable-next-line:no-any
- function plainObjectCheck(x) {
- if (x === null) {
- // Note: typeof `null` is 'object', and `null` is valid in JSON.
- return true;
- }
- else if (typeof x === 'object') {
- if (Object.getPrototypeOf(x) === Object.prototype) {
- // `x` is a JavaScript object and its prototype is Object.
- const keys = Object.keys(x);
- for (const key of keys) {
- if (typeof key !== 'string') {
- // JSON keys must be strings.
- return false;
- }
- if (!plainObjectCheck(x[key])) { // Recursive call.
- return false;
- }
- }
- return true;
- }
- else {
- // `x` is a JavaScript object but its prototype is not Object.
- if (Array.isArray(x)) {
- // `x` is a JavaScript array.
- for (const item of x) {
- if (!plainObjectCheck(item)) { // Recursive call.
- return false;
- }
- }
- return true;
- }
- else {
- // `x` is a JavaScript object and its prototype is not Object,
- // and it's not an Array. I.e., it's a complex object such as
- // `Error` and `Date`.
- return false;
- }
- }
- }
- else {
- // `x` is not a JavaScript object or `null`.
- const xType = typeof x;
- return xType === 'string' || xType === 'number' || xType === 'boolean';
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Print the summary of a LayersModel object.
- *
- * @param model tf.LayersModel instance.
- * @param lineLength Total length of printed lines. Set this to adapt to the
- * display to different terminal or console sizes.
- * @param positions Relative or absolute positions of log elements in each
- * line. Each number corresponds to right-most (i.e., ending) position of a
- * column.
- * If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like
- * models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models.
- * @param printFn Print function to use.
- * It will be called on each line of the summary. You can provide a custom
- * function in order to capture the string summary. Defaults to `console.log`.
- */
- function printSummary(model, lineLength, positions,
- // tslint:disable-next-line:no-any
- printFn = console.log) {
- const sequentialLike = isModelSequentialLike(model);
- // Header names for different log elements.
- const toDisplay = ['Layer (type)', 'Output shape', 'Param #'];
- if (sequentialLike) {
- lineLength = lineLength || 65;
- positions = positions || [0.45, 0.85, 1];
- }
- else {
- lineLength = lineLength || 98;
- positions = positions || [0.33, 0.55, 0.67, 1];
- // Header names for different log elements.
- }
- if (positions[positions.length - 1] <= 1) {
- // `positions` is relative. Convert it to absolute positioning.
- positions = positions.map(p => Math.floor(lineLength * p));
- }
- let relevantNodes;
- if (!sequentialLike) {
- toDisplay.push('Receives inputs');
- relevantNodes = [];
- for (const depth in model.nodesByDepth) {
- relevantNodes.push(...model.nodesByDepth[depth]);
- }
- }
- printFn('_'.repeat(lineLength));
- printRow(toDisplay, positions, printFn);
- printFn('='.repeat(lineLength));
- const layers = model.layers;
- for (let i = 0; i < layers.length; ++i) {
- if (sequentialLike) {
- printLayerSummary(layers[i], positions, printFn);
- }
- else {
- printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
- }
- printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
- }
- // tslint:disable-next-line:no-any
- model.checkTrainableWeightsConsistency();
- const trainableCount = countTrainableParams(model);
- const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
- printFn(`Total params: ${trainableCount + nonTrainableCount}`);
- printFn(`Trainable params: ${trainableCount}`);
- printFn(`Non-trainable params: ${nonTrainableCount}`);
- printFn('_'.repeat(lineLength));
- }
- function countTrainableParams(model) {
- let trainableCount;
- // tslint:disable:no-any
- if (model.collectedTrainableWeights != null) {
- trainableCount =
- countParamsInWeights(model.collectedTrainableWeights);
- }
- else {
- trainableCount = countParamsInWeights(model.trainableWeights);
- }
- // tslint:enable:no-any
- return trainableCount;
- }
- function isModelSequentialLike(model) {
- let sequentialLike = true;
- const nodesByDepth = [];
- const nodes = [];
- for (const depth in model.nodesByDepth) {
- nodesByDepth.push(model.nodesByDepth[depth]);
- }
- for (const depthNodes of nodesByDepth) {
- if (depthNodes.length > 1 ||
- depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
- sequentialLike = false;
- break;
- }
- nodes.push(...depthNodes);
- }
- if (sequentialLike) {
- // Search for shared layers.
- for (const layer of model.layers) {
- let flag = false;
- for (const node of layer.inboundNodes) {
- if (nodes.indexOf(node) !== -1) {
- if (flag) {
- sequentialLike = false;
- break;
- }
- else {
- flag = true;
- }
- }
- }
- if (!sequentialLike) {
- break;
- }
- }
- }
- return sequentialLike;
- }
- function printRow(fields, positions,
- // tslint:disable-next-line:no-any
- printFn = console.log) {
- let line = '';
- for (let i = 0; i < fields.length; ++i) {
- if (i > 0) {
- line = line.slice(0, line.length - 1) + ' ';
- }
- line += fields[i];
- line = line.slice(0, positions[i]);
- line += ' '.repeat(positions[i] - line.length);
- }
- printFn(line);
- }
- /**
- * Prints a summary for a single Layer, without connectivity information.
- *
- * @param layer: Layer instance to print.
- */
- function printLayerSummary(layer, positions,
- // tslint:disable-next-line:no-any
- printFn) {
- let outputShape;
- try {
- outputShape = JSON.stringify(layer.outputShape);
- }
- catch (err) {
- outputShape = 'multiple';
- }
- const name = layer.name;
- const className = layer.getClassName();
- const fields = [`${name} (${className})`, outputShape, layer.countParams().toString()];
- printRow(fields, positions, printFn);
- }
- /**
- * Prints a summary for a single Layer, with connectivity information.
- */
- function printLayerSummaryWithConnections(layer, positions, relevantNodes,
- // tslint:disable-next-line:no-any
- printFn) {
- let outputShape;
- try {
- outputShape = JSON.stringify(layer.outputShape);
- }
- catch (err) {
- outputShape = 'multiple';
- }
- const connections = [];
- for (const node of layer.inboundNodes) {
- if (relevantNodes != null && relevantNodes.length > 0 &&
- relevantNodes.indexOf(node) === -1) {
- continue;
- }
- for (let i = 0; i < node.inboundLayers.length; ++i) {
- const inboundLayer = node.inboundLayers[i].name;
- const inboundLayerIndex = node.nodeIndices[i];
- const inboundTensorIndex = node.tensorIndices[i];
- connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`);
- }
- }
- const name = layer.name;
- const className = layer.getClassName();
- const firstConnection = connections.length === 0 ? '' : connections[0];
- const fields = [
- `${name} (${className})`, outputShape, layer.countParams().toString(),
- firstConnection
- ];
- printRow(fields, positions, printFn);
- for (let i = 1; i < connections.length; ++i) {
- printRow(['', '', '', connections[i]], positions, printFn);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // tslint:enable
- /**
- * Test whether a value in an array is the name of a LayersModel or Layer.
- * @param key The key name that the value is found under. Note that the key
- * may not be at the level immediately above the value, if the value is in a
- * nested array.
- * @param index Index of the value in the Array that it is found in.
- * @param value The value object.
- * @returns A boolean indicating whether value is a name.
- */
- function isArrayItemInputOrOutputName(key, index, value) {
- return (key === 'inboundNodes' || key === 'outputLayers' ||
- key === 'inputLayers') &&
- index === 0 && typeof value === 'string';
- }
- /**
- * Convert a Pythonic config object to TypeScript config object.
- * @param pythonicConfig The config object to convert.
- * @param key Optional key name of the object being converted.
- * @returns Result of the conversion.
- */
- function convertPythonicToTs(pythonicConfig, key) {
- if (pythonicConfig === null) {
- return null;
- }
- else if (typeof pythonicConfig === 'string') {
- return toCamelCase(pythonicConfig);
- }
- else if ((typeof pythonicConfig === 'number') ||
- (typeof pythonicConfig === 'boolean')) {
- return pythonicConfig;
- }
- else if (pythonicConfig instanceof Array) {
- const tsArray = [];
- const arrayLength = pythonicConfig.length;
- for (let i = 0; i < arrayLength; ++i) {
- const item = pythonicConfig[i];
- if (isArrayItemInputOrOutputName(key, i, item)) {
- tsArray.push(item);
- }
- else {
- tsArray.push(convertPythonicToTs(item, key));
- }
- }
- return tsArray;
- }
- else {
- const tsDict = {};
- for (const pythonicKey of Object.keys(pythonicConfig)) {
- const pythonicValue = pythonicConfig[pythonicKey];
- if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
- // Special case the 'name' key with a string value. Name values, such as
- // the names of LayersModel and Layer instances, should not undergo the
- // camel-case conversion.
- tsDict[pythonicKey] = pythonicValue;
- }
- else {
- const tsKey = toCamelCase(pythonicKey);
- tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
- }
- }
- return tsDict;
- }
- }
- /**
- * Convert a TypeScript config object to Python config object.
- * @param tsConfig The config object to convert.
- * @param key Optional key name of the object being converted.
- * @returns Result of the conversion.
- */
- function convertTsToPythonic(tsConfig, key) {
- if (tsConfig === null || tsConfig === undefined) {
- return null;
- }
- else if (typeof tsConfig === 'string') {
- return toSnakeCase(tsConfig);
- }
- else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) {
- return tsConfig;
- }
- else if (tsConfig instanceof Array) {
- const pyArray = [];
- const arrayLength = tsConfig.length;
- for (let i = 0; i < arrayLength; ++i) {
- const item = tsConfig[i];
- if (isArrayItemInputOrOutputName(key, i, item)) {
- pyArray.push(item);
- }
- else {
- pyArray.push(convertTsToPythonic(item, key));
- }
- }
- return pyArray;
- }
- else {
- const pyDict = {};
- for (const tsKey of Object.keys(tsConfig)) {
- const tsValue = tsConfig[tsKey];
- const pyKey = toSnakeCase(tsKey);
- if ((tsKey === 'name' || tsKey === 'className') &&
- typeof tsValue === 'string') {
- // Special case the 'name' key with a string value. Name values, such as
- // the names of LayersModel and Layer instances, should not undergo the
- // snake-case conversion.
- pyDict[pyKey] = tsValue;
- }
- else {
- pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
- }
- }
- return pyDict;
- }
- }
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$1 = '0.0.0';
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Helper function to check the dtype and shape compatibility of a feed value.
- */
- function assertFeedCompatibility(key, val) {
- // Check dtype compatibility.
- if (key.dtype == null || key.dtype === val.dtype) {
- // a. If types match, return val tensor as is.
- return val;
- }
- try {
- // b. Attempt to convert to expected type.
- return cast(val, key.dtype);
- }
- catch (err) {
- // c. If conversion fails, return helpful error.
- throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
- `of the key '${key.name}' (${key.dtype}).`);
- }
- }
- /**
- * FeedDict: A mapping from unique SymbolicTensors to feed values for them.
- * A feed value is a concrete value represented as an `Tensor`.
- */
- class FeedDict {
- /**
- * Constructor, optionally does copy-construction.
- * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
- * copy-construction will be performed.
- */
- constructor(feeds) {
- this.id2Value = {};
- this.id2Mask = {};
- this.name2Id = {};
- if (feeds instanceof FeedDict) {
- for (const id in feeds.id2Value) {
- this.id2Value[id] = feeds.id2Value[id];
- if (id in feeds.id2Mask) {
- this.id2Mask[id] = feeds.id2Mask[id];
- }
- }
- }
- else {
- if (feeds == null) {
- return;
- }
- for (const feed of feeds) {
- this.add(feed.key, feed.value);
- }
- }
- }
- /**
- * Add a key-value pair to the FeedDict.
- *
- * @param key The key of the feed.
- * @param value The value of the tensor feed.
- * @param mask The value of the mask feed (optional).
- * @returns This `FeedDict`.
- * @throws ValueError: If the key `SymbolicTensor` already exists in the
- * `FeedDict`.
- */
- add(key, value, mask) {
- if (this.id2Value[key.id] == null) {
- this.id2Value[key.id] = assertFeedCompatibility(key, value);
- this.name2Id[key.name] = key.id;
- if (mask != null) {
- this.id2Mask[key.id] = mask;
- }
- }
- else {
- throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
- }
- return this;
- }
- /**
- * Add a Feed to the FeedDict.
- * @param feed The new `Feed` to add.
- * @returns This `FeedDict`.
- */
- addFeed(feed) {
- this.add(feed.key, feed.value);
- }
- /**
- * Probe whether a key already exists in the FeedDict.
- * @param key
- */
- hasKey(key) {
- return this.id2Value[key.id] != null;
- }
- /**
- * Get all the SymbolicTensor available in this FeedDict.
- */
- names() {
- return Object.keys(this.name2Id);
- }
- /**
- * Get the feed value for given key.
- * @param key The SymbolicTensor, or its name (as a string), of which the
- * value is sought.
- * @returns If `key` exists, the corresponding feed value.
- * @throws ValueError: If `key` does not exist in this `FeedDict`.
- */
- getValue(key) {
- if (key instanceof SymbolicTensor) {
- if (this.id2Value[key.id] == null) {
- throw new ValueError(`Nonexistent key: ${key.name}`);
- }
- else {
- return this.id2Value[key.id];
- }
- }
- else {
- const id = this.name2Id[key];
- if (id == null) {
- throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
- }
- return this.id2Value[id];
- }
- }
- /**
- * Get the feed mask for given key.
- * @param key The SymbolicTensor, or its name (as a string), of which the
- * value is sought.
- * @returns If `key` exists, the corresponding feed mask.
- * @throws ValueError: If `key` does not exist in this `FeedDict`.
- */
- getMask(key) {
- if (key instanceof SymbolicTensor) {
- if (this.id2Value[key.id] == null) {
- throw new ValueError(`Nonexistent key: ${key.name}`);
- }
- else {
- return this.id2Mask[key.id];
- }
- }
- else {
- const id = this.name2Id[key];
- if (id == null) {
- throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
- }
- return this.id2Mask[id];
- }
- }
- /** Dispose all mask Tensors held by this object. */
- disposeMasks() {
- if (this.id2Mask != null) {
- dispose(this.id2Mask);
- }
- }
- }
- // Cache for topologically sorted SymbolicTensors for given execution
- // targets (i.e., fetches).
- const cachedSorted = {};
- // Cache for recipient count maps for given execution targets (i.e., fetches).
- const cachedRecipientCounts = {};
- /**
- * Execute a SymbolicTensor by using concrete feed values.
- *
- * A `SymbolicTensor` object is a node in a computation graph of TF.js
- * Layers. The object is backed by a source layer and input
- * `SymbolicTensor`s to the source layer. This method evaluates
- * the `call()` method of the source layer, using concrete values of the
- * inputs obtained from either
- * * `feedDict`, if the input key exists in `feedDict`, or else,
- * * a recursive call to `execute()` itself.
- *
- * @param x: The `SymbolicTensor` to execute.
- * @param feedDict: The feed values, as base condition of the recursion.
- * execution.
- * @param kwargs: Optional keyword arguments.
- * @param probe: A probe object (of interface `ExecutionProbe`) used for
- * testing memory footprint of `execute` calls.
- * @returns Result of the execution.
- * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
- * encountered during the execution lacks a feed value in `feedDict`.
- */
- function execute(fetches, feedDict, kwargs, probe) {
- const training = kwargs == null ? false : kwargs['training'];
- const arrayFetches = Array.isArray(fetches);
- const fetchArray = arrayFetches ? fetches : [fetches];
- const outputNames = fetchArray.map(t => t.name);
- const finalOutputs = [];
- const feedNames = feedDict.names();
- for (const outputName of outputNames) {
- if (feedNames.indexOf(outputName) !== -1) {
- finalOutputs.push(feedDict.getValue(outputName));
- }
- else {
- finalOutputs.push(null);
- }
- }
- if (probe != null) {
- // For optional probing of memory footprint during execution.
- probe.maxNumTensors = -Infinity;
- probe.minNumTensors = Infinity;
- }
- // Check cache.
- const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().join(',');
- let sorted;
- let recipientCounts;
- if (cachedSorted[fetchAndFeedKey] == null) {
- // Cache doesn't contain the desired combination of fetches. Compute
- // topological sort for the combination for the first time.
- const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
- sorted = out.sorted;
- recipientCounts = out.recipientCounts;
- // Store results in cache for future use.
- cachedSorted[fetchAndFeedKey] = sorted;
- cachedRecipientCounts[fetchAndFeedKey] = recipientCounts;
- }
- sorted = cachedSorted[fetchAndFeedKey];
- recipientCounts = {};
- if (!training) {
- Object.assign(recipientCounts, cachedRecipientCounts[fetchAndFeedKey]);
- }
- const internalFeedDict = new FeedDict(feedDict);
- // Start iterative execution on the topologically-sorted SymbolicTensors.
- for (let i = 0; i < sorted.length; ++i) {
- if (probe != null) {
- // For optional probing of memory usage during execution.
- const numTensors = memory().numTensors;
- if (numTensors > probe.maxNumTensors) {
- probe.maxNumTensors = numTensors;
- }
- if (numTensors < probe.minNumTensors) {
- probe.minNumTensors = numTensors;
- }
- }
- const symbolic = sorted[i];
- const srcLayer = symbolic.sourceLayer;
- if (srcLayer instanceof InputLayer) {
- continue;
- }
- const inputValues = [];
- const inputMasks = [];
- const tensorsToDispose = [];
- let maskExists = false;
- for (const input of symbolic.inputs) {
- const value = internalFeedDict.getValue(input);
- const mask = internalFeedDict.getMask(input);
- inputValues.push(value);
- inputMasks.push(mask);
- if (mask != null) {
- maskExists = true;
- }
- if (!training) {
- recipientCounts[input.name]--;
- if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
- outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
- input.sourceLayer.stateful !== true) {
- tensorsToDispose.push(value);
- }
- }
- }
- if (maskExists) {
- kwargs = kwargs || {};
- kwargs['mask'] = inputMasks[0];
- }
- const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
- let outputMask = null;
- if (srcLayer.supportsMasking) {
- outputMask = srcLayer.computeMask(inputValues, inputMasks);
- }
- const layerOutputs = getNodeOutputs(symbolic);
- const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
- for (let i = 0; i < outputSymbolicTensors.length; ++i) {
- if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
- internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
- }
- const index = outputNames.indexOf(outputSymbolicTensors[i].name);
- if (index !== -1) {
- finalOutputs[index] = outputTensors[i];
- }
- }
- if (!training) {
- // Clean up Tensors that are no longer needed.
- dispose(tensorsToDispose);
- }
- }
- // NOTE(cais): Unlike intermediate tensors, we don't discard mask
- // tensors as we go, because these tensors are sometimes passed over a
- // series of mutliple layers, i.e., not obeying the immediate input
- // relations in the graph. If this becomes a memory-usage concern,
- // we can improve this in the future.
- internalFeedDict.disposeMasks();
- return arrayFetches ? finalOutputs : finalOutputs[0];
- }
- /**
- * Sort the `SymbolicTensor`s topologically, for an array of fetches.
- *
- * This function calls getTopologicalSortAndRecipientCountsForOneFetch and
- * merges their results.
- *
- * @param fetch The array of fetches requested. Must be a non-empty array.
- * @param feedDict The dictionary of fed values.
- * @returns sorted: Topologically-sorted array of SymbolicTensors.
- * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
- */
- function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
- assert(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
- let finalSorted = [];
- let finalRecipientMap = {};
- if (fetches.length === 1) {
- // Special-casing 1 fetch for efficiency.
- const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
- finalSorted = out.sorted;
- finalRecipientMap = out.recipientMap;
- }
- else {
- const visited = new Set();
- for (const fetch of fetches) {
- const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
- // Merge sorted SymbolicTensor Arrays.
- for (const symbolicTensor of sorted) {
- if (!visited.has(symbolicTensor.name)) {
- finalSorted.push(symbolicTensor);
- visited.add(symbolicTensor.name);
- }
- }
- // Merge recipient maps.
- for (const name in recipientMap) {
- if (finalRecipientMap[name] == null) {
- finalRecipientMap[name] = new Set();
- }
- recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
- }
- }
- }
- return {
- sorted: finalSorted,
- recipientCounts: recipientMap2Counts(finalRecipientMap)
- };
- }
- function recipientMap2Counts(recipientMap) {
- const recipientCounts = {};
- for (const name in recipientMap) {
- recipientCounts[name] = recipientMap[name].size;
- }
- return recipientCounts;
- }
- /**
- * Sort the `SymbolicTensor`s topologically, for a single fetch.
- *
- * This helper function processes the upstream SymbolicTensors of a single
- * fetch.
- *
- * @param fetch The single fetch requested.
- * @param feedDict The dictionary of fed values.
- * @returns sorted: Topologically-sorted array of SymbolicTensors.
- * recipientMap: Recipient names for all SymbolicTensors in `sorted`.
- */
- function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
- const visited = new Set();
- const sorted = [];
- const recipientMap = {};
- // Put keys of the feedDict into visited first, so they don't have to be
- // walked. This is needed in case where there are feeds for intermediate
- // SymbolicTensors of the graph.
- for (const key of feedDict.names()) {
- visited.add(key);
- }
- const stack = [];
- const marks = [];
- // Initial population of stack and marks.
- stack.push(fetch);
- while (stack.length > 0) {
- const top = stack[stack.length - 1];
- if (visited.has(top.name)) {
- stack.pop();
- continue;
- }
- const topIsMarked = marks[marks.length - 1] === stack.length - 1;
- if (top.inputs.length === 0 || topIsMarked) {
- // Input SymbolicTensor or all children have been visited.
- stack.pop();
- sorted.push(top);
- visited.add(top.name);
- if (topIsMarked) {
- marks.pop();
- }
- }
- else {
- // A non-input SymbolicTensor whose upstream SymbolicTensors haven't
- // been visited yet. Push them onto the stack.
- marks.push(stack.length - 1);
- for (const input of top.inputs) {
- // Increment the recipient count. Note that this needs to happen
- // regardless of whether the SymbolicTensor has been visited before.
- if (recipientMap[input.name] == null) {
- recipientMap[input.name] = new Set();
- }
- recipientMap[input.name].add(top.name);
- if (visited.has(input.name)) {
- continue; // Avoid repeated visits to the same SymbolicTensor.
- }
- stack.push(input);
- }
- }
- }
- return { sorted, recipientMap };
- }
- /**
- * Get the symbolic output tensors of the node to which a given fetch belongs.
- * @param fetch The fetched symbolic tensor.
- * @returns The Array of symbolic tensors output by the node to which `fetch`
- * belongs.
- */
- function getNodeOutputs(fetch) {
- let layerOutputs;
- if (fetch.sourceLayer.inboundNodes.length === 1) {
- layerOutputs = fetch.sourceLayer.output;
- }
- else {
- let nodeIndex = null;
- for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
- for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
- .outputTensors) {
- if (outputTensor.id === fetch.id) {
- nodeIndex = i;
- break;
- }
- }
- }
- layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
- }
- return layerOutputs;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * A Container is a directed acyclic graph of layers.
- *
- * It is the topological form of a "model". A LayersModel
- * is simply a Container with added training routines.
- *
- */
- class Container extends Layer {
- constructor(args) {
- // No args passed to super's constructor.
- super({});
- this.containerNodes = new Set();
- this.name = args.name;
- if (this.name == null) {
- const prefix = this.getClassName().toLowerCase();
- this.name = getUid(prefix);
- }
- this.supportsMasking = false;
- this.trainable_ = true;
- // TODO(michaelterry): Initialize perInputLosses/Updates here.
- // Container-specific properties.
- if (Array.isArray(args.inputs)) {
- this.inputs = args.inputs.slice();
- }
- else {
- this.inputs = [args.inputs];
- }
- if (Array.isArray(args.outputs)) {
- this.outputs = args.outputs.slice();
- }
- else {
- this.outputs = [args.outputs];
- }
- // Check for redundancy in inputs.
- if (unique$1(this.inputs).length !== this.inputs.length) {
- throw new ValueError('The list of inputs passed to the model is ' +
- 'redundant. All inputs should only appear once. Found: ' +
- `${this.inputs.map(x => x.name)}`);
- }
- // Check for redundancy in outputs.
- if (unique$1(this.outputs).length !== this.outputs.length) {
- console.warn('The list of outputs passed to the model is redundant. ' +
- 'All outputs should only appear once. Found: ' +
- `${this.outputs.map(x => x.name)}`);
- }
- /*
- List of initial layers (1 to 1 mapping with this.inputs, hence the same
- layer might appear twice)
- */
- this.inputLayers = [];
- this.inputLayersNodeIndices = [];
- this.inputLayersTensorIndices = [];
- /*
- List of layers (1 to 1 mapping with this.outputs, hence the same layer
- might appear twice)
- */
- this.outputLayers = [];
- this.outputLayersNodeIndices = [];
- this.outputLayersTensorIndices = [];
- /*
- All layers in order of horizontal graph traversal. Entries are unique.
- Includes input and output layers.
- */
- this.layers = [];
- /*
- References to container layers that were constructed internally. We need
- these to properly dispose of tensors from nested containers.
- */
- this.internalContainerRefs = [];
- // TODO(michaelterry): Determine if caching still needed with eager
- // backend.
- /*
- This is for performance optimization when calling the Container on new
- inputs. Every time the Container is called on a set on input tensors,
- we compute the output tensors, output masks and output shapes in one pass,
- then cache them here. When one of these outputs is queried later,
- we retrieve it from there instead of recomputing it.
- */
- // this.outputTensorCache = {};
- // this.outputShapeCache = {};
- // Build this.outputLayers:
- for (const x of this.outputs) {
- const layer = x.sourceLayer;
- const nodeIndex = x.nodeIndex;
- const tensorIndex = x.tensorIndex;
- this.outputLayers.push(layer);
- this.outputLayersNodeIndices.push(nodeIndex);
- this.outputLayersTensorIndices.push(tensorIndex);
- }
- // TODO(michaelterry): Add output mask cache code.
- // Build this.inputLayers:
- for (const x of this.inputs) {
- const layer = x.sourceLayer;
- const nodeIndex = x.nodeIndex;
- const tensorIndex = x.tensorIndex;
- /*
- It's supposed to be an input layer, so only one node
- and one tensor output.
- */
- assert$1(nodeIndex === 0, 'input layer has >1 nodes');
- assert$1(tensorIndex === 0, 'input layer has >1 tensors');
- this.inputLayers.push(layer);
- this.inputLayersNodeIndices.push(nodeIndex);
- this.inputLayersTensorIndices.push(tensorIndex);
- }
- // Build this.inputNames and this.outputNames.
- this.inputNames = [];
- this.outputNames = [];
- this.feedInputShapes = [];
- this.feedInputNames = [];
- this.feedOutputNames = [];
- for (let i = 0; i < this.inputLayers.length; i++) {
- const layer = this.inputLayers[i];
- // Check that layer is an InputLayer.
- if (!(layer instanceof InputLayer)) {
- throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' +
- `Received inputs: ${args.inputs}. ` +
- `Input ${i} (0-based) originates ` +
- `from layer type ${layer.getClassName()}.`);
- }
- this.inputNames.push(layer.name);
- this.feedInputShapes.push(layer.batchInputShape);
- this.feedInputNames.push(layer.name);
- }
- for (const layer of this.outputLayers) {
- this.outputNames.push(layer.name);
- }
- this.internalInputShapes = this.inputs.map(x => x.shape);
- this.internalOutputShapes = this.outputs.map(x => x.shape);
- /*
- Container_nodes: set of nodes included in the graph (not all nodes
- included in the layers are relevant to the current graph).
- */
- // ids of all nodes relevant to the Container:
- const nodesDepths = {};
- // To recover nodes from their ID.
- const nodeIDToNode = {};
- const layersDepths = {};
- // To layers from their ID.
- const layerIDToLayer = {};
- const layerIndices = {};
- const nodesInDecreasingDepth = [];
- /**
- * Builds a map of the graph of layers.
- *
- * This recursively updates the map `layerIndices`,
- * the list `nodesInDecreasingDepth` and the set `containerNodes`.
- *
- * @param tensor Some tensor in a graph.
- * @param finishedNodes Set of nodes whose subgraphs have been traversed
- * completely. Useful to prevent duplicated work.
- * @param nodesInProgress Set of nodes that are currently active on the
- * recursion stack. Useful to detect cycles.
- * @param layer Layer from which `tensor` comes from. If not provided,
- * will be obtained from tensor.sourceLayer.
- * @param nodeIndex Node index from which `tensor` comes from.
- * @param tensorIndex TensorIndex from which `tensor` comes from.
- *
- * @exception RuntimeError if a cycle is detected.
- */
- const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => {
- if (layer == null || nodeIndex == null || tensorIndex == null) {
- layer = tensor.sourceLayer;
- nodeIndex = tensor.nodeIndex;
- tensorIndex = tensor.tensorIndex;
- }
- const node = layer.inboundNodes[nodeIndex];
- // Prevent cycles.
- if (nodesInProgress.indexOf(node) !== -1) {
- throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` +
- 'is part of a cycle.');
- }
- // Don't repeat work for shared subgraphs
- if (finishedNodes.indexOf(node) !== -1) {
- return;
- }
- // Update containerNodes.
- this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
- // Store the traversal order for layer sorting.
- if (!(layer.id in layerIndices)) {
- layerIndices[layer.id] = Object.keys(layerIndices).length;
- }
- if (nodesInProgress.indexOf(node) === -1) {
- nodesInProgress.push(node);
- }
- // Propagate to all previous tensors connected to this node.
- const numInboundLayers = node.inboundLayers.length;
- for (let i = 0; i < numInboundLayers; i++) {
- const x = node.inputTensors[i];
- const layer = node.inboundLayers[i];
- const nodeIndex = node.nodeIndices[i];
- const tensorIndex = node.tensorIndices[i];
- buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex);
- }
- finishedNodes.push(node);
- while (nodesInProgress.indexOf(node) >= 0) {
- nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
- }
- nodesInDecreasingDepth.push(node);
- };
- const finishedNodes = [];
- const nodesInProgress = [];
- for (const x of this.outputs) {
- buildMapOfGraph(x, finishedNodes, nodesInProgress);
- }
- const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
- for (const node of reversedNodesInDecreasingDepth) {
- nodeIDToNode[node.id] = node;
- // If the depth is not set, the node has no outbound nodes (depth 0).
- if (!(node.id in nodesDepths)) {
- nodesDepths[node.id] = 0;
- }
- let depth = nodesDepths[node.id];
- // Update the depth of the corresponding layer
- const previousDepth = (layersDepths[node.outboundLayer.id] == null ?
- 0 :
- layersDepths[node.outboundLayer.id]);
- /*
- If we've seen this layer before at a higher depth, we should use that
- depth instead of the node depth. This is necessary for shared layers
- that have inputs at different depth levels in the graph.
- */
- depth = Math.max(depth, previousDepth);
- layersDepths[node.outboundLayer.id] = depth;
- layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
- nodesDepths[node.id] = depth;
- // Update the depth of inbound nodes.
- for (let i = 0; i < node.inboundLayers.length; i++) {
- const inboundLayer = node.inboundLayers[i];
- const nodeIndex = node.nodeIndices[i];
- const inboundNode = inboundLayer.inboundNodes[nodeIndex];
- const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 :
- nodesDepths[inboundNode.id]);
- nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth);
- nodeIDToNode[inboundNode.id] = inboundNode;
- }
- }
- // Build a dict {depth: list of nodes with this depth}
- const nodesByDepth = {};
- for (const nodeID in nodesDepths) {
- const depth = nodesDepths[nodeID];
- if (!(depth in nodesByDepth)) {
- nodesByDepth[depth] = [];
- }
- nodesByDepth[depth].push(nodeIDToNode[nodeID]);
- }
- // Build a dict {depth: list of layers with this depth}
- const layersByDepth = {};
- for (const layerID in layersDepths) {
- const depth = layersDepths[layerID];
- if (!(depth in layersByDepth)) {
- layersByDepth[depth] = [];
- }
- layersByDepth[depth].push(layerIDToLayer[layerID]);
- }
- // Get sorted list of layer depths.
- let depthKeys = Object.keys(layersByDepth)
- .map(x => parseInt(x, 10))
- .sort(reverseNumberCompare);
- // Set this.layers and this.layersByDepth.
- this.layers = [];
- for (const depth of depthKeys) {
- const layersForDepth = layersByDepth[depth];
- // Container.layers needs to have a deterministic order:
- // here we order them by traversal order.
- layersForDepth.sort((a, b) => {
- const aIndex = layerIndices[a.id];
- const bIndex = layerIndices[b.id];
- if (aIndex < bIndex) {
- return -1;
- }
- if (aIndex > bIndex) {
- return 1;
- }
- return 0;
- });
- for (const layer of layersForDepth) {
- if (layer instanceof Container) {
- this.internalContainerRefs.push(layer);
- }
- this.layers.push(layer);
- }
- }
- this.layersByDepth = layersByDepth;
- // Get sorted list of node depths;
- depthKeys = Object.keys(nodesByDepth)
- .map(x => parseInt(x, 10))
- .sort(reverseNumberCompare);
- // Check that all tensors required are computable.
- // computable_tensors: all tensors in the graph
- // that can be computed from the inputs provided.
- const computableTensors = this.inputs.slice();
- // To provide a better error msg.
- const layersWithCompleteInput = [];
- for (const depth of depthKeys) {
- for (const node of nodesByDepth[depth]) {
- const layer = node.outboundLayer;
- if (layer != null) {
- for (const x of node.inputTensors) {
- if (computableTensors.indexOf(x) === -1) {
- throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` +
- ` at layer "${layer.name}". ` +
- 'The following previous layers were accessed without ' +
- `issue: ${layersWithCompleteInput}`);
- }
- }
- for (const x of node.outputTensors) {
- computableTensors.push(x);
- }
- layersWithCompleteInput.push(layer.name);
- }
- }
- }
- // Set this.containerNodes and this.nodesByDepth.
- this.nodesByDepth = nodesByDepth;
- // Ensure name unicity, which will be crucial for serialization
- // (since serialized nodes refer to layers by their name).
- const allNames = this.layers.map(x => x.name);
- for (const name of allNames) {
- const numOccurrences = allNames.filter(x => x === name).length;
- if (numOccurrences !== 1) {
- throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` +
- 'in the model. All layer names should be unique. Layer names: ' +
- JSON.stringify(allNames));
- }
- }
- // Layer parameters.
- // The new container starts with a single inbound node
- // for its inputs, and no outbound nodes.
- // Will be appended to by future calls to apply().
- this.outboundNodes = [];
- // Will be appended to below, and by future calls to apply().
- this.inboundNodes = [];
- // Create the node linking internal inputs to internal outputs.
- // (This call has side effects.)
- // tslint:disable-next-line:no-unused-expression
- new Node({
- outboundLayer: this,
- inboundLayers: [],
- nodeIndices: [],
- tensorIndices: [],
- inputTensors: this.inputs,
- outputTensors: this.outputs,
- inputMasks: this.inputs.map(x => null),
- outputMasks: this.outputs.map(x => null),
- inputShapes: this.inputs.map(x => x.shape),
- outputShapes: this.outputs.map(x => x.shape)
- });
- this.built = true;
- this._refCount = 1; // The ref count of a container always start at 1.
- }
- assertNotDisposed() {
- if (this._refCount === 0) {
- throw new Error(`Container '${this.name}' is already disposed.`);
- }
- }
- /**
- * Attempt to dispose a LayersModel's weights.
- *
- * This method decrease the reference count of the LayersModel object by 1.
- *
- * A LayersModel is reference-counted. Its reference count is incremented by 1
- * when it is first constructed and when it is used as a Layer of another
- * LayersModel.
- *
- * If the reference count of a LayersModel becomes 0, the `dispose` method of
- * all its constituent `Layer`s will be called.
- *
- * Note: If the reference count is greater than 0 after the decrement, the
- * `dispose` method of its constituent `Layer`s will *not* be called.
- *
- * After a LayersModel is disposed, it cannot be used in calls such as
- * 'predict`, `evaluate` or `fit` anymore.
- *
- * @returns A DisposeResult Object with the following fields:
- * - refCountAfterDispose: The reference count of the LayersModel after this
- * `dispose()` call.
- * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
- * during this `dispose()` call.
- * @throws {Error} If the layer is not built yet, or if the LayersModel has
- * already been disposed.
- */
- dispose() {
- this.assertNotDisposed();
- const result = { refCountAfterDispose: null, numDisposedVariables: 0 };
- if (--this._refCount === 0) {
- for (const layer of this.layers) {
- result.numDisposedVariables += layer.dispose().numDisposedVariables;
- }
- // Call dispose on each internally created container layer again to ensure
- // their refCounts hit zero and their tensors are subsequently deleted.
- for (const container of this.internalContainerRefs) {
- result.numDisposedVariables += container.dispose().numDisposedVariables;
- }
- }
- result.refCountAfterDispose = this._refCount;
- return result;
- }
- get trainable() {
- return this.trainable_;
- }
- set trainable(trainable) {
- this.layers.forEach(layer => {
- // tslint:disable-next-line:no-any
- layer._trainableWeights
- .forEach(w => w.trainable = trainable);
- });
- this.trainable_ = trainable;
- }
- get trainableWeights() {
- // Porting Note: This check below is to prevent errors where the
- // _trainableWeights inherited from the parent class (Layer) gets
- // inadvertently used.
- if (this._trainableWeights.length > 0) {
- throw new ValueError('Container instance unexpectedly contains _trainableWeights.' +
- 'The trainable weights of a Container are a union of the ' +
- 'trainable weights of its consituent Layers. Its own ' +
- '_trainableWeights must remain an empty Array.');
- }
- if (!this.trainable) {
- return [];
- }
- let weights = [];
- for (const layer of this.layers) {
- weights = weights.concat(layer.trainableWeights);
- }
- return weights;
- }
- get nonTrainableWeights() {
- const weights = [];
- for (const layer of this.layers) {
- weights.push(...layer.nonTrainableWeights);
- }
- if (!this.trainable) {
- const trainableWeights = [];
- for (const layer of this.layers) {
- trainableWeights.push(...layer.trainableWeights);
- }
- return trainableWeights.concat(weights);
- }
- return weights;
- }
- get weights() {
- return this.trainableWeights.concat(this.nonTrainableWeights);
- }
- /**
- * Loads all layer weights from a JSON object.
- *
- * Porting Note: HDF5 weight files cannot be directly loaded in JavaScript /
- * TypeScript. The utility script at `scripts/pykeras.py` offers means
- * to convert them into JSON strings compatible with this method.
- * Porting Note: TensorFlow.js Layers supports only loading by name currently.
- *
- * @param weights A JSON mapping weight names to weight values as nested
- * arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight
- * names to `tf.Tensor` objects.
- * @param strict Require that the provided weights exactly match those
- * required by the container. Default: `true`. Passing `false` means that
- * extra weights and missing weights will be silently ignored.
- */
- loadWeights(weights, strict = true) {
- const nameToWeight = {};
- let totalWeightsCount = 0;
- for (const layer of this.layers) {
- for (const weight of layer.weights) {
- if (nameToWeight[weight.originalName] != null) {
- throw new ValueError(`Duplicate weight name: ${weight.originalName}`);
- }
- nameToWeight[weight.originalName] = weight;
- totalWeightsCount++;
- }
- }
- const weightValueTuples = [];
- for (const name in weights) {
- // TF 2.2.0 added cell name to the weight name in the format of
- // layer_name/cell_name/weight_name, we need to remove
- // the inner cell name.
- let validatedName = name;
- if (nameToWeight[name] == null) {
- const tokens = name.split('/');
- const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
- validatedName = shortenNameArray.join('/');
- }
- if (nameToWeight[validatedName] != null) {
- weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
- }
- else if (strict) {
- throw new ValueError(`Provided weight data has no target variable: ${name}`);
- }
- delete nameToWeight[validatedName];
- }
- if (strict) {
- // Check that all weights are set.
- const unsetNames = [];
- for (const name in nameToWeight) {
- unsetNames.push(name);
- }
- if (unsetNames.length > 0) {
- throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` +
- `${unsetNames}`);
- }
- }
- batchSetValue(weightValueTuples);
- }
- /**
- * Util shared between different serialization methods.
- * @returns LayersModel config with Keras version information added.
- */
- updatedConfig() {
- const theConfig = this.getConfig();
- const modelConfig = {};
- modelConfig['className'] = this.getClassName();
- modelConfig['config'] = theConfig;
- modelConfig['kerasVersion'] = `tfjs-layers ${version$1}`;
- // TODO(nielsene): Replace something like K.backend() once
- // possible.
- modelConfig['backend'] = 'TensorFlow.js';
- return modelConfig;
- }
- /**
- * Returns a JSON string containing the network configuration.
- *
- * To load a network from a JSON save file, use
- * models.modelFromJSON(jsonString);
- * @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras
- * @param returnString Whether the return value should be stringified
- * (default: `true`).
- * @returns a JSON string if `returnString` (default), or a JSON object if
- * `!returnString`.
- */
- // tslint:disable-next-line:no-any
- toJSON(unused, returnString = true) {
- const modelConfig = convertTsToPythonic(this.updatedConfig());
- return returnString ? JSON.stringify(modelConfig) : modelConfig;
- }
- /**
- * Call the model on new inputs.
- *
- * In this case `call` just reapplies all ops in the graph to the new inputs
- * (e.g. build a new computational graph from the provided inputs).
- *
- * @param inputs A tensor or list of tensors.
- * @param mask A mask or list of masks. A mask can be either a tensor or null
- * (no mask).
- *
- * @return A tensor if there is a single output, or a list of tensors if there
- * are more than one outputs.
- */
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = toList(inputs);
- const feedDict = new FeedDict();
- for (let i = 0; i < this.inputs.length; ++i) {
- feedDict.add(this.inputs[i], inputs[i]);
- }
- return execute(this.outputs, feedDict, kwargs);
- });
- }
- /**
- * Computes an output mask tensor.
- *
- * @param inputs Tensor or list of tensors.
- * @param mask Tensor or list of tensors.
- *
- * @return null or a tensor (or list of tensors, one per output tensor of the
- * layer).
- */
- computeMask(inputs, mask) {
- return tidy(() => {
- inputs = toList(inputs);
- let masks;
- if (mask == null) {
- masks = pyListRepeat(null, inputs.length);
- }
- else {
- masks = toList(mask);
- }
- // TODO(michaelterry): Add support for mask caching.
- return this.runInternalGraph(inputs, masks)[1];
- });
- }
- /**
- * Computes the output shape of the layer.
- *
- * Assumes that the layer will be built to match that input shape provided.
- *
- * @param inputShape A shape (tuple of integers) or a list of shape tuples
- * (one per output tensor of the layer). Shape tuples can include null for
- * free dimensions, instead of an integer.
- */
- computeOutputShape(inputShape) {
- const inputShapes = normalizeShapeList(inputShape);
- if (inputShapes.length !== this.inputLayers.length) {
- throw new ValueError(`Invalid inputShape argument ${inputShape}: ` +
- `model has ${this.inputLayers.length} tensor inputs.`);
- }
- // TODO(michaelterry): Add caching
- const layersToOutputShapes = {};
- for (let i = 0; i < inputShapes.length; i++) {
- const layer = this.inputLayers[i];
- const inputShape = inputShapes[i];
- // It's an input layer: computeOutputShape is identity,
- // and there is only one node and one tensor output.
- const shapeKey = layer.name + '_0_0';
- layersToOutputShapes[shapeKey] = inputShape;
- }
- const depthKeys = Object.keys(this.nodesByDepth)
- .map(x => parseInt(x, 10))
- .sort(reverseNumberCompare);
- // Iterate over nodes, by depth level.
- if (depthKeys.length > 1) {
- for (const depth of depthKeys) {
- const nodes = this.nodesByDepth[depth];
- for (const node of nodes) {
- // This is always a single layer, never a list.
- const layer = node.outboundLayer;
- if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) {
- // We've already covered the input layers a few lines above.
- continue;
- }
- // Potentially redundant list, same size of node.inputTensors.
- const inputShapes = [];
- for (let j = 0; j < node.inboundLayers.length; j++) {
- const inboundLayer = node.inboundLayers[j];
- const nodeIndex = node.nodeIndices[j];
- const tensorIndex = node.tensorIndices[j];
- const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`;
- const inputShape = layersToOutputShapes[shapeKey];
- inputShapes.push(inputShape);
- }
- const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes));
- const outputShapes = normalizeShapeList(outputShape);
- const nodeIndex = layer.inboundNodes.indexOf(node);
- for (let j = 0; j < outputShapes.length; j++) {
- const shapeKey = `${layer.name}_${nodeIndex}_${j}`;
- layersToOutputShapes[shapeKey] = outputShapes[j];
- }
- }
- }
- }
- // Read final output shapes from layersToOutputShapes.
- const outputShapes = [];
- const outputShapeKeys = [];
- for (let i = 0; i < this.outputLayers.length; i++) {
- const layer = this.outputLayers[i];
- const nodeIndex = this.outputLayersNodeIndices[i];
- const tensorIndex = this.outputLayersTensorIndices[i];
- const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`;
- outputShapeKeys.push(shapeKey);
- }
- for (let i = 0; i < outputShapeKeys.length; i++) {
- const key = outputShapeKeys[i];
- assert$1(key in layersToOutputShapes);
- outputShapes.push(layersToOutputShapes[key]);
- }
- // TODO(michaelterry): Update cache
- return singletonOrArray(outputShapes);
- }
- /**
- * Computes output tensors for new inputs.
- *
- * Note:
- * - Expects `inputs` to be a list (potentially with 1 element).
- *
- * @param inputs List of tensors
- * @param masks List of masks (tensors or null).
- * @return Three lists: outputTensors, outputMasks, outputShapes
- */
- runInternalGraph(inputs, masks) {
- if (masks == null) {
- masks = pyListRepeat(null, inputs.length);
- }
- // Dictionary mapping reference tensors to tuples
- // (computed tensor, compute mask)
- // we assume a 1:1 mapping from tensor to mask
- // TODO: raise exception when a `.computeMask()` call
- // does not return a list the same size as `call`
- const tensorMap = {};
- for (let i = 0; i < this.inputs.length; ++i) {
- const x = this.inputs[i];
- const y = inputs[i];
- const mask = masks[i];
- tensorMap[x.id] = [y, mask];
- }
- const depthKeys = Object.keys(this.nodesByDepth)
- .map(x => parseInt(x, 10))
- .sort(reverseNumberCompare);
- for (const depth of depthKeys) {
- const nodes = this.nodesByDepth[depth];
- for (const node of nodes) {
- // This is always a single layer, never a list.
- const layer = node.outboundLayer;
- const referenceInputTensors = node.inputTensors;
- const referenceOutputTensors = node.outputTensors;
- // If all previous input tensors are available in tensorMap,
- // then call node.inboundLayer on them.
- // List of tuples [input, mask]:
- const computedData = new Array();
- for (const x of referenceInputTensors) {
- if (x.id in tensorMap) {
- computedData.push(tensorMap[x.id]);
- }
- }
- if (computedData.length === referenceInputTensors.length) {
- // TODO(michaelterry): Add K.name_scope here, if we need it.
- let kwargs = {};
- let computedTensors;
- let computedMasks;
- let outputTensors;
- let outputMasks;
- // call layer
- if (node.callArgs != null) {
- kwargs = node.callArgs;
- }
- if (computedData.length === 1) {
- const [computedTensor, computedMask] = computedData[0];
- if (kwargs['mask'] == null) {
- kwargs['mask'] = computedMask;
- }
- outputTensors =
- toList(layer.call(computedTensor, kwargs));
- outputMasks = toList(layer.computeMask(computedTensor, computedMask));
- computedTensors = [computedTensor];
- computedMasks = [computedMask];
- }
- else {
- computedTensors = computedData.map(x => x[0]);
- computedMasks = computedData.map(x => x[1]);
- if (kwargs['mask'] == null) {
- kwargs['mask'] = computedMasks;
- }
- outputTensors =
- toList(layer.call(computedTensors, kwargs));
- outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
- }
- if (layer.activityRegularizer) {
- throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' +
- 'presence of activity regularizer(s) is not supported yet.');
- }
- // TODO(michaelterry): Add model updates and losses
- // Update tensor map.
- for (let i = 0; i < referenceOutputTensors.length; ++i) {
- const x = referenceOutputTensors[i];
- const y = outputTensors[i];
- const mask = outputMasks[i];
- tensorMap[x.id] = [y, mask];
- }
- }
- }
- }
- const outputTensors = [];
- const outputMasks = [];
- const outputShapes = [];
- for (const x of this.outputs) {
- assert$1(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`);
- const [tensor, mask] = tensorMap[x.id];
- outputShapes.push(tensor.shape);
- outputTensors.push(tensor);
- outputMasks.push(mask);
- }
- // TODO(michaelterry): Add support for caches.
- return [outputTensors, outputMasks, outputShapes];
- }
- /**
- * Builds a map of internal node keys to node ordering.
- * Used in serializaion a node orderings may change as unused nodes are
- * dropped. Porting Note: This helper method was pulled out of getConfig to
- * improve readability.
- * @param layers An array of Layers in the model.
- * @returns Map of Node Keys to index order within the layer.
- */
- buildNodeConversionMap(layers) {
- const nodeConversionMap = {};
- let keptNodes;
- for (const layer of this.layers) {
- keptNodes = layer instanceof Container ? 1 : 0;
- for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
- const nodeKey = Container.nodeKey(layer, originalNodeIndex);
- if (this.containerNodes.has(nodeKey)) {
- // i.e. we mark it to be saved
- nodeConversionMap[nodeKey] = keptNodes;
- keptNodes += 1;
- }
- }
- }
- return nodeConversionMap;
- }
- /**
- * Retrieves a layer based on either its name (unique) or index.
- *
- * Indices are based on order of horizontal graph traversal (bottom-up).
- *
- * If both `name` and `index` are specified, `index` takes precedence.
- *
- * @param name Name of layer.
- * @param index Index of layer.
- * @returns A Layer instance.
- * @throws ValueError: In case of invalid layer name or index.
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Classes',
- * namespace: 'layers',
- * subclasses: ['LayersModel']
- * }
- */
- getLayer(name, index) {
- if (index != null) {
- if (this.layers.length <= index) {
- throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` +
- `has ${this.layers.length} layer(s).`);
- }
- else {
- return this.layers[index];
- }
- }
- else {
- if (name == null) {
- throw new ValueError('Provide either a layer name or layer index');
- }
- }
- for (const layer of this.layers) {
- if (layer.name === name) {
- return layer;
- }
- }
- throw new ValueError(`No such layer: ${name}`);
- }
- /**
- * Retrieves the Container's current loss values.
- *
- * Used for regularizers during training.
- */
- calculateLosses() {
- // Porting Node: This is an augmentation to Container.loss in PyKeras.
- // In PyKeras, Container.loss returns symbolic tensors. Here a concrete
- // Tensor (specifically Scalar) values are returned. This is due to the
- // imperative backend.
- return tidy(() => {
- const losses = [];
- for (const layer of this.layers) {
- for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
- const nodeKey = Container.nodeKey(layer, nodeIndex);
- if (this.containerNodes.has(nodeKey)) {
- losses.push(...layer.calculateLosses());
- }
- }
- }
- // TODO(cais): Add any unconditional model-level losses?
- return losses;
- });
- }
- getConfig() {
- const config = { name: this.name };
- // Build a map from layer unique name (self._node_key)
- // to the index of the nodes that are saved in the config.
- // Only nodes in container_nodes are saved.
- const nodeConversionMap = this.buildNodeConversionMap(this.layers);
- // Serialize and save the layers in layerConfigs
- const layerConfigs = [];
- for (const layer of this.layers) {
- const layerClassName = layer.getClassName();
- const layerConfig = layer.getConfig();
- const filteredInboundNodes = [];
- for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
- const node = layer.inboundNodes[originalNodeIndex];
- const nodeKey = Container.nodeKey(layer, originalNodeIndex);
- let kwargs = {};
- if (this.containerNodes.has(nodeKey)) {
- // The node is relevant to the model:
- // add to filteredInboundNodes.
- if (node.callArgs) {
- try {
- JSON.stringify(node.callArgs);
- kwargs = node.callArgs;
- }
- catch (err) {
- console.warn(`Layer ${layer.name} was passed ` +
- `non-serializable keyword arguments: ` +
- `${node.callArgs}. They will not be included ` +
- `in the serialized model (and thus will be ` +
- `missing at deserialization time).`);
- kwargs = {};
- }
- }
- if (node.inboundLayers.length > 0) {
- const nodeData = [];
- for (let i = 0; i < node.inboundLayers.length; i++) {
- const inboundLayer = node.inboundLayers[i];
- const nodeIndex = node.nodeIndices[i];
- const tensorIndex = node.tensorIndices[i];
- const nodeKey = Container.nodeKey(inboundLayer, nodeIndex);
- let newNodeIndex = nodeConversionMap[nodeKey];
- if (newNodeIndex == null) {
- newNodeIndex = 0;
- }
- nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
- }
- filteredInboundNodes.push(nodeData);
- }
- }
- }
- const dict = {};
- dict['name'] = layer.name;
- dict['className'] = layerClassName;
- dict['config'] = layerConfig;
- dict['inboundNodes'] = filteredInboundNodes;
- layerConfigs.push(dict);
- }
- config['layers'] = layerConfigs;
- // Gather info about inputs and outputs
- const modelInputs = [];
- for (let i = 0; i < this.inputLayers.length; i++) {
- const layer = this.inputLayers[i];
- const nodeIndex = this.inputLayersNodeIndices[i];
- const nodeKey = Container.nodeKey(layer, nodeIndex);
- if (!this.containerNodes.has(nodeKey)) {
- continue;
- }
- let newNodeIndex = nodeConversionMap[nodeKey];
- if (newNodeIndex === null || newNodeIndex === undefined) {
- newNodeIndex = 0;
- }
- const tensorIndex = this.inputLayersTensorIndices[i];
- modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
- }
- config['inputLayers'] = modelInputs;
- const modelOutputs = [];
- for (let i = 0; i < this.outputLayers.length; i++) {
- const layer = this.outputLayers[i];
- const nodeIndex = this.outputLayersNodeIndices[i];
- const nodeKey = Container.nodeKey(layer, nodeIndex);
- if (!this.containerNodes.has(nodeKey)) {
- continue;
- }
- let newNodeIndex = nodeConversionMap[nodeKey];
- if (newNodeIndex === null || newNodeIndex === undefined) {
- newNodeIndex = 0;
- }
- const tensorIndex = this.outputLayersTensorIndices[i];
- modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
- }
- config['outputLayers'] = modelOutputs;
- return config;
- }
- /**
- * Instantiates a LayersModel from its config (output of `get_config()`).
- * @param cls the class to create
- * @param config LayersModel config dictionary.
- * @param customObjects An optional dictionary of custom objects.
- * @param fastWeightInit Optional flag to use fast weight initialization
- * during deserialization. This is applicable to cases in which
- * the initialization will be immediately overwritten by loaded weight
- * values. Default: `false`.
- * @returns A LayersModel instance.
- * @throws ValueError: In case of improperly formatted config dict.
- */
- /** @nocollapse */
- static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
- // Layer instances created during
- // the graph reconstruction process
- const createdLayers = {};
- // Dictionary mapping layer instances to
- // node data that specifies a layer call.
- // It acts as a queue that maintains any unprocessed
- // layer call until it becomes possible to process it
- // (i.e. until the input tensors to the call all exist).
- const unprocessedNodes = {};
- function addUnprocessedNode(layer, nodeData) {
- if (!(layer.name in unprocessedNodes)) {
- unprocessedNodes[layer.name] = [nodeData];
- }
- else {
- unprocessedNodes[layer.name].push(nodeData);
- }
- }
- function processNode(layer, nodeData) {
- const inputTensors = [];
- let kwargs;
- for (const inputData of nodeData) {
- const inboundLayerName = inputData[0];
- const inboundNodeIndex = inputData[1];
- const inboundTensorIndex = inputData[2];
- kwargs = inputData[3] == null ?
- {} :
- inputData[3];
- if (!(inboundLayerName in createdLayers)) {
- addUnprocessedNode(layer, nodeData);
- return;
- }
- const inboundLayer = createdLayers[inboundLayerName];
- if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
- addUnprocessedNode(layer, nodeData);
- return;
- }
- const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
- inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
- }
- // Call layer on its inputs, thus creating the node
- // and building the layer if needed.
- // Note: This has Eager vs Graph Implications.
- if (inputTensors.length > 0) {
- layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs
- }
- }
- /**
- * Deserialize a layer, then call it on appropriate inputs.
- * @param layerData: layer config dict.
- * @throws ValueError: In case of improperly formatted `layer_data`
- * dict.
- */
- function processLayer(layerData) {
- const layerName = layerData['name'];
- // Instantiate layer.
- const layer = deserialize(layerData, config['customObjects'] != null ?
- config['customObjects'] :
- {});
- layer.setFastWeightInitDuringBuild(fastWeightInit);
- createdLayers[layerName] = layer;
- // Gather layer inputs.
- const inboundNodesData = layerData['inboundNodes'];
- inboundNodesData.forEach(nodeData => {
- if (!(nodeData instanceof Array)) {
- throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`);
- }
- // We don't process nodes (i.e. make layer calls)
- // on the fly because the inbound node may not yet exist,
- // in case of layer shared at different topological depths
- // (e.g.a model such as A(B(A(B(x)))))
- addUnprocessedNode(layer, nodeData);
- });
- }
- // First, we create all layers and enqueue nodes to be processed.
- const name = config['name'];
- const layersFromConfig = config['layers'];
- for (const layerData of layersFromConfig) {
- processLayer(layerData);
- }
- // Then we process nodes in order of layer depth.
- // Nodes that cannot yet be processed(if the inbound node
- // does not yet exist) are re - enqueued, and the process
- // is repeated until all nodes are processed.
- while (!isObjectEmpty(unprocessedNodes)) {
- for (const layerData of layersFromConfig) {
- const layer = createdLayers[layerData['name']];
- if (layer.name in unprocessedNodes) {
- const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
- delete unprocessedNodes[layer.name];
- for (const nodeData of currentUnprocessedNodesForLayer) {
- processNode(layer, nodeData);
- }
- }
- }
- }
- const inputTensors = [];
- const outputTensors = [];
- const inputLayersFromConfig = config['inputLayers'];
- for (const layerData of inputLayersFromConfig) {
- const layerName = layerData[0];
- const nodeIndex = layerData[1];
- const tensorIndex = layerData[2];
- assert$1(layerName in createdLayers);
- const layer = createdLayers[layerName];
- const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
- inputTensors.push(layerOutputTensors[tensorIndex]);
- }
- const outputLayersFromConfig = config['outputLayers'];
- for (const layerData of outputLayersFromConfig) {
- const layerName = layerData[0];
- const nodeIndex = layerData[1];
- const tensorIndex = layerData[2];
- assert$1(layerName in createdLayers);
- const layer = createdLayers[layerName];
- const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
- outputTensors.push(layerOutputTensors[tensorIndex]);
- }
- return new cls({ inputs: inputTensors, outputs: outputTensors, name });
- }
- /**
- * Determine whether the container is stateful.
- *
- * Porting Note: this is the equivalent of the stateful @property of
- * the Container class in PyKeras.
- */
- get stateful() {
- // Porting Note: This check is to prevent inadvertent setting of the
- // _stateful property of the Container instance.
- if (this._stateful) {
- throw new ValueError('Container instance unexpectedly has _stateful = true. The ' +
- 'statefulness of a Container is determined by the Layers it ' +
- 'contains. Its _stateful property must remain the default false.');
- }
- for (const layer of this.layers) {
- if (layer.stateful) {
- return true;
- }
- }
- return false;
- }
- /**
- * Reset the state of all stateful constituent layers (if any).
- *
- * Examples of stateful layers include RNN layers whose `stateful` property
- * is set as `true`.
- */
- resetStates() {
- tidy(() => {
- this.layers.forEach(layer => {
- // tslint:disable:no-any
- if (layer.stateful) {
- layer.resetStates();
- }
- // tslint:enable:no-any
- });
- });
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
- const numOutputs = outputNames.length;
- if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
- return outputNames.map(name => null);
- }
- if (numOutputs === 1) {
- if (Array.isArray(xWeight) && xWeight.length === 1) {
- return xWeight;
- }
- else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
- return [xWeight[outputNames[0]]];
- }
- else {
- return [xWeight];
- }
- }
- if (Array.isArray(xWeight)) {
- if (xWeight.length !== numOutputs) {
- throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
- `element(s), but the model has ${numOutputs} outputs. ` +
- `Make sure a set of weights is provided for each model output.`);
- }
- return xWeight;
- }
- else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
- typeof xWeight[Object.keys(xWeight)[0]] ===
- 'object') {
- const output = [];
- outputNames.forEach(outputName => {
- if (outputName in xWeight) {
- output.push(xWeight[outputName]);
- }
- else {
- output.push(null);
- }
- });
- return output;
- }
- else {
- throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
- `so ${weightType} must be either an array with ` +
- `${numOutputs} elements or an object with ${outputNames} keys. ` +
- `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
- }
- }
- /**
- * Standardize class weighting objects.
- *
- * This function takes a single class-weighting object, an array of them,
- * or a map from output name to class-weighting object. It compares it to the
- * output name(s) of the model, base on which it outputs an array of
- * class-weighting objects of which the length matches the number of outputs.
- *
- * @param classWeight Input class-weighting object(s).
- * @param outputNames All output name(s) of the model.
- * @return An array of class-weighting objects. The length of the array matches
- * the model's number of outputs.
- */
- function standardizeClassWeights(classWeight, outputNames) {
- return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
- }
- function standardizeSampleWeights(classWeight, outputNames) {
- return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
- }
- /**
- * Standardize by-sample and/or by-class weights for training.
- *
- * Note that this function operates on one model output at a time. For a model
- * with multiple outputs, you must call this function multiple times.
- *
- * @param y The target tensor that the by-sample and/or by-class weight is for.
- * The values of y are assumed to encode the classes, either directly
- * as an integer index, or as one-hot encoding.
- * @param sampleWeight By-sample weights.
- * @param classWeight By-class weights: an object mapping class indices
- * (integers) to a weight (float) to apply to the model's loss for the
- * samples from this class during training. This can be useful to tell the
- * model to "pay more attention" to samples from an under-represented class.
- * @param sampleWeightMode The mode for the sample weights.
- * @return A Promise of weight tensor, of which the size of the first dimension
- * matches that of `y`.
- */
- async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
- if (sampleWeight != null || sampleWeightMode != null) {
- // TODO(cais): Once 'temporal' mode is implemented, document it in the doc
- // string.
- throw new Error('Support sampleWeight is not implemented yet');
- }
- if (classWeight != null) {
- // Apply class weights per sample.
- const yClasses = tidy(() => {
- if (y.shape.length === 1) {
- // Assume class indices.
- return y.clone();
- }
- else if (y.shape.length === 2) {
- if (y.shape[1] > 1) {
- // Assume one-hot encoding of classes.
- const axis = 1;
- return y.argMax(axis);
- }
- else if (y.shape[1] === 1) {
- // Class index.
- return y.reshape([y.shape[0]]);
- }
- else {
- throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
- `during handling of class weights. The size is expected to be ` +
- `>= 1.`);
- }
- }
- else {
- throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
- `handling of class weights. The rank is expected to be 1 or 2.`);
- }
- });
- const yClassIndices = Array.from(await yClasses.data());
- dispose(yClasses);
- const classSampleWeight = [];
- yClassIndices.forEach(classIndex => {
- if (classWeight[classIndex] == null) {
- throw new Error(`classWeight must contain all classes in the training data. ` +
- `The class ${classIndex} exists in the data but not in ` +
- `classWeight`);
- }
- else {
- classSampleWeight.push(classWeight[classIndex]);
- }
- });
- return tensor1d(classSampleWeight, 'float32');
- }
- else {
- return null;
- }
- }
- /**
- * Apply per-sample weights on the loss values from a number of samples.
- *
- * @param losses Loss tensor of shape `[batchSize]`.
- * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
- * @returns Tensor of the same shape as`losses`.
- */
- function computeWeightedLoss$1(losses, sampleWeights) {
- return mul(losses, sampleWeights);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // Default batch size used during tensor-based validation.
- const DEFAULT_VALIDATION_BATCH_SIZE = 32;
- /**
- * Standardize the output of a dataset iterator for use by
- * LayersModel.fitDataset().
- *
- * @param model: A `tf.LayersModel` object.
- * @param iteratorOut The output of a dataset iterator. It is required to be
- * an object of the form `{xs: TensorOrArrayOrMap, ys:
- * TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`,
- * a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s.
- * @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s
- * followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided
- * as a map, the order in the resulting array is taken from the `inputNames`
- * and `outputNames` of the model.
- */
- function standardizeDataIteratorOutput(
- // Type `model` as `any` here to avoid circular dependency w/
- // training.ts.
- // tslint:disable-next-line:no-any
- model, iteratorOut) {
- let xs;
- let ys;
- const iteratorOutObj = iteratorOut;
- xs = iteratorOutObj['xs'];
- ys = iteratorOutObj['ys'];
- assert(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' +
- 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
- 'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
- 'string to Tensor. The provided Dataset instead generates ' +
- `${iteratorOut}`);
- const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
- const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
- const batchSize = flattenedXs[0].shape[0];
- assert(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
- `provides ${flattenedXs.length} inputs. (Expected input keys: ` +
- `${JSON.stringify(model.inputNames)})`);
- assert(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
- `provides ${flattenedYs.length} outputs. (Expected output keys: ` +
- `${JSON.stringify(model.outputNames)})`);
- for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
- assert(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` +
- `${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` +
- `expected ${batchSize} based on input ${model.inputNames[0]}.`);
- }
- for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
- assert(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` +
- `${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` +
- `expected ${batchSize} based on input ${model.inputNames[0]}.`);
- }
- return { xs: flattenedXs, ys: flattenedYs };
- }
- function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
- if (values instanceof Tensor) {
- return [values];
- }
- else if (Array.isArray(values)) {
- assert(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`);
- return values;
- }
- else {
- const result = [];
- // Check that all the required keys are available.
- for (const name of names) {
- if (values[name] == null) {
- throw new ValueError(`The feature data generated by the dataset lacks the required ` +
- `${inputOrOutput} key '${name}'.`);
- }
- result.push(values[name]);
- }
- return result;
- }
- }
- function standardizeTensorValidationData(data) {
- if (data.length === 3) {
- throw new NotImplementedError('Validation with sample weights is not implemented yet.');
- }
- return { xs: data[0], ys: data[1] };
- }
- async function fitDataset(
- // Type `model` as `any` here to avoid circular dependency w/
- // training.ts.
- // tslint:disable-next-line:no-any
- model, dataset, args) {
- const hasBatchesPerEpoch = args.batchesPerEpoch != null;
- assert(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' +
- 'LayersModel.compile(modelCompileConfig).');
- assert(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` +
- `but it is not provided in this call.`);
- assert(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` +
- `integer, but got ${args.epochs}`);
- assert(!hasBatchesPerEpoch ||
- (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
- `positive integer if specified, but got ${args.batchesPerEpoch}`);
- assert(
- // tslint:disable-next-line:no-any
- args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' +
- 'Use validationData instead.');
- if (model.isTraining) {
- throw new Error('Cannot start training because another fit() call is ongoing.');
- }
- model.isTraining = true;
- try {
- const doValidation = args.validationData != null;
- let valXs;
- let valYs;
- if (doValidation) {
- if (isDatasetObject(args.validationData)) {
- assert(args.validationBatches == null ||
- (args.validationBatches > 0 &&
- Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` +
- `config.validationBatches is expected not to be provided, ` +
- `or to be a positive integer, ` +
- `but got ${args.validationBatches}`);
- }
- else {
- const validationData = standardizeTensorValidationData(args.validationData);
- valXs = validationData.xs;
- valYs = validationData.ys;
- }
- }
- const trainFunction = model.makeTrainFunction();
- const outLabels = model.getDedupedMetricsNames();
- let callbackMetrics;
- if (doValidation) {
- callbackMetrics =
- outLabels.slice().concat(outLabels.map(n => 'val_' + n));
- }
- else {
- callbackMetrics = outLabels.slice();
- }
- const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
- const verbose = args.verbose == null ? 1 : args.verbose;
- const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, // Batch size determined by the dataset itself.
- doValidation, callbackMetrics);
- callbackList.setModel(model);
- model.history = history;
- await callbackList.onTrainBegin();
- model.stopTraining_ = false;
- let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
- let dataIterator = await dataset.iterator();
- while (epoch < args.epochs) {
- const epochLogs = {};
- await callbackList.onEpochBegin(epoch);
- let stepsDone = 0;
- let batchIndex = 0;
- if (!hasBatchesPerEpoch) {
- dataIterator = await dataset.iterator();
- }
- while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
- const iteratorOut = await dataIterator.next();
- // If `batchesPerEpoch` is specified, the dataset should not be
- // exhausted until all epoches are done.
- if (hasBatchesPerEpoch && iteratorOut.done) {
- console.warn('You provided `batchesPerEpoch` as ' +
- `${args.batchesPerEpoch}, ` +
- 'but your dataset iterator ran out of data after ' +
- `${stepsDone} batches; ` +
- 'interrupting training. Make sure that your ' +
- 'dataset can generate at least `batchesPerEpoch * epochs` ' +
- 'batches (in this case, ' +
- `${args.batchesPerEpoch * args.epochs} batches). ` +
- 'You may need to use the repeat() function when building ' +
- 'your dataset.');
- break;
- }
- if (iteratorOut.value != null) {
- const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
- const batchLogs = {};
- batchLogs['batch'] = batchIndex;
- batchLogs['size'] = xs[0].shape[0];
- await callbackList.onBatchBegin(batchIndex, batchLogs);
- const sampleWeights = [];
- if (args.classWeight != null) {
- const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
- for (let i = 0; i < standardClassWeights.length; ++i) {
- sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i]));
- }
- }
- // Train on batch.
- const ins = xs.concat(ys).concat(sampleWeights);
- const outs = trainFunction(ins);
- dispose(ins);
- for (let i = 0; i < outLabels.length; ++i) {
- const label = outLabels[i];
- const out = outs[i];
- batchLogs[label] = out;
- keep(out);
- }
- await callbackList.onBatchEnd(batchIndex, batchLogs);
- disposeTensorsInLogs(batchLogs);
- batchIndex++;
- stepsDone++;
- }
- if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
- iteratorOut.done) {
- // Epoch finished. Perform validation.
- if (doValidation) {
- let valOuts;
- if (isDatasetObject(args.validationData)) {
- valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches }));
- }
- else {
- valOuts = toList(model.evaluate(valXs, valYs, {
- batchSize: args.validationBatchSize == null ?
- DEFAULT_VALIDATION_BATCH_SIZE :
- args.validationBatchSize,
- verbose: 0
- }));
- }
- for (let i = 0; i < model.metricsNames.length; ++i) {
- epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
- }
- }
- // Call `break` to exit one epoch lopp after validation is done. If
- // config.batchesPerEpoch is specified, an epoch while loop will
- // stop when `stepsDone >= config.batchesPerEpoch`. When
- // config.batchesPerEpoch is not provided, the following `break` is
- // required to exit the while lopp after dataset is exhausted.
- break;
- }
- if (model.stopTraining_) {
- break;
- }
- }
- await callbackList.onEpochEnd(epoch, epochLogs);
- epoch++;
- if (model.stopTraining_) {
- break;
- }
- }
- await callbackList.onTrainEnd();
- await model.history.syncData();
- return model.history;
- }
- finally {
- model.isTraining = false;
- }
- }
- /** Helper function that determines number of steps (batches) per epoch. */
- function getStepsPerEpoch(dataset, args) {
- // Attempt to determine # of batches in an epoch.
- let stepsPerEpoch = null;
- if (args.batchesPerEpoch != null) {
- stepsPerEpoch = args.batchesPerEpoch;
- }
- else if (Number.isFinite(dataset.size)) {
- stepsPerEpoch = dataset.size;
- }
- return stepsPerEpoch;
- }
- // Check if provided object is a Dataset object by checking its .iterator
- // element.
- function isDatasetObject(dataset) {
- return (typeof dataset.iterator === 'function');
- }
- // Check if provided object is a LazyIterator object by checking it's .next
- // element.
- function isLazyIteratorObject(iterator) {
- return (typeof iterator.next === 'function');
- }
- async function evaluateDataset(
- // Type `model` as `any` here to avoid circular dependency w/
- // training.ts.
- // tslint:disable-next-line:no-any
- model, dataset, args) {
- args = args || {};
- const hasBatches = args.batches != null;
- const f = model.testFunction;
- let outs = [];
- if (args.verbose > 0) {
- throw new NotImplementedError('Verbose mode is not implemented yet.');
- }
- assert(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' +
- `received ${JSON.stringify(args.batches)}`);
- const dataIterator = isLazyIteratorObject(dataset) ?
- dataset :
- await dataset.iterator();
- // Keeps track of number of examples used in this evaluation.
- let numExamples = 0;
- let batch = 0;
- while (hasBatches ? batch < args.batches : true) {
- const iteratorOut = await dataIterator.next();
- outs = tidy(() => {
- if (iteratorOut.value) {
- // TODO(cais): Once real dataset is available, use
- // `map(x => standardizeDataIteratorOutput(model, x).map(f)`.
- const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
- const xsAndYs = xs.concat(ys);
- const batchOuts = tidy(() => f(xsAndYs));
- dispose(xsAndYs);
- if (batch === 0) {
- for (let i = 0; i < batchOuts.length; ++i) {
- outs.push(scalar(0));
- }
- }
- const batchSize = xsAndYs[0].shape[0];
- for (let i = 0; i < batchOuts.length; ++i) {
- const batchOut = batchOuts[i];
- const oldScalar = outs[i];
- outs[i] =
- tidy(() => add$1(outs[i], mul(batchSize, batchOut)));
- if (batch > 0) {
- dispose(oldScalar);
- }
- }
- dispose(batchOuts);
- numExamples += batchSize;
- ++batch;
- }
- return outs;
- });
- if (iteratorOut.done) {
- if (hasBatches) {
- console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' +
- 'Interrupting evalution. Make sure that your ' +
- 'dataset can generate at least `batches` ' +
- `batches (in this case, ${args.batches} batches). ` +
- 'You may need to use the repeat() function when building ' +
- 'your dataset.');
- }
- break;
- }
- }
- for (let i = 0; i < outs.length; ++i) {
- const oldScalar = outs[i];
- outs[i] = div(outs[i], numExamples);
- dispose(oldScalar);
- }
- return singletonOrArray(outs);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- function checkBatchSize(batchSize) {
- assert(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`);
- }
- /**
- * Slice a Tensor or an Array of Tensors, by start and stop indices.
- *
- * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
- * function and `sliceArraysByIndices()` together.
- *
- * @param arrays: the input.
- * @param start: the starting index (inclusive).
- * @param stop: the stopping index (exclusive).
- * @returns The result of the slicing. If `arrays` is an `Array` of
- * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
- * in the same way.
- */
- function sliceArrays(arrays, start, stop) {
- if (arrays == null) {
- return [null];
- }
- else if (Array.isArray(arrays)) {
- return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));
- }
- else { // Tensor.
- return sliceAlongFirstAxis(arrays, start, stop - start);
- }
- }
- /**
- * Slice a Tensor or an Array of Tensors, by random-order indices.
- *
- * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
- * function and `sliceArrays()` together.
- *
- * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
- * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
- * same fashion.
- * @param indices The indices to use for slicing along the first (batch)
- * dimension.
- * @returns Result(s) of the slicing.
- */
- function sliceArraysByIndices(arrays, indices) {
- return tidy(() => {
- if (arrays == null) {
- return null;
- }
- else if (Array.isArray(arrays)) {
- return arrays.map(array => sliceArraysByIndices(array, indices));
- }
- else {
- // TODO(cais): indices should be a pre-constructed Tensor1D to avoid
- // tensor1d() calls.
- return gather$1(arrays, indices.dtype === 'int32' ? indices : indices.toInt());
- }
- });
- }
- /**
- * Returns a list of batch indices (tuples of indices).
- * @param size: Integer, total size of the data to slice into batches.
- * @param batchSize: Integer, batch size.
- * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
- * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
- * that satisfy batchStart <= x < batchEnd.
- */
- function makeBatches(size, batchSize) {
- const output = [];
- let batchStart = 0;
- let batchEnd = null;
- while (batchStart < size) {
- batchEnd = batchStart + batchSize;
- if (batchEnd >= size) {
- batchEnd = size;
- }
- output.push([batchStart, batchEnd]);
- batchStart = batchEnd;
- }
- return output;
- }
- /**
- * Abstract fit function for `f(ins)`.
- * @param f A Function returning a list of tensors. For training, this
- * function is expected to perform the updates to the variables.
- * @param ins List of tensors to be fed to `f`.
- * @param outLabels List of strings, display names of the outputs of `f`.
- * @param batchSize Integer batch size or `== null` if unknown. Default : 32.
- * @param epochs Number of times to iterate over the data. Default : 1.
- * @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
- * @param callbacks List of callbacks to be called during training.
- * @param valF Function to call for validation.
- * @param valIns List of tensors to be fed to `valF`.
- * @param shuffle Whether to shuffle the data at the beginning of every
- * epoch. Default : true.
- * @param callbackMetrics List of strings, the display names of the metrics
- * passed to the callbacks. They should be the concatenation of the
- * display names of the outputs of `f` and the list of display names
- * of the outputs of `valF`.
- * @param initialEpoch Epoch at which to start training (useful for
- * resuming a previous training run). Default : 0.
- * @param stepsPerEpoch Total number of steps (batches on samples) before
- * declaring one epoch finished and starting the next epoch. Ignored with
- * the default value of `undefined` or `null`.
- * @param validationSteps Number of steps to run validation for (only if
- * doing validation from data tensors). Not applicable for tfjs-layers.
- * @returns A `History` object.
- */
- async function fitLoop(
- // Type `model` as `any` here to avoid circular dependency w/ training.ts.
- // tslint:disable-next-line:no-any
- model, f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
- if (batchSize == null) {
- batchSize = 32;
- }
- if (epochs == null) {
- epochs = 1;
- }
- if (shuffle$1 == null) {
- shuffle$1 = true;
- }
- if (initialEpoch == null) {
- initialEpoch = 0;
- }
- // TODO(cais): Change const to let below when implementing validation.
- let doValidation = false;
- if (valF != null && valIns != null) {
- doValidation = true;
- // TODO(cais): verbose message.
- }
- if (validationSteps != null) {
- doValidation = true;
- if (stepsPerEpoch == null) {
- throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' +
- 'i.e., `stepsPerEpoch` must be set.');
- }
- }
- const numTrainSamples = model.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
- let indexArray;
- if (numTrainSamples != null) {
- indexArray = range$1(0, numTrainSamples);
- }
- if (verbose == null) {
- verbose = 1;
- }
- const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics);
- callbackList.setModel(model);
- model.history = history;
- await callbackList.onTrainBegin();
- model.stopTraining_ = false;
- // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
- // TODO(cais): Pre-convert feeds for performance as in PyKeras.
- for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
- await callbackList.onEpochBegin(epoch);
- const epochLogs = {};
- if (stepsPerEpoch != null) {
- throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
- }
- else {
- if (shuffle$1 === 'batch') {
- throw new NotImplementedError('batch shuffling is not implemneted yet');
- }
- else if (shuffle$1) {
- shuffle(indexArray);
- }
- // Convert the potentially shuffled indices to Tensor1D, to avoid the
- // cost of repeated creation of Array1Ds later on.
- const epochIndexArray1D = tensor1d(indexArray);
- const batches = makeBatches(numTrainSamples, batchSize);
- for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
- const batchLogs = {};
- await callbackList.onBatchBegin(batchIndex, batchLogs);
- tidy(() => {
- const batchStart = batches[batchIndex][0];
- const batchEnd = batches[batchIndex][1];
- const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
- batchLogs['batch'] = batchIndex;
- batchLogs['size'] = batchEnd - batchStart;
- // TODO(cais): In ins, train flag can be a number, instead of an
- // Tensor? Do we need to handle this in tfjs-layers?
- const insBatch = sliceArraysByIndices(ins, batchIds);
- const outs = f(insBatch);
- for (let i = 0; i < outLabels.length; ++i) {
- const label = outLabels[i];
- const out = outs[i];
- batchLogs[label] = out;
- keep(out);
- // TODO(cais): Use scope() to avoid ownership.
- }
- if (batchIndex === batches.length - 1) { // Last batch.
- if (doValidation) {
- const valOuts = model.testLoop(valF, valIns, batchSize);
- // Porting Notes: In tfjs-layers, valOuts is always an Array.
- for (let i = 0; i < outLabels.length; ++i) {
- const label = outLabels[i];
- const out = valOuts[i];
- keep(out);
- // TODO(cais): Use scope() to avoid ownership.
- epochLogs['val_' + label] = out;
- }
- }
- }
- });
- await callbackList.onBatchEnd(batchIndex, batchLogs);
- disposeTensorsInLogs(batchLogs);
- if (model.stopTraining_) {
- break;
- }
- // TODO(cais): return outs as list of Tensor.
- }
- epochIndexArray1D.dispose();
- }
- // TODO(cais): Run validation at the end of the epoch.
- await callbackList.onEpochEnd(epoch, epochLogs);
- if (model.stopTraining_) {
- break;
- }
- }
- await callbackList.onTrainEnd();
- await model.history.syncData();
- return model.history;
- }
- async function fitTensors(
- // Type `model` as `any` here to avoid circular dependency w/ training.ts.
- // tslint:disable-next-line:no-any
- model, x, y, args = {}) {
- if (model.isTraining) {
- throw new Error('Cannot start training because another fit() call is ongoing.');
- }
- model.isTraining = true;
- let inputs;
- let targets;
- let inputValX;
- let inputValY;
- let valX;
- let valY;
- let sampleWeights;
- try {
- const batchSize = args.batchSize == null ? 32 : args.batchSize;
- checkBatchSize(batchSize);
- // Validate user data.
- // TODO(cais): Support sampleWeight.
- const checkBatchAxis = false;
- const standardizedOuts = await model.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
- inputs = standardizedOuts[0];
- targets = standardizedOuts[1];
- sampleWeights = standardizedOuts[2];
- // Prepare validation data.
- let doValidation = false;
- let valIns;
- if (args.validationData != null && args.validationData.length > 0) {
- doValidation = true;
- if (args.validationData.length === 2) {
- // config.validationData consists of valX and valY.
- inputValX = args.validationData[0];
- inputValY = args.validationData[1];
- }
- else if (args.validationData.length === 3) {
- throw new NotImplementedError('validationData including sample weights is not supported yet.');
- }
- else {
- throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` +
- `or 3 (valX, valY, valSampleWeight) items; ` +
- `${args.validationData} is invalid.`);
- }
- const checkBatchAxis = true;
- const valStandardized = await model.standardizeUserData(inputValX, inputValY, null, /** Unused sample weights. */ null, /** Unused class weights. */ checkBatchAxis, batchSize);
- valX = valStandardized[0];
- valY = valStandardized[1];
- valIns = valX.concat(valY);
- // TODO(cais): Add useLearningPhase data properly.
- }
- else if (args.validationSplit != null && args.validationSplit > 0 &&
- args.validationSplit < 1) {
- doValidation = true;
- // Porting Note: In tfjs-layers, inputs[0] is always a Tensor.
- const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
- const originalBatchSize = inputs[0].shape[0];
- valX = sliceArrays(inputs, splitAt, originalBatchSize);
- inputs = sliceArrays(inputs, 0, splitAt);
- valY = sliceArrays(targets, splitAt, originalBatchSize);
- targets = sliceArrays(targets, 0, splitAt);
- // TODO(cais): Once sampleWeights becomes available, slice it to get
- // valSampleWeights.
- valIns = valX.concat(valY);
- // TODO(cais): Add useLearningPhase data properly.
- }
- else if (args.validationSteps != null) {
- doValidation = true;
- // TODO(cais): Add useLearningPhase.
- }
- const ins = inputs.concat(targets).concat(sampleWeights);
- model.checkTrainableWeightsConsistency();
- // TODO(cais): Handle use_learning_phase and learning_phase?
- // Porting Note: Here we see a key deviation of tfjs-layers from
- // Keras.
- // Due to the imperative nature of tfjs-layers' backend (tfjs-core),
- // we do not construct symbolic computation graphs to embody the
- // training process. Instead, we define a function that performs the
- // training action. In PyKeras, the data (inputs and targets) are fed
- // through graph placeholders. In tfjs-layers, the data are fed as
- // function arguments. Since the function are defined below in the
- // scope, we don't have equivalents of PyKeras's
- // `_make_train_funciton`.
- const trainFunction = model.makeTrainFunction();
- const outLabels = model.getDedupedMetricsNames();
- let valFunction;
- let callbackMetrics;
- if (doValidation) {
- model.makeTestFunction();
- valFunction = model.testFunction;
- callbackMetrics =
- outLabels.slice().concat(outLabels.map(n => 'val_' + n));
- }
- else {
- valFunction = null;
- valIns = [];
- callbackMetrics = outLabels.slice();
- }
- const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
- const out = await fitLoop(model, trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
- return out;
- }
- finally {
- model.isTraining = false;
- // Memory clean up.
- disposeNewTensors(inputs, x);
- disposeNewTensors(targets, y);
- disposeNewTensors(valX, inputValX);
- disposeNewTensors(valY, inputValY);
- if (sampleWeights != null) {
- dispose(sampleWeights);
- }
- }
- // TODO(cais): Add value to outLabels.
- }
- /**
- * Ensure tensors all have a rank of at least 2.
- *
- * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
- * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
- */
- function ensureTensorsRank2OrHigher(tensors) {
- const outs = [];
- if (tensors instanceof Tensor) {
- tensors = [tensors];
- }
- // Make Tensors at least 2D.
- for (let i = 0; i < tensors.length; ++i) {
- const tensor = tensors[i];
- if (tensor.rank === 1) {
- outs.push(expandDims$1(tensor, 1));
- }
- else if (tensor.rank === 0) {
- throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' +
- '(scalar).');
- }
- else {
- outs.push(tensor);
- }
- }
- return outs;
- }
- /**
- * Compare a set of tensors with a reference (old) set, discard the ones
- * in the new set that are not present in the reference set.
- *
- * This method is used for memory clenaup during calls such as
- * LayersModel.fit().
- *
- * @param tensors New set which may contain Tensors not present in
- * `refTensors`.
- * @param refTensors Reference Tensor set.
- */
- // TODO(cais, kangyizhang): Deduplicate with tfjs-data.
- function disposeNewTensors(tensors, refTensors) {
- if (tensors == null) {
- return;
- }
- const oldTensorIds = [];
- if (refTensors instanceof Tensor) {
- oldTensorIds.push(refTensors.id);
- }
- else if (Array.isArray(refTensors)) {
- refTensors.forEach(t => oldTensorIds.push(t.id));
- }
- else if (refTensors != null) {
- // `oldTensors` is a map from string name to Tensor.
- for (const name in refTensors) {
- const oldTensor = refTensors[name];
- oldTensorIds.push(oldTensor.id);
- }
- }
- const tensorsToDispose = [];
- if (tensors instanceof Tensor) {
- if (oldTensorIds.indexOf(tensors.id) === -1) {
- tensorsToDispose.push(tensors);
- }
- }
- else if (Array.isArray(tensors)) {
- tensors.forEach(t => {
- if (oldTensorIds.indexOf(t.id) === -1) {
- tensorsToDispose.push(t);
- }
- });
- }
- else if (tensors != null) {
- // `oldTensors` is a map from string name to Tensor.
- for (const name in tensors) {
- const tensor = tensors[name];
- if (oldTensorIds.indexOf(tensor.id) === -1) {
- tensorsToDispose.push(tensor);
- }
- }
- }
- tensorsToDispose.forEach(t => {
- if (!t.isDisposed) {
- t.dispose();
- }
- });
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Helper function for polymorphic input data: 1. singleton Tensor.
- */
- function isDataTensor(x) {
- return x instanceof Tensor;
- }
- /**
- * Helper function for polymorphic input data: 2. Array of Tensor.
- */
- function isDataArray(x) {
- return Array.isArray(x);
- }
- /**
- * Helper function for polymorphic input data: 3. "dict" of Tensor.
- */
- function isDataDict(x) {
- return !isDataTensor(x) && !isDataArray(x);
- }
- /**
- * Normalizes inputs and targets provided by users.
- * @param data User-provided input data (polymorphic).
- * @param names An Array of expected Tensor names.
- * @param shapes Optional Array of expected Tensor shapes.
- * @param checkBatchAxis Whether to check that the batch axis of the arrays
- * match the expected value found in `shapes`.
- * @param exceptionPrefix String prefix used for exception formatting.
- * @returns List of standardized input Tensors (one Tensor per model input).
- * @throws ValueError: in case of improperly formatted user data.
- */
- function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
- if (names == null || names.length === 0) {
- // Check for the case where the model expected no data, but some data got
- // sent.
- if (data != null) {
- let gotUnexpectedData = false;
- if (isDataArray(data) && data.length > 0) {
- gotUnexpectedData = true;
- }
- else if (isDataDict(data)) {
- for (const key in data) {
- if (data.hasOwnProperty(key)) {
- gotUnexpectedData = true;
- break;
- }
- }
- }
- else {
- // `data` is a singleton Tensor in this case.
- gotUnexpectedData = true;
- }
- if (gotUnexpectedData) {
- throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
- `but got ${data}`);
- }
- }
- return [];
- }
- if (data == null) {
- return names.map(name => null);
- }
- let arrays;
- if (isDataDict(data)) {
- data = data;
- arrays = [];
- for (const name of names) {
- if (data[name] == null) {
- throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
- `${names}`);
- }
- arrays.push(data[name]);
- }
- }
- else if (isDataArray(data)) {
- data = data;
- if (data.length !== names.length) {
- throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
- `Tensors that you are passing to your model is not the size the ` +
- `model expected. Expected to see ${names.length} Tensor(s), but ` +
- `instead got the following list of Tensor(s): ${data}`);
- }
- arrays = data;
- }
- else {
- data = data;
- if (names.length > 1) {
- throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
- `but only received one Tensor. Found: Tensor with shape ${data.shape}`);
- }
- arrays = [data];
- }
- arrays = ensureTensorsRank2OrHigher(arrays);
- // Check shape compatibility.
- if (shapes != null) {
- for (let i = 0; i < names.length; ++i) {
- if (shapes[i] == null) {
- continue;
- }
- const array = arrays[i];
- if (array.shape.length !== shapes[i].length) {
- throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
- `to have ${shapes[i].length} dimension(s). but got array with ` +
- `shape ${array.shape}`);
- }
- for (let j = 0; j < shapes[i].length; ++j) {
- if (j === 0 && !checkBatchAxis) {
- // Skip the first (batch) axis.
- continue;
- }
- const dim = array.shape[j];
- const refDim = shapes[i][j];
- if (refDim != null && refDim >= 0 && dim !== refDim) {
- throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
- `to have shape [${shapes[i]}], but got array with shape ` +
- `[${array.shape}].`);
- }
- }
- }
- }
- return arrays;
- }
- /**
- * User input validation for Tensors.
- * @param inputs `Array` of `tf.Tensor`s for inputs.
- * @param targets `Array` of `tf.Tensor`s for targets.
- * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
- * @throws ValueError: in case of incorrectly formatted data.
- */
- function checkArrayLengths(inputs, targets, weights) {
- const setX = unique$1(inputs.map(input => input.shape[0]));
- setX.sort();
- const setY = unique$1(targets.map(target => target.shape[0]));
- setY.sort();
- // TODO(cais): Check `weights` as well.
- if (setX.length > 1) {
- throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
- `Got array shapes: ` +
- `${JSON.stringify(inputs.map(input => input.shape))}`);
- }
- if (setY.length > 1) {
- throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
- `Got array shapes: ` +
- `${JSON.stringify(targets.map(target => target.shape))}`);
- }
- if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
- throw new ValueError(`Input Tensors should have the same number of samples as target ` +
- `Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
- `sample(s).`);
- }
- }
- /**
- * Validation on the compatibility of targes and loss functions.
- *
- * This helps prevent users from using loss functions incorrectly.
- *
- * @param targets `Array` of `tf.Tensor`s of targets.
- * @param lossFns `Array` of loss functions.
- * @param outputShapes `Array` of shapes of model outputs.
- */
- function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
- // TODO(cais): Dedicated test coverage?
- const keyLosses = [
- meanSquaredError$1, binaryCrossentropy,
- categoricalCrossentropy
- ];
- for (let i = 0; i < targets.length; ++i) {
- const y = targets[i];
- const loss = lossFns[i];
- const shape = outputShapes[i];
- if (loss == null) {
- continue;
- }
- if (loss === categoricalCrossentropy) {
- if (y.shape[y.shape.length - 1] === 1) {
- throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
- `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
- `expects targets to be binary matrices (1s and 0s) of shape ` +
- `[samples, classes].`);
- // TODO(cais): Example code in error message.
- }
- }
- if (keyLosses.indexOf(loss) !== -1) {
- const slicedYShape = y.shape.slice(1);
- const slicedShape = shape.slice(1);
- for (let j = 0; j < slicedYShape.length; ++j) {
- const targetDim = slicedYShape[j];
- const outDim = slicedShape[j];
- if (outDim != null && targetDim !== outDim) {
- throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
- `output of shape ${shape}, while using a loss function that ` +
- `expects targets to have the same shape as the output.`);
- }
- }
- }
- }
- }
- /**
- * Check inputs provided by the user.
- *
- * Porting Note: This corresponds to _standardize_input_data() in Python
- * Keras. Because of the strong typing in TF.js, we do not need to convert
- * the data. Specifically:
- * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
- * example. We don't need to worry about that here because there is no
- * widely popular javascript/typesdcript equivalent of pandas (so far).
- * If one becomes available in the future, we can add support.
- * 2) in PyKeras, inputs can be Python dict. But here we are stipulating
- * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
- * may add support for `Object` data inputs in the future when the need
- * arises.
- *
- * Instead, we perform basic checks for number of parameters and shapes.
- *
- * @param data: The input data.
- * @param names: Name for the inputs, from the model.
- * @param shapes: Expected shapes for the input data, from the model.
- * @param checkBatchAxis: Whether the size along the batch axis (i.e., the
- * first dimension) will be checked for matching.
- * @param exceptionPrefix: Execption prefix message, used in generating error
- * messages.
- * @throws ValueError: on incorrect number of inputs or mismatches in shapes.
- */
- function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
- let arrays;
- if (Array.isArray(data)) {
- if (data.length !== names.length) {
- throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
- `Tensors that you are passing to your model is not the size the ` +
- `the model expected. Expected to see ${names.length} Tensor(s),` +
- ` but instead got ${data.length} Tensors(s).`);
- }
- arrays = data;
- }
- else {
- if (names.length > 1) {
- throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
- `but only received one Tensor. Found: array with shape ` +
- `${JSON.stringify(data.shape)}.`);
- }
- arrays = [data];
- }
- if (shapes != null) {
- for (let i = 0; i < names.length; ++i) {
- if (shapes[i] == null) {
- continue;
- }
- const array = arrays[i];
- if (array.shape.length !== shapes[i].length) {
- throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
- `to have ${shapes[i].length} dimension(s), but got array with ` +
- `shape ${JSON.stringify(array.shape)}`);
- }
- for (let j = 0; j < shapes[i].length; ++j) {
- if (j === 0 && !checkBatchAxis) {
- continue;
- }
- const dim = array.shape[j];
- const refDim = shapes[i][j];
- if (refDim != null) {
- if (refDim !== dim) {
- throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
- `${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
- `got array with shape ${JSON.stringify(array.shape)}.`);
- }
- }
- }
- }
- }
- }
- /**
- * Maps metric functions to model outputs.
- * @param metrics An shortcut strings name, metric function, `Array` or dict
- * (`Object`) of metric functions.
- * @param outputNames An `Array` of the names of model outputs.
- * @returns An `Array` (one entry per model output) of `Array` of metric
- * functions. For instance, if the model has 2 outputs, and for the first
- * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
- * and just `binaryAccuracy` for the second output, the `Array` would look
- * like:
- * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
- * @throws TypeError: incompatible metrics format.
- */
- function collectMetrics(metrics, outputNames) {
- if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
- return outputNames.map(name => []);
- }
- let wrappedMetrics;
- if (typeof metrics === 'string' || typeof metrics === 'function') {
- wrappedMetrics = [metrics];
- }
- else if (Array.isArray(metrics) || typeof metrics === 'object') {
- wrappedMetrics = metrics;
- }
- else {
- throw new TypeError('Type of metrics argument not understood. Expected an string,' +
- `function, Array, or Object, found: ${metrics}`);
- }
- if (Array.isArray(wrappedMetrics)) {
- // We then apply all metrics to all outputs.
- return outputNames.map(name => wrappedMetrics);
- }
- else {
- // In this case, metrics is a dict.
- const nestedMetrics = [];
- for (const name of outputNames) {
- let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
- if (!Array.isArray(outputMetrics)) {
- outputMetrics = [outputMetrics];
- }
- nestedMetrics.push(outputMetrics);
- }
- return nestedMetrics;
- }
- }
- const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
- /**
- * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
- * for training, evaluation, prediction and saving.
- *
- * `tf.LayersModel` is the basic unit of training, inference and evaluation in
- * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
- *
- * See also:
- * `tf.Sequential`, `tf.loadLayersModel`.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- class LayersModel extends Container {
- constructor(args) {
- super(args);
- this.isTraining = false;
- }
- /**
- * Print a text summary of the model's layers.
- *
- * The summary includes
- * - Name and type of all layers that comprise the model.
- * - Output shape(s) of the layers
- * - Number of weight parameters of each layer
- * - If the model has non-sequential-like topology, the inputs each layer
- * receives
- * - The total number of trainable and non-trainable parameters of the model.
- *
- * ```js
- * const input1 = tf.input({shape: [10]});
- * const input2 = tf.input({shape: [20]});
- * const dense1 = tf.layers.dense({units: 4}).apply(input1);
- * const dense2 = tf.layers.dense({units: 8}).apply(input2);
- * const concat = tf.layers.concatenate().apply([dense1, dense2]);
- * const output =
- * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
- *
- * const model = tf.model({inputs: [input1, input2], outputs: output});
- * model.summary();
- * ```
- *
- * @param lineLength Custom line length, in number of characters.
- * @param positions Custom widths of each of the columns, as either
- * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
- * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
- * right-most (i.e., ending) position of a column.
- * @param printFn Custom print function. Can be used to replace the default
- * `console.log`. For example, you can use `x => {}` to mute the printed
- * messages in the console.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- summary(lineLength, positions, printFn = console.log) {
- if (!this.built) {
- throw new ValueError(`This model has never been called, thus its weights have not been ` +
- `created yet. So no summary can be displayed. Build the model ` +
- `first (e.g., by calling it on some test data).`);
- }
- printSummary(this, lineLength, positions, printFn);
- }
- /**
- * Configures and prepares the model for training and evaluation. Compiling
- * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
- * or `evaluate` on an un-compiled model will throw an error.
- *
- * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
- * metrics to be used for fitting and evaluating this model.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- compile(args) {
- if (args.loss == null) {
- args.loss = [];
- }
- this.loss = args.loss;
- if (typeof args.optimizer === 'string') {
- this.optimizer_ = getOptimizer(args.optimizer);
- this.isOptimizerOwned = true;
- }
- else {
- if (!(args.optimizer instanceof Optimizer)) {
- throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
- }
- this.optimizer_ = args.optimizer;
- this.isOptimizerOwned = false;
- }
- // TODO(cais): Add lossWeights.
- // TODO(cais): Add sampleWeightMode.
- // Prepare loss functions.
- let lossFunctions = [];
- if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
- typeof args.loss !== 'function') {
- args.loss = args.loss;
- for (const name in args.loss) {
- if (this.outputNames.indexOf(name) === -1) {
- throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
- `Only expected the following keys: ${this.outputNames}`);
- }
- }
- for (const name of this.outputNames) {
- if (args.loss[name] == null) {
- console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
- `this was done on purpose, and we will not be expecting data ` +
- `to be passed to ${name} during training`);
- }
- lossFunctions.push(get(args.loss[name]));
- }
- }
- else if (Array.isArray(args.loss)) {
- if (args.loss.length !== this.outputs.length) {
- throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
- `model output. The model has ${this.outputs.length} output(s), ` +
- `but you passed loss=${args.loss}.`);
- }
- const theLosses = args.loss;
- lossFunctions = theLosses.map(l => get(l));
- }
- else {
- const lossFunction = get(args.loss);
- this.outputs.forEach(_ => {
- lossFunctions.push(lossFunction);
- });
- }
- this.lossFunctions = lossFunctions;
- this.feedOutputNames = [];
- this.feedOutputShapes = [];
- this.feedLossFns = [];
- for (let i = 0; i < this.outputs.length; ++i) {
- // TODO(cais): Logic for skipping target(s).
- const shape = this.internalOutputShapes[i];
- const name = this.outputNames[i];
- this.feedOutputNames.push(name);
- this.feedOutputShapes.push(shape);
- this.feedLossFns.push(this.lossFunctions[i]);
- }
- // TODO(cais): Add logic for output masks.
- // TODO(cais): Add logic for sample weights.
- const skipTargetIndices = [];
- // Prepare metrics.
- this.metrics = args.metrics;
- // TODO(cais): Add weightedMetrics.
- this.metricsNames = ['loss'];
- this.metricsTensors = [];
- // Compute total loss.
- // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
- // Here, metricsTensors are TypeScript functions. This difference is due
- // to the difference in symbolic/imperative property of the backends.
- nameScope('loss', () => {
- for (let i = 0; i < this.outputs.length; ++i) {
- if (skipTargetIndices.indexOf(i) !== -1) {
- continue;
- }
- // TODO(cais): Add weightedLoss, sampleWeight and mask.
- // The following line should be weightedLoss
- const weightedLoss = this.lossFunctions[i];
- if (this.outputs.length > 1) {
- this.metricsTensors.push([weightedLoss, i]);
- this.metricsNames.push(this.outputNames[i] + '_loss');
- }
- }
- // Porting Note: Due to the imperative nature of the backend, we calculate
- // the regularizer penalties in the totalLossFunction, instead of here.
- });
- const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
- // TODO(cais): Add nestedWeightedMetrics.
- /**
- * Helper function used in loop below.
- */
- const appendMetric = (outputIndex, metricName, metricTensor) => {
- if (this.outputNames.length > 1) {
- metricName = this.outputNames[outputIndex] + '_' + metricName;
- }
- this.metricsNames.push(metricName);
- this.metricsTensors.push([metricTensor, outputIndex]);
- };
- nameScope('metric', () => {
- for (let i = 0; i < this.outputs.length; ++i) {
- if (skipTargetIndices.indexOf(i) !== -1) {
- continue;
- }
- const outputMetrics = nestedMetrics[i];
- // TODO(cais): Add weights and outputWeightedMetrics.
- // TODO(cais): Add optional arg `weights` to the following function.
- const handleMetrics = (metrics) => {
- const metricNamePrefix = '';
- let metricName;
- let accFn;
- let weightedMetricFn;
- // TODO(cais): Use 'weights_' for weighted metrics.
- for (const metric of metrics) {
- if (typeof metric === 'string' &&
- ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
- -1) {
- const outputShape = this.internalOutputShapes[i];
- if (outputShape[outputShape.length - 1] === 1 ||
- this.lossFunctions[i] === binaryCrossentropy) {
- // case: binary accuracy/crossentropy.
- if (['accuracy', 'acc'].indexOf(metric) !== -1) {
- accFn = binaryAccuracy;
- }
- else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
- accFn = binaryCrossentropy$1;
- }
- }
- else if (this.lossFunctions[i] ===
- sparseCategoricalCrossentropy) {
- // case: categorical accuracy / crossentropy with sparse
- // targets.
- if (['accuracy', 'acc'].indexOf(metric) !== -1) {
- accFn = sparseCategoricalAccuracy;
- }
- else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
- accFn = sparseCategoricalCrossentropy$1;
- }
- }
- else {
- // case: categorical accuracy / crossentropy.
- if (['accuracy', 'acc'].indexOf(metric) !== -1) {
- accFn = categoricalAccuracy;
- }
- else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
- accFn = categoricalCrossentropy$1;
- }
- }
- let suffix;
- if (['accuracy', 'acc'].indexOf(metric) !== -1) {
- suffix = 'acc';
- }
- else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
- suffix = 'ce';
- }
- // TODO(cais): Add weighting actually.
- weightedMetricFn = accFn;
- metricName = metricNamePrefix + suffix;
- }
- else {
- const metricFn = get$1(metric);
- // TODO(cais): Add weighting actually.
- weightedMetricFn = metricFn;
- metricName =
- metricNamePrefix + getLossOrMetricName(metric);
- }
- // TODO(cais): Add weighting and masking to metricResult.
- let metricResult;
- nameScope(metricName, () => {
- metricResult = weightedMetricFn;
- });
- appendMetric(i, metricName, metricResult);
- }
- };
- handleMetrics(outputMetrics);
- // TODO(cais): Call handleMetrics with weights.
- }
- });
- // Porting Notes: Given the imperative backend of tfjs-core,
- // there is no need for constructing the symbolic graph and placeholders.
- this.collectedTrainableWeights = this.trainableWeights;
- }
- /**
- * Check trainable weights count consistency.
- *
- * This will raise a warning if `this.trainableWeights` and
- * `this.collectedTrainableWeights` are inconsistent (i.e., have different
- * numbers of parameters).
- * Inconsistency will typically arise when one modifies `model.trainable`
- * without calling `model.compile()` again.
- */
- checkTrainableWeightsConsistency() {
- if (this.collectedTrainableWeights == null) {
- return;
- }
- if (this.trainableWeights.length !==
- this.collectedTrainableWeights.length) {
- console.warn('Discrepancy between trainableweights and collected trainable ' +
- 'weights. Did you set `model.trainable` without calling ' +
- '`model.compile()` afterwards?');
- }
- }
- /**
- * Returns the loss value & metrics values for the model in test mode.
- *
- * Loss and metrics are specified during `compile()`, which needs to happen
- * before calls to `evaluate()`.
- *
- * Computation is done in batches.
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
- * const result = model.evaluate(
- * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
- * result.print();
- * ```
- *
- * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
- * model has multiple inputs.
- * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
- * model has multiple outputs.
- * @param args A `ModelEvaluateArgs`, containing optional fields.
- *
- * @return `Scalar` test loss (if the model has a single output and no
- * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
- * and/or metrics). The attribute `model.metricsNames`
- * will give you the display labels for the scalar outputs.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- evaluate(x, y, args = {}) {
- const batchSize = args.batchSize == null ? 32 : args.batchSize;
- checkBatchSize(batchSize);
- // TODO(cais): Standardize `config.sampleWeights` as well.
- // Validate user data.
- const checkBatchAxis = true;
- const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
- try {
- // TODO(cais): If uses `useLearningPhase`, set the corresponding element
- // of the input to 0.
- const ins = standardizedOuts[0].concat(standardizedOuts[1]);
- this.makeTestFunction();
- const f = this.testFunction;
- const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
- return singletonOrArray(testOuts);
- }
- finally {
- disposeNewTensors(standardizedOuts[0], x);
- disposeNewTensors(standardizedOuts[1], y);
- }
- }
- // TODO(cais): Add code snippet below once real dataset objects are
- // available.
- /**
- * Evaluate model using a dataset object.
- *
- * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
- *
- * @param dataset A dataset object. Its `iterator()` method is expected
- * to generate a dataset iterator object, the `next()` method of which
- * is expected to produce data batches for evaluation. The return value
- * of the `next()` call ought to contain a boolean `done` field and a
- * `value` field. The `value` field is expected to be an array of two
- * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
- * case is for models with exactly one input and one output (e.g..
- * a sequential model). The latter case is for models with multiple
- * inputs and/or multiple outputs. Of the two items in the array, the
- * first is the input feature(s) and the second is the output target(s).
- * @param args A configuration object for the dataset-based evaluation.
- * @returns Loss and metric values as an Array of `Scalar` objects.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async evaluateDataset(dataset, args) {
- this.makeTestFunction();
- return evaluateDataset(this, dataset, args);
- }
- /**
- * Get number of samples provided for training, evaluation or prediction.
- *
- * @param ins Input `tf.Tensor`.
- * @param batchSize Integer batch size, optional.
- * @param steps Total number of steps (batches of samples) before
- * declaring loop finished. Optional.
- * @param stepsName The public API's parameter name for `steps`.
- * @returns Number of samples provided.
- */
- checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
- let numSamples;
- if (steps != null) {
- numSamples = null;
- if (batchSize != null) {
- throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
- `Got batchSize = ${batchSize}`);
- }
- }
- else if (ins != null) {
- if (Array.isArray(ins)) {
- numSamples = ins[0].shape[0];
- }
- else {
- numSamples = ins.shape[0];
- }
- }
- else {
- throw new ValueError(`Either the input data should have a defined shape, or ` +
- `${stepsName} shoud be specified.`);
- }
- return numSamples;
- }
- /**
- * Execute internal tensors of the model with input data feed.
- * @param inputs Input data feed. Must match the inputs of the model.
- * @param outputs Names of the output tensors to be fetched. Must match
- * names of the SymbolicTensors that belong to the graph.
- * @returns Fetched values for `outputs`.
- */
- execute(inputs, outputs) {
- if (Array.isArray(outputs) && outputs.length === 0) {
- throw new ValueError('`outputs` is an empty Array, which is not allowed.');
- }
- const outputsIsArray = Array.isArray(outputs);
- const outputNames = (outputsIsArray ? outputs : [outputs]);
- const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
- // Format the input into a FeedDict.
- const feedDict = new FeedDict();
- if (inputs instanceof Tensor) {
- inputs = [inputs];
- }
- if (Array.isArray(inputs)) {
- if (inputs.length !== this.inputs.length) {
- throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
- `does not match the number of inputs of this model ` +
- `(${this.inputs.length}).`);
- }
- for (let i = 0; i < this.inputs.length; ++i) {
- feedDict.add(this.inputs[i], inputs[i]);
- }
- }
- else {
- for (const input of this.inputs) {
- const tensorValue = inputs[input.name];
- if (tensorValue == null) {
- throw new ValueError(`No value is provided for the model's input ${input.name}`);
- }
- feedDict.add(input, tensorValue);
- }
- }
- // Run execution.
- const executeOutputs = execute(outputSymbolicTensors, feedDict);
- return outputsIsArray ? executeOutputs : executeOutputs[0];
- }
- /**
- * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
- */
- retrieveSymbolicTensors(symbolicTensorNames) {
- const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
- let outputsRemaining = symbolicTensorNames.length;
- for (const layer of this.layers) {
- const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
- const layerOutputNames = layerOutputs.map(output => output.name);
- for (let i = 0; i < symbolicTensorNames.length; ++i) {
- const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
- if (index !== -1) {
- outputSymbolicTensors[i] = layerOutputs[index];
- outputsRemaining--;
- }
- if (outputsRemaining === 0) {
- break;
- }
- }
- if (outputsRemaining === 0) {
- break;
- }
- }
- if (outputsRemaining > 0) {
- const remainingNames = [];
- outputSymbolicTensors.forEach((tensor, i) => {
- if (tensor == null) {
- remainingNames.push(symbolicTensorNames[i]);
- }
- });
- throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
- `${JSON.stringify(remainingNames)}`);
- }
- return outputSymbolicTensors;
- }
- /**
- * Helper method to loop over some data in batches.
- *
- * Porting Note: Not using the functional approach in the Python equivalent
- * due to the imperative backend.
- * Porting Note: Does not support step mode currently.
- *
- * @param ins: input data
- * @param batchSize: integer batch size.
- * @param verbose: verbosity model
- * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
- * `tf.Tensor` (if multipe outputs).
- */
- predictLoop(ins, batchSize = 32, verbose = false) {
- return tidy(() => {
- const numSamples = this.checkNumSamples(ins);
- if (verbose) {
- throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
- }
- // Sample-based predictions.
- // Porting Note: Tensor currently does not support sliced assignments as
- // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
- // iterating over the batches.
- const batches = makeBatches(numSamples, batchSize);
- const outsBatches = this.outputs.map(output => []);
- // TODO(cais): Can the scope() be pushed down inside the for loop?
- for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
- const batchOuts = tidy(() => {
- const batchStart = batches[batchIndex][0];
- const batchEnd = batches[batchIndex][1];
- // TODO(cais): Take care of the case of the last element is a flag for
- // training/test.
- const insBatch = sliceArrays(ins, batchStart, batchEnd);
- // Construct the feeds for execute();
- const feeds = [];
- if (Array.isArray(insBatch)) {
- for (let i = 0; i < insBatch.length; ++i) {
- feeds.push({ key: this.inputs[i], value: insBatch[i] });
- }
- }
- else {
- feeds.push({ key: this.inputs[0], value: insBatch });
- }
- const feedDict = new FeedDict(feeds);
- return execute(this.outputs, feedDict);
- });
- batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
- }
- return singletonOrArray(outsBatches.map(batches => concat(batches, 0)));
- });
- }
- /**
- * Generates output predictions for the input samples.
- *
- * Computation is done in batches.
- *
- * Note: the "step" mode of predict() is currently not supported.
- * This is because the TensorFlow.js core backend is imperative only.
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
- * ```
- *
- * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
- * the model has multiple inputs.
- * @param args A `ModelPredictArgs` object containing optional fields.
- *
- * @return Prediction results as a `tf.Tensor`(s).
- *
- * @exception ValueError In case of mismatch between the provided input data
- * and the model's expectations, or in case a stateful model receives a
- * number of samples that is not a multiple of the batch size.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- predict(x, args = {}) {
- const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
- checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
- try {
- // TODO(cais): Take care of stateful models.
- // if (this.stateful) ...
- // TODO(cais): Take care of the learning_phase boolean flag.
- // if (this.useLearningPhase) ...
- const batchSize = args.batchSize == null ? 32 : args.batchSize;
- checkBatchSize(batchSize);
- return this.predictLoop(xsRank2OrHigher, batchSize);
- }
- finally {
- disposeNewTensors(xsRank2OrHigher, x);
- }
- }
- /**
- * Returns predictions for a single batch of samples.
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.predictOnBatch(tf.ones([8, 10])).print();
- * ```
- * @param x: Input samples, as a Tensor (for models with exactly one
- * input) or an array of Tensors (for models with more than one input).
- * @return Tensor(s) of predictions
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- predictOnBatch(x) {
- checkInputData(x, this.inputNames, this.feedInputShapes, true);
- // TODO(cais): Take care of the learning_phase boolean flag.
- // if (this.useLearningPhase) ...
- const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
- return this.predictLoop(x, batchSize);
- }
- standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
- // TODO(cais): Add sampleWeight, classWeight
- if (this.optimizer_ == null) {
- throw new RuntimeError('You must compile a model before training/testing. Use ' +
- 'LayersModel.compile(modelCompileArgs).');
- }
- const outputShapes = [];
- for (let i = 0; i < this.feedOutputShapes.length; ++i) {
- const outputShape = this.feedOutputShapes[i];
- const lossFn = this.feedLossFns[i];
- if (lossFn === sparseCategoricalCrossentropy) {
- outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
- }
- else {
- // Porting Note: Because of strong typing `lossFn` must be a function.
- outputShapes.push(outputShape);
- }
- }
- x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
- y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
- // TODO(cais): Standardize sampleWeights & classWeights.
- checkArrayLengths(x, y, null);
- // TODO(cais): Check sampleWeights as well.
- checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
- if (this.stateful && batchSize != null && batchSize > 0) {
- if (x[0].shape[0] % batchSize !== 0) {
- throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
- `number of samples that is divisible by the batch size ` +
- `${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
- }
- }
- return [x, y];
- }
- async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
- const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
- // TODO(cais): Handle sampleWeights.
- if (sampleWeight != null) {
- throw new Error('sample weight is not supported yet.');
- }
- let standardSampleWeights = null;
- if (classWeight != null) {
- const classWeights = standardizeClassWeights(classWeight, this.outputNames);
- standardSampleWeights = [];
- for (let i = 0; i < classWeights.length; ++i) {
- standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
- }
- }
- // TODO(cais): Deal with the case of model.stateful == true.
- return [standardXs, standardYs, standardSampleWeights];
- }
- /**
- * Loop over some test data in batches.
- * @param f A Function returning a list of tensors.
- * @param ins Array of tensors to be fed to `f`.
- * @param batchSize Integer batch size or `null` / `undefined`.
- * @param verbose verbosity mode.
- * @param steps Total number of steps (batches of samples) before
- * declaring test finished. Ignored with the default value of `null` /
- * `undefined`.
- * @returns Array of Scalars.
- */
- testLoop(f, ins, batchSize, verbose = 0, steps) {
- return tidy(() => {
- const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
- const outs = [];
- if (verbose > 0) {
- throw new NotImplementedError('Verbose mode is not implemented yet.');
- }
- // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
- if (steps != null) {
- throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
- }
- else {
- const batches = makeBatches(numSamples, batchSize);
- const indexArray = tensor1d(range$1(0, numSamples));
- for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
- const batchStart = batches[batchIndex][0];
- const batchEnd = batches[batchIndex][1];
- const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
- // TODO(cais): In ins, train flag can be a number, instead of an
- // Tensor? Do we need to handle this in tfjs-layers?
- const insBatch = sliceArraysByIndices(ins, batchIds);
- const batchOuts = f(insBatch);
- if (batchIndex === 0) {
- for (let i = 0; i < batchOuts.length; ++i) {
- outs.push(scalar(0));
- }
- }
- for (let i = 0; i < batchOuts.length; ++i) {
- const batchOut = batchOuts[i];
- outs[i] =
- add$1(outs[i], mul(batchEnd - batchStart, batchOut));
- }
- }
- for (let i = 0; i < outs.length; ++i) {
- outs[i] = div(outs[i], numSamples);
- }
- }
- return outs;
- });
- }
- getDedupedMetricsNames() {
- const outLabels = this.metricsNames;
- // Rename duplicated metrics names (can happen with an output layer
- // shared among multiple dataflows).
- const dedupedOutLabels = [];
- for (let i = 0; i < outLabels.length; ++i) {
- const label = outLabels[i];
- let newLabel = label;
- if (count(outLabels, label) > 1) {
- const dupIndex = count(outLabels.slice(0, i), label);
- newLabel += `_${dupIndex}`;
- }
- dedupedOutLabels.push(newLabel);
- }
- return dedupedOutLabels;
- }
- /**
- * Creates a function that performs the following actions:
- *
- * 1. computes the losses
- * 2. sums them to get the total loss
- * 3. call the optimizer computes the gradients of the LayersModel's
- * trainable weights w.r.t. the total loss and update the variables
- * 4. calculates the metrics
- * 5. returns the values of the losses and metrics.
- */
- makeTrainFunction() {
- return (data) => {
- const lossValues = [];
- const inputs = data.slice(0, this.inputs.length);
- const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
- const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
- const metricsValues = [];
- // Create a function that computes the total loss based on the
- // inputs. This function is used for obtaining gradients through
- // backprop.
- const totalLossFunction = () => {
- const feeds = [];
- for (let i = 0; i < this.inputs.length; ++i) {
- feeds.push({ key: this.inputs[i], value: inputs[i] });
- }
- const feedDict = new FeedDict(feeds);
- const outputs = execute(this.outputs, feedDict, { 'training': true });
- // TODO(cais): Take care of the case of multiple outputs from a
- // single layer?
- let totalLoss;
- for (let i = 0; i < this.lossFunctions.length; ++i) {
- const lossFunction = this.lossFunctions[i];
- let loss = lossFunction(targets[i], outputs[i]);
- if (sampleWeights[i] != null) {
- loss = computeWeightedLoss$1(loss, sampleWeights[i]);
- }
- // TODO(cais): push Scalar instead.
- const meanLoss = mean(loss);
- // TODO(cais): Use a scope() instead, to avoid ownership.
- lossValues.push(meanLoss);
- if (i === 0) {
- totalLoss = loss;
- }
- else {
- totalLoss = add$1(totalLoss, loss);
- }
- }
- // Compute the metrics.
- // TODO(cais): These should probably be calculated outside
- // totalLossFunction to benefit speed?
- for (let i = 0; i < this.metricsTensors.length; ++i) {
- let weightedMetric;
- if (this.outputs.length > 1 && i < this.outputs.length) {
- weightedMetric = lossValues[i];
- }
- else {
- const metric = this.metricsTensors[i][0];
- const outputIndex = this.metricsTensors[i][1];
- weightedMetric =
- mean(metric(targets[outputIndex], outputs[outputIndex]));
- }
- keep(weightedMetric);
- // TODO(cais): Use a scope() instead, to avoid ownership.
- metricsValues.push(weightedMetric);
- }
- totalLoss = mean(totalLoss);
- // Add regularizer penalties.
- this.calculateLosses().forEach(regularizerLoss => {
- totalLoss = add$1(totalLoss, regularizerLoss);
- });
- return totalLoss;
- };
- const variables = this.collectedTrainableWeights.map(param => param.read());
- const returnCost = true;
- const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
- return [totalLossValue].concat(metricsValues);
- };
- }
- /**
- * Create a function which, when invoked with an array of `tf.Tensor`s as a
- * batch of inputs, returns the prespecified loss and metrics of the model
- * under the batch of input data.
- */
- makeTestFunction() {
- this.testFunction = (data) => {
- return tidy(() => {
- const valOutputs = [];
- let totalLoss;
- const inputs = data.slice(0, this.inputs.length);
- const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
- const feeds = [];
- for (let i = 0; i < this.inputs.length; ++i) {
- feeds.push({ key: this.inputs[i], value: inputs[i] });
- }
- const feedDict = new FeedDict(feeds);
- const outputs = execute(this.outputs, feedDict);
- // Compute total loss.
- for (let i = 0; i < this.lossFunctions.length; ++i) {
- const lossFunction = this.lossFunctions[i];
- // TODO(cais): Add sample weighting and replace the simple
- // averaging.
- const loss = mean(lossFunction(targets[i], outputs[i]));
- if (i === 0) {
- totalLoss = loss;
- }
- else {
- totalLoss = add$1(totalLoss, loss);
- }
- valOutputs.push(totalLoss);
- }
- // Compute the metrics.
- for (let i = 0; i < this.metricsTensors.length; ++i) {
- const metric = this.metricsTensors[i][0];
- const outputIndex = this.metricsTensors[i][1];
- // TODO(cais): Replace K.mean() with a proper weighting function.
- const meanMetric = mean(metric(targets[outputIndex], outputs[outputIndex]));
- valOutputs.push(meanMetric);
- }
- return valOutputs;
- });
- };
- }
- /**
- * Trains the model for a fixed number of epochs (iterations on a
- * dataset).
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
- * for (let i = 1; i < 5 ; ++i) {
- * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
- * batchSize: 4,
- * epochs: 3
- * });
- * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
- * }
- * ```
- *
- * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
- * model has multiple inputs. If all inputs in the model are named, you
- * can also pass a dictionary mapping input names to `tf.Tensor`s.
- * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
- * the model has multiple outputs. If all outputs in the model are named,
- * you can also pass a dictionary mapping output names to `tf.Tensor`s.
- * @param args A `ModelFitArgs`, containing optional fields.
- *
- * @return A `History` instance. Its `history` attribute contains all
- * information collected during training.
- *
- * @exception ValueError In case of mismatch between the provided input
- * data and what the model expects.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async fit(x, y, args = {}) {
- return fitTensors(this, x, y, args);
- }
- // TODO(cais): Add code snippet below when it's possible to instantiate
- // actual dataset objects.
- /**
- * Trains the model using a dataset object.
- *
- * @param dataset A dataset object. Its `iterator()` method is expected
- * to generate a dataset iterator object, the `next()` method of which
- * is expected to produce data batches for training. The return value
- * of the `next()` call ought to contain a boolean `done` field and a
- * `value` field. The `value` field is expected to be an array of two
- * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
- * case is for models with exactly one input and one output (e.g..
- * a sequential model). The latter case is for models with multiple
- * inputs and/or multiple outputs.
- * Of the two items in the array, the first is the input feature(s) and
- * the second is the output target(s).
- * @param args A `ModelFitDatasetArgs`, containing optional fields.
- *
- * @return A `History` instance. Its `history` attribute contains all
- * information collected during training.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async fitDataset(dataset, args) {
- return fitDataset(this, dataset, args);
- }
- /**
- * Runs a single gradient update on a single batch of data.
- *
- * This method differs from `fit()` and `fitDataset()` in the following
- * regards:
- * - It operates on exactly one batch of data.
- * - It returns only the loss and matric values, instead of
- * returning the batch-by-batch loss and metric values.
- * - It doesn't support fine-grained options such as verbosity and
- * callbacks.
- *
- * @param x Input data. It could be one of the following:
- * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
- * multiple inputs).
- * - An Object mapping input names to corresponding `tf.Tensor` (if the
- * model has named inputs).
- * @param y Target darta. It could be either a `tf.Tensor` a multiple
- * `tf.Tensor`s. It should be consistent with `x`.
- * @returns Training loss or losses (in case the model has
- * multiple outputs), along with metrics (if any), as numbers.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async trainOnBatch(x, y) {
- // TODO(cais): Support sampleWeight and classWeight.
- // TODO(cais): Support Dataset objects.
- const standardizeOut = await this.standardizeUserData(x, y);
- const inputs = standardizeOut[0];
- const targets = standardizeOut[1];
- const trainFunction = this.makeTrainFunction();
- const losses = trainFunction(inputs.concat(targets));
- const lossValues = [];
- for (const loss of losses) {
- const v = await loss.data();
- lossValues.push(v[0]);
- }
- dispose(losses);
- return singletonOrArray(lossValues);
- }
- /**
- * Extract weight values of the model.
- *
- * @param config: An instance of `io.SaveConfig`, which specifies
- * model-saving options such as whether only trainable weights are to be
- * saved.
- * @returns A `NamedTensorMap` mapping original weight names (i.e.,
- * non-uniqueified weight names) to their values.
- */
- getNamedWeights(config) {
- const namedWeights = [];
- const trainableOnly = config != null && config.trainableOnly;
- const weights = trainableOnly ? this.trainableWeights : this.weights;
- const weightValues = this.getWeights(trainableOnly);
- for (let i = 0; i < weights.length; ++i) {
- if (trainableOnly && !weights[i].trainable) {
- // Optionally skip non-trainable weights.
- continue;
- }
- namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
- }
- return namedWeights;
- }
- /**
- * Setter used for force stopping of LayersModel.fit() (i.e., training).
- *
- * Example:
- *
- * ```js
- * const input = tf.input({shape: [10]});
- * const output = tf.layers.dense({units: 1}).apply(input);
- * const model = tf.model({inputs: [input], outputs: [output]});
- * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
- * const xs = tf.ones([8, 10]);
- * const ys = tf.zeros([8, 1]);
- *
- * const history = await model.fit(xs, ys, {
- * epochs: 10,
- * callbacks: {
- * onEpochEnd: async (epoch, logs) => {
- * if (epoch === 2) {
- * model.stopTraining = true;
- * }
- * }
- * }
- * });
- *
- * // There should be only 3 values in the loss array, instead of 10
- * values,
- * // due to the stopping after 3 epochs.
- * console.log(history.history.loss);
- * ```
- */
- set stopTraining(stop) {
- this.stopTraining_ = stop;
- }
- get stopTraining() {
- return this.stopTraining_;
- }
- get optimizer() {
- return this.optimizer_;
- }
- set optimizer(optimizer) {
- if (this.optimizer_ !== optimizer) {
- this.optimizer_ = optimizer;
- this.isOptimizerOwned = false;
- }
- }
- dispose() {
- const result = super.dispose();
- if (result.refCountAfterDispose === 0 && this.optimizer != null &&
- this.isOptimizerOwned) {
- const numTensorsBeforeOptmizerDisposal = memory().numTensors;
- this.optimizer_.dispose();
- result.numDisposedVariables +=
- numTensorsBeforeOptmizerDisposal - memory().numTensors;
- }
- return result;
- }
- getLossIdentifiers() {
- let lossNames;
- if (typeof this.loss === 'string') {
- lossNames = toSnakeCase(this.loss);
- }
- else if (Array.isArray(this.loss)) {
- for (const loss of this.loss) {
- if (typeof loss !== 'string') {
- throw new Error('Serialization of non-string loss is not supported.');
- }
- }
- lossNames = this.loss.map(name => toSnakeCase(name));
- }
- else {
- const outputNames = Object.keys(this.loss);
- lossNames = {};
- const losses = this.loss;
- for (const outputName of outputNames) {
- if (typeof losses[outputName] === 'string') {
- lossNames[outputName] =
- toSnakeCase(losses[outputName]);
- }
- else {
- throw new Error('Serialization of non-string loss is not supported.');
- }
- }
- }
- return lossNames;
- }
- getMetricIdentifiers() {
- if (typeof this.metrics === 'string' ||
- typeof this.metrics === 'function') {
- return [toSnakeCase(getLossOrMetricName(this.metrics))];
- }
- else if (Array.isArray(this.metrics)) {
- return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric)));
- }
- else {
- const metricsIdentifiers = {};
- for (const key in this.metrics) {
- metricsIdentifiers[key] =
- toSnakeCase(getLossOrMetricName(this.metrics[key]));
- }
- return metricsIdentifiers;
- }
- }
- getTrainingConfig() {
- return {
- loss: this.getLossIdentifiers(),
- metrics: this.getMetricIdentifiers(),
- optimizer_config: {
- class_name: this.optimizer.getClassName(),
- config: this.optimizer.getConfig()
- }
- };
- // TODO(cais): Add weight_metrics when they are supported.
- // TODO(cais): Add sample_weight_mode when it's supported.
- // TODO(cais): Add loss_weights when it's supported.
- }
- loadTrainingConfig(trainingConfig) {
- if (trainingConfig.weighted_metrics != null) {
- throw new Error('Loading weight_metrics is not supported yet.');
- }
- if (trainingConfig.loss_weights != null) {
- throw new Error('Loading loss_weights is not supported yet.');
- }
- if (trainingConfig.sample_weight_mode != null) {
- throw new Error('Loading sample_weight_mode is not supported yet.');
- }
- const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
- const optimizer = deserialize(tsConfig);
- let loss;
- if (typeof trainingConfig.loss === 'string') {
- loss = toCamelCase(trainingConfig.loss);
- }
- else if (Array.isArray(trainingConfig.loss)) {
- loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
- }
- else if (trainingConfig.loss != null) {
- loss = {};
- for (const key in trainingConfig.loss) {
- loss[key] = toCamelCase(trainingConfig.loss[key]);
- }
- }
- let metrics;
- if (Array.isArray(trainingConfig.metrics)) {
- metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
- }
- else if (trainingConfig.metrics != null) {
- metrics = {};
- for (const key in trainingConfig.metrics) {
- metrics[key] = toCamelCase(trainingConfig.metrics[key]);
- }
- }
- this.compile({ loss, metrics, optimizer });
- }
- /**
- * Save the configuration and/or weights of the LayersModel.
- *
- * An `IOHandler` is an object that has a `save` method of the proper
- * signature defined. The `save` method manages the storing or
- * transmission of serialized data ("artifacts") that represent the
- * model's topology and weights onto or via a specific medium, such as
- * file downloads, local storage, IndexedDB in the web browser and HTTP
- * requests to a server. TensorFlow.js provides `IOHandler`
- * implementations for a number of frequently used saving mediums, such as
- * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
- * for more details.
- *
- * This method also allows you to refer to certain types of `IOHandler`s
- * as URL-like string shortcuts, such as 'localstorage://' and
- * 'indexeddb://'.
- *
- * Example 1: Save `model`'s topology and weights to browser [local
- * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('localstorage://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 2. Saving `model`'s topology and weights to browser
- * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('indexeddb://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 3. Saving `model`'s topology and weights as two files
- * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
- * browser.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * const saveResults = await model.save('downloads://my-model-1');
- * ```
- *
- * Example 4. Send `model`'s topology and weights to an HTTP server.
- * See the documentation of `tf.io.http` for more details
- * including specifying request parameters and implementation of the
- * server.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * const saveResults = await model.save('http://my-server/model/upload');
- * ```
- *
- * @param handlerOrURL An instance of `IOHandler` or a URL-like,
- * scheme-based string shortcut for `IOHandler`.
- * @param config Options for saving the model.
- * @returns A `Promise` of `SaveResult`, which summarizes the result of
- * the saving, such as byte sizes of the saved artifacts for the model's
- * topology and weight values.
- *
- * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
- */
- async save(handlerOrURL, config) {
- if (typeof handlerOrURL === 'string') {
- const handlers = getSaveHandlers(handlerOrURL);
- if (handlers.length === 0) {
- throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
- }
- else if (handlers.length > 1) {
- throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
- `URL '${handlerOrURL}'`);
- }
- handlerOrURL = handlers[0];
- }
- if (handlerOrURL.save == null) {
- throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
- 'provided does not have the `save` attribute defined.');
- }
- const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config));
- const returnString = false;
- const unusedArg = null;
- const modelConfig = this.toJSON(unusedArg, returnString);
- const modelArtifacts = {
- modelTopology: modelConfig,
- format: LAYERS_MODEL_FORMAT_NAME,
- generatedBy: `TensorFlow.js tfjs-layers v${version$1}`,
- convertedBy: null,
- };
- const includeOptimizer = config == null ? false : config.includeOptimizer;
- if (includeOptimizer && this.optimizer != null) {
- modelArtifacts.trainingConfig = this.getTrainingConfig();
- const weightType = 'optimizer';
- const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType);
- weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
- weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
- }
- if (this.userDefinedMetadata != null) {
- // Check serialized size of user-defined metadata.
- const checkSize = true;
- checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
- modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
- }
- modelArtifacts.weightData = weightDataAndSpecs.data;
- modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
- return handlerOrURL.save(modelArtifacts);
- }
- /**
- * Set user-defined metadata.
- *
- * The set metadata will be serialized together with the topology
- * and weights of the model during `save()` calls.
- *
- * @param setUserDefinedMetadata
- */
- setUserDefinedMetadata(userDefinedMetadata) {
- checkUserDefinedMetadata(userDefinedMetadata, this.name);
- this.userDefinedMetadata = userDefinedMetadata;
- }
- /**
- * Get user-defined metadata.
- *
- * The metadata is supplied via one of the two routes:
- * 1. By calling `setUserDefinedMetadata()`.
- * 2. Loaded during model loading (if the model is constructed
- * via `tf.loadLayersModel()`.)
- *
- * If no user-defined metadata is available from either of the
- * two routes, this function will return `undefined`.
- */
- getUserDefinedMetadata() {
- return this.userDefinedMetadata;
- }
- }
- // The class name is 'Model' rather than 'LayersModel' for backwards
- // compatibility since this class name shows up in the serialization format.
- /** @nocollapse */
- LayersModel.className = 'Model';
- registerClass(LayersModel);
- /**
- * A `tf.Functional` is an alias to `tf.LayersModel`.
- *
- * See also:
- * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
- */
- /** @doc {heading: 'Models', subheading: 'Classes'} */
- class Functional extends LayersModel {
- }
- Functional.className = 'Functional';
- registerClass(Functional);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Parses a JSON model configuration file and returns a model instance.
- *
- * ```js
- * // This example shows how to serialize a model using `toJSON()` and
- * // deserialize it as another model using `tf.models.modelFromJSON()`.
- * // Note: this example serializes and deserializes only the topology
- * // of the model; the weights of the loaded model will be different
- * // from those of the the original model, due to random weight
- * // initialization.
- * // To load the topology and weights of a model, use `tf.loadLayersModel()`.
- * const model1 = tf.sequential();
- * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
- * // Serialize `model1` as a JSON object.
- * const model1JSON = model1.toJSON(null, false);
- * model1.summary();
- *
- * const model2 = await tf.models.modelFromJSON(model1JSON);
- * model2.summary();
- * ```
- *
- * @param modelAndWeightsConfig JSON object or string encoding a model and
- * weights configuration. It can also be only the topology JSON of the
- * model, in which case the weights will not be loaded.
- * @param custom_objects Optional dictionary mapping names
- * (strings) to custom classes or functions to be
- * considered during deserialization.
- * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
- */
- async function modelFromJSON(modelAndWeightsConfig, customObjects) {
- if (!('modelTopology' in modelAndWeightsConfig)) {
- modelAndWeightsConfig = { modelTopology: modelAndWeightsConfig };
- }
- modelAndWeightsConfig = modelAndWeightsConfig;
- let modelTopology = modelAndWeightsConfig.modelTopology;
- if (modelTopology['model_config'] != null) {
- // If the model-topology JSON contains a 'model_config' field, then it is
- // a full model JSON (e.g., from `keras.Model.save()`), which contains
- // not only the model's architecture in its 'model_config' field, but
- // additional information such as the model's optimizer. We use only the
- // 'model_config' field currently.
- modelTopology = modelTopology['model_config'];
- }
- const tsConfig = convertPythonicToTs(modelTopology);
- const model = deserialize(tsConfig, customObjects);
- if (modelAndWeightsConfig.weightsManifest != null) {
- // Load the weight values keyed by the original tensor names in the model
- // file that was loaded. These should match the keys of the weight
- // manifest.
- const weightValues = await loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(weight => weight.originalName));
- // Map the weights to the unique tensor names generated during model loading
- const uniqueWeightValues = {};
- for (const weight of model.weights) {
- uniqueWeightValues[weight.originalName] =
- weightValues[weight.originalName];
- }
- model.loadWeights(uniqueWeightValues);
- // Dispose temporary weight values.
- dispose(weightValues);
- }
- return model;
- }
- /**
- * Load a model, including its topology and optionally weights. See the
- * Tutorial named "How to import a Keras Model" for usage examples.
- *
- * Example 1: Save `model`'s topology and weights to browser [local
- * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('localstorage://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 2. Saving `model`'s topology and weights to browser
- * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('indexeddb://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 3. Load a model from user-selected files from HTML
- * [file input
- * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
- *
- * ```js
- * // Note: this code snippet will not work without the HTML elements in the
- * // page
- * const jsonUpload = document.getElementById('json-upload');
- * const weightsUpload = document.getElementById('weights-upload');
- *
- * const model = await tf.loadLayersModel(
- * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
- * ```
- *
- * Example 4. Load a model from an HTTP server.
- *
- * ```js
- * const model = await
- * tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
- * model.summary();
- * ```
- *
- * @param pathOrIOHandler Can be either of the two formats
- * 1. A string path to the `ModelAndWeightsConfig` JSON describing
- * the model in the canonical TensorFlow.js format. This path will be
- * interpreted as a relative HTTP path, to which `fetch` will be used to
- * request the model topology and weight manifest JSON.
- * The content of the JSON file is assumed to be a JSON object with the
- * following fields and values:
- * - 'modelTopology': A JSON object that can be either of:
- * 1. a model architecture JSON consistent with the format of the return
- * value of `keras.Model.to_json()`
- * 2. a full model JSON in the format of `keras.models.save_model()`.
- * - 'weightsManifest': A TensorFlow.js weights manifest.
- * See the Python converter function `save_model()` for more details.
- * It is also assumed that model weights can be accessed from relative
- * paths described by the `paths` fields in weights manifest.
- * 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
- * method.
- * @param options Optional configuration arguments for the model loading,
- * including:
- * - `strict`: Require that the provided weights exactly match those required
- * by the layers. Default true. Passing false means that both extra
- * weights and missing weights will be silently ignored.
- * - `onProgress`: A progress callback of the form:
- * `(fraction: number) => void`. This callback can be used to monitor the
- * model-loading process.
- * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
- * loaded.
- */
- async function loadLayersModelInternal(pathOrIOHandler, options) {
- if (options == null) {
- options = {};
- }
- if (typeof pathOrIOHandler === 'string') {
- const handlers = getLoadHandlers(pathOrIOHandler, options);
- if (handlers.length === 0) {
- // For backward compatibility: if no load handler can be found,
- // assume it is a relative http path.
- // TODO(cais): Reformat the args into a single `LoadOptions` once the core
- // is refactored.
- handlers.push(browserHTTPRequest(pathOrIOHandler, options));
- }
- else if (handlers.length > 1) {
- throw new ValueError(`Found more than one (${handlers.length}) load handlers for ` +
- `URL '${pathOrIOHandler}'`);
- }
- pathOrIOHandler = handlers[0];
- }
- return loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options);
- }
- /**
- * Load a model and optionally its weights, using an IOHandler object.
- *
- * @param handler The instance of `IOHandler` to be used during the model
- * loading.
- * @param customObjects Any optional custom objects to be used during model
- * loading.
- * @param strict Whether the weight loading will be done in strict mode.
- * Default: `true`.
- */
- async function loadLayersModelFromIOHandler(handler, customObjects, options) {
- if (options == null) {
- options = {};
- }
- if (handler.load == null) {
- throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' +
- 'does not have the `load` method implemented.');
- }
- const artifacts = await handler.load();
- let modelTopology = artifacts.modelTopology;
- if (modelTopology['model_config'] != null) {
- modelTopology = modelTopology['model_config'];
- }
- const strict = options.strict == null ? true : options.strict;
- // If weights are provided and the weight-loading mode is strict, use
- // fast weight initialization. This skips costly initializers such as
- // 'orthogonal' and saves unnecessary computation in cases where
- // the initialized weight values will immediately be overwritten by
- // loaded weight values.
- const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
- const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
- const trainingConfig = artifacts.trainingConfig;
- if (trainingConfig != null) {
- model.loadTrainingConfig(trainingConfig);
- }
- if (artifacts.userDefinedMetadata != null) {
- model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
- }
- // If weightData is present, load the weights into the model.
- if (artifacts.weightData != null) {
- // Loading weights requires weightSpecs.
- if (artifacts.weightSpecs == null) {
- throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' +
- 'Therefore loading of weights cannot proceed.');
- }
- const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs);
- model.loadWeights(modelWeights, strict);
- if (model.optimizer != null && optimizerWeights.length > 0) {
- await model.optimizer.setWeights(optimizerWeights);
- }
- // Dispose temporary weight values.
- dispose(modelWeights);
- dispose(optimizerWeights.map(w => w.tensor));
- }
- return model;
- }
- function decodeModelAndOptimizerWeights(buffer, specs) {
- const name2Tensor = decodeWeights(buffer, specs);
- const modelWeights = {};
- const optimizerWeights = [];
- specs.forEach(spec => {
- if (spec.group === 'optimizer') {
- optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] });
- }
- else {
- modelWeights[spec.name] = name2Tensor[spec.name];
- }
- });
- return { modelWeights, optimizerWeights };
- }
- /**
- * A model with a stack of layers, feeding linearly from one to the next.
- *
- * `tf.sequential` is a factory function that creates an instance of
- * `tf.Sequential`.
- *
- * ```js
- * // Define a model for linear regression.
- * const model = tf.sequential();
- * model.add(tf.layers.dense({units: 1, inputShape: [1]}));
- *
- * // Prepare the model for training: Specify the loss and the optimizer.
- * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
- *
- * // Generate some synthetic data for training.
- * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
- * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
- *
- * // Train the model using the data then do inference on a data point the
- * // model hasn't seen:
- * await model.fit(xs, ys);
- * model.predict(tf.tensor2d([5], [1, 1])).print();
- * ```
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- class Sequential extends LayersModel {
- constructor(args) {
- super({ inputs: [], outputs: [] });
- args = args || {};
- this.trainable = true;
- this.built = false;
- // Set model name.
- this.name = (args.name != null) ? args.name : getUid('sequential_');
- // Add to the model any layers passed to the constructor.
- if (args.layers != null) {
- for (const layer of args.layers) {
- this.add(layer);
- }
- }
- }
- // Helper function to Sequential.add Throws if the new output shape will be
- // invalid.
- checkShape(layer) {
- const shape = layer.inboundNodes[0].outputTensors[0].shape;
- if (shape.some(x => x < 0)) {
- throw new ValueError('Negative dimension size caused by adding layer ' +
- `${layer.name} with input shape [` +
- `${layer.inboundNodes[0].inputTensors[0].shape}]`);
- }
- }
- /**
- * Adds a layer instance on top of the layer stack.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.dense({units: 8, inputShape: [1]}));
- * model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
- * model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
- * // Note that the untrained model is random at this point.
- * model.predict(tf.randomNormal([10, 1])).print();
- * ```
- * @param layer Layer instance.
- *
- * @exception ValueError In case the `layer` argument does not know its
- * input shape.
- * @exception ValueError In case the `layer` argument has multiple output
- * tensors, or is already connected somewhere else (forbidden in
- * `Sequential` models).
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- add(layer) {
- const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
- let modelLayer;
- if (isLayerModelInstance) {
- modelLayer = layer;
- if (modelLayer.outputs.length !== 1) {
- throw new ValueError('All layers in a Sequential model ' +
- 'should have a single output tensor. ' +
- 'For multi-output layers, ' +
- 'use the functional API.');
- }
- if (modelLayer.inputs.length !== 1) {
- throw new ValueError('All layers in a Sequential model ' +
- 'should have a single input tensor. ' +
- 'For multi-input layers, ' +
- 'use the functional API.');
- }
- }
- if (this.outputs.length === 0) {
- // first layer in model: check that it is an input layer
- if (layer.inboundNodes.length === 0) {
- // create an input layer
- if (layer.batchInputShape == null) {
- throw new ValueError('The first layer in a Sequential model must ' +
- 'get an `inputShape` or `batchInputShape` argument.');
- }
- // Instantiate the input layer.
- const x = Input({
- batchShape: layer.batchInputShape,
- dtype: layer.dtype,
- name: layer.name + '_input'
- });
- // This will build the current layer and create the node connecting
- // the current layer to the input layer we just created.
- layer.apply(x);
- }
- if (isLayerModelInstance) {
- this.outputs = modelLayer.outputs;
- this.inputs = modelLayer.inputs;
- }
- else {
- if (layer.inboundNodes.length !== 1) {
- throw new ValueError('A layer added to a Sequential model must not already be ' +
- `connected somewhere else. LayersModel received layer ${layer.name} ` +
- `which has ${layer.inboundNodes.length} pre-existing inbound ` +
- 'connections.');
- }
- if (layer.inboundNodes[0].outputTensors.length !== 1) {
- throw new ValueError('All layers in a Sequential model ' +
- 'should have a single output tensor. ' +
- 'For multi-output layers, ' +
- 'use the functional API.');
- }
- this.checkShape(layer);
- this.outputs = [layer.inboundNodes[0].outputTensors[0]];
- this.inputs = getSourceInputs(this.outputs[0]);
- }
- this.inboundNodes = [];
- // We create an input node, which we will keep updated
- // as we add more layers.
- // (This call has side effects.)
- // tslint:disable-next-line:no-unused-expression
- new Node({
- outboundLayer: this,
- inboundLayers: [],
- nodeIndices: [],
- tensorIndices: [],
- inputTensors: this.inputs,
- outputTensors: this.outputs,
- // no model-level masking for now
- inputMasks: pyListRepeat(null, this.inputs.length),
- outputMasks: [null],
- inputShapes: this.inputs.map(x => x.shape),
- outputShapes: this.outputs[0].shape
- });
- }
- else {
- const outputTensor = layer.apply(this.outputs[0]);
- if (Array.isArray(outputTensor)) {
- throw new TypeError('All layers in a Sequential model ' +
- 'should have a single output tensor. ' +
- 'For multi-output layers, ' +
- 'use the functional API.');
- }
- this.checkShape(layer);
- this.outputs = [outputTensor];
- // update self.inbound_nodes
- this.inboundNodes[0].outputTensors = this.outputs;
- this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
- }
- this.layers.push(layer);
- this.built = false;
- }
- /**
- * Removes the last layer in the model.
- *
- * @exception TypeError if there are no layers in the model.
- */
- pop() {
- if (this.layers.length === 0) {
- throw new TypeError('There are no layers in the model.');
- }
- this.layers.pop();
- if (this.layers.length === 0) {
- this.outputs = [];
- this.inboundNodes = [];
- this.outboundNodes = [];
- }
- else {
- const lastLayerIndex = this.layers.length - 1;
- this.layers[lastLayerIndex].outboundNodes = [];
- this.outputs = [this.layers[lastLayerIndex].output];
- // update self.inbound_nodes
- this.inboundNodes[0].outputTensors = this.outputs;
- this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
- }
- }
- call(inputs, kwargs) {
- if (this.model == null) {
- this.build();
- }
- return this.model.call(inputs, kwargs);
- }
- build(inputShape) {
- // Call `getExactlyOneShape` without using its return value,
- // to verify that exactly one input shape is provided.
- getExactlyOneShape(inputShape);
- if (this.inputs.length === 0 || this.outputs.length === 0) {
- throw new TypeError('Sequential model cannot be built: model is empty.' +
- ' Add some layers first.');
- }
- // actually create the model
- this.model = new LayersModel({
- inputs: this.inputs,
- outputs: this.outputs[0],
- name: this.name + '_model'
- });
- this.model.trainable = this.trainable;
- // mirror model attributes
- this.supportsMasking = this.model.supportsMasking;
- // TODO(michaelterry): Add caches
- this.inputLayers = this.model.inputLayers;
- this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
- this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
- this.outputLayers = this.model.outputLayers;
- this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
- this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
- this.nodesByDepth = this.model.nodesByDepth;
- this.containerNodes = this.model.containerNodes;
- this.outputNames = this.model.outputNames;
- this.inputNames = this.model.inputNames;
- // TODO(michaelterry): Add feedInputNames, feedInputs, if needed.
- // TODO(michaelterry): Add callbackModel if needed.
- this.built = true;
- }
- countParams() {
- if (!this.built) {
- this.build();
- }
- return super.countParams();
- }
- /**
- * Print a text summary of the Sequential model's layers.
- *
- * The summary includes
- * - Name and type of all layers that comprise the model.
- * - Output shape(s) of the layers
- * - Number of weight parameters of each layer
- * - The total number of trainable and non-trainable parameters of the
- * model.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(
- * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
- * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
- *
- * model.summary();
- * ```
- *
- * @param lineLength Custom line length, in number of characters.
- * @param positions Custom widths of each of the columns, as either
- * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
- * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
- * right-most (i.e., ending) position of a column.
- * @param printFn Custom print function. Can be used to replace the default
- * `console.log`. For example, you can use `x => {}` to mute the printed
- * messages in the console.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- summary(lineLength, positions, printFn = console.log) {
- if (!this.built) {
- this.build();
- }
- super.summary(lineLength, positions, printFn);
- }
- /**
- * Sets the weights of the model.
- *
- * @param weights Should be a list of Tensors with shapes and types matching
- * the output of `model.getWeights()`.
- */
- setWeights(weights) {
- if (this.model == null) {
- this.build();
- }
- this.model.setWeights(weights);
- }
- /**
- * Returns the loss value & metrics values for the model in test mode.
- *
- * Loss and metrics are specified during `compile()`, which needs to happen
- * before calls to `evaluate()`.
- *
- * Computation is done in batches.
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
- * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
- * batchSize: 4,
- * });
- * result.print();
- * ```
- *
- * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
- * model has multiple inputs.
- * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
- * model has multiple outputs.
- * @param args A `ModelEvaluateConfig`, containing optional fields.
- *
- * @return `Scalar` test loss (if the model has a single output and no
- * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
- * and/or metrics). The attribute `model.metricsNames`
- * will give you the display labels for the scalar outputs.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- evaluate(x, y, args = {}) {
- if (!this.built) {
- throw new RuntimeError('The model needs to be compiled before being used.');
- }
- return this.model.evaluate(x, y, args);
- }
- // TODO(cais): Add code snippet below once real dataset objects are
- // available.
- /**
- * Evaluate model using a dataset object.
- *
- * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
- *
- * @param dataset A dataset object. Its `iterator()` method is expected
- * to generate a dataset iterator object, the `next()` method of which
- * is expected to produce data batches for evaluation. The return value
- * of the `next()` call ought to contain a boolean `done` field and a
- * `value` field. The `value` field is expected to be an array of two
- * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
- * case is for models with exactly one input and one output (e.g..
- * a sequential model). The latter case is for models with multiple
- * inputs and/or multiple outputs. Of the two items in the array, the
- * first is the input feature(s) and the second is the output target(s).
- * @param args A configuration object for the dataset-based evaluation.
- * @returns Loss and metric values as an Array of `Scalar` objects.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async evaluateDataset(dataset, args) {
- if (!this.built) {
- throw new RuntimeError('The model needs to be compiled before being used.');
- }
- return this.model.evaluateDataset(dataset, args);
- }
- /**
- * Generates output predictions for the input samples.
- *
- * Computation is done in batches.
- *
- * Note: the "step" mode of predict() is currently not supported.
- * This is because the TensorFow.js core backend is imperative only.
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.predict(tf.ones([2, 10])).print();
- * ```
- *
- * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
- * the model has multiple inputs.
- * @param conifg A `ModelPredictConfig` object containing optional fields.
- *
- * @return `tf.Tensor`(s) of predictions.
- *
- * @exception ValueError In case of mismatch between the provided input data
- * and the model's expectations, or in case a stateful model receives a
- * number of samples that is not a multiple of the batch size.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- predict(x, args = {}) {
- if (this.model == null) {
- this.build();
- }
- return this.model.predict(x, args);
- }
- /**
- * Returns predictions for a single batch of samples.
- *
- * @param x: Input samples, as a Tensor, or list of Tensors (if the model
- * has multiple inputs).
- * @return Tensor(s) of predictions
- */
- predictOnBatch(x) {
- if (this.model == null) {
- this.build();
- }
- return this.model.predictOnBatch(x);
- }
- /**
- * See `LayersModel.compile`.
- *
- * @param args
- */
- compile(args) {
- this.build();
- this.model.compile(args);
- this.optimizer_ = this.model.optimizer;
- // tslint:disable-next-line:no-any
- this.isOptimizerOwned = this.model.isOptimizerOwned;
- this.loss = this.model.loss;
- this.metrics = this.model.metrics;
- // TODO(cais): Add this.lossWeights, this.sampleWeightMode,
- // this.weightedMetrics, this.targets.
- this.metricsTensors = this.model.metricsTensors;
- this.metricsNames = this.model.metricsNames;
- // TODO(cais): Add sampleWeights.
- }
- get optimizer() {
- return this.model == null ? undefined : this.model.optimizer;
- }
- set optimizer(optimizer) {
- this.model.optimizer = optimizer;
- }
- /**
- * Trains the model for a fixed number of epochs (iterations on a dataset).
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
- * });
- * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
- * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
- * batchSize: 4,
- * epochs: 3
- * });
- * console.log(history.history.loss[0]);
- * ```
- *
- * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
- * model has multiple inputs. If all inputs in the model are named, you can
- * also pass a dictionary mapping input names to `tf.Tensor`s.
- * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
- * the model has multiple outputs. If all outputs in the model are named, you
- * can also pass a dictionary mapping output names to `tf.Tensor`s.
- * @param args A `ModelFitConfig`, containing optional fields.
- *
- * @return A `History` instance. Its `history` attribute contains all
- * information collected during training.
- *
- * @exception ValueError In case of mismatch between the provided input data
- * and what the model expects.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async fit(x, y, args = {}) {
- if (!this.built) {
- throw new RuntimeError('The model needs to be compiled before ' +
- 'being used.');
- }
- return this.model.fit(x, y, args);
- }
- /**
- * Trains the model using a dataset object.
- *
- * ```js
- * const xArray = [
- * [1, 1, 1, 1, 1, 1, 1, 1, 1],
- * [1, 1, 1, 1, 1, 1, 1, 1, 1],
- * [1, 1, 1, 1, 1, 1, 1, 1, 1],
- * [1, 1, 1, 1, 1, 1, 1, 1, 1],
- * ];
- * const yArray = [1, 1, 1, 1];
- * // Create a dataset from the JavaScript array.
- * const xDataset = tf.data.array(xArray);
- * const yDataset = tf.data.array(yArray);
- * // Zip combines the `x` and `y` Datasets into a single Dataset, the
- * // iterator of which will return an object containing of two tensors,
- * // corresponding to `x` and `y`. The call to `batch(4)` will bundle
- * // four such samples into a single object, with the same keys now pointing
- * // to tensors that hold 4 examples, organized along the batch dimension.
- * // The call to `shuffle(4)` causes each iteration through the dataset to
- * // happen in a different order. The size of the shuffle window is 4.
- * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
- * .batch(4)
- * .shuffle(4);
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 1, inputShape: [9]})]
- * });
- * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
- * const history = await model.fitDataset(xyDataset, {
- * epochs: 4,
- * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
- * });
- * ```
- *
- * @param dataset A dataset object. Its `iterator()` method is expected to
- * generate a dataset iterator object, the `next()` method of which is
- * expected to produce data batches for evaluation. The return value of the
- * `next()` call ought to contain a boolean `done` field and a `value`
- * field.
- *
- * The `value` field is expected to be an object of with fields
- * `xs` and `ys`, which point to the feature tensor and the target tensor,
- * respectively. This case is for models with exactly one input and one
- * output (e.g.. a sequential model). For example:
- * ```js
- * {value: {xs: xsTensor, ys: ysTensor}, done: false}
- * ```
- *
- * If the model has multiple inputs, the `xs` field of `value` should
- * be an object mapping input names to their respective feature tensors.
- * For example:
- * ```js
- * {
- * value: {
- * xs: {
- * input_1: xsTensor1,
- * input_2: xsTensor2
- * },
- * ys: ysTensor
- * },
- * done: false
- * }
- * ```
- * If the model has multiple outputs, the `ys` field of `value` should
- * be an object mapping output names to their respective target tensors.
- * For example:
- * ```js
- * {
- * value: {
- * xs: xsTensor,
- * ys: {
- * output_1: ysTensor1,
- * output_2: ysTensor2
- * },
- * },
- * done: false
- * }
- * ```
- * @param args A `ModelFitDatasetArgs`, containing optional fields.
- *
- * @return A `History` instance. Its `history` attribute contains all
- * information collected during training.
- *
- * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
- */
- async fitDataset(dataset, args) {
- if (!this.built) {
- throw new RuntimeError('The model needs to be compiled before ' +
- 'being used.');
- }
- return this.model.fitDataset(dataset, args);
- }
- /**
- * Runs a single gradient update on a single batch of data.
- *
- * This method differs from `fit()` and `fitDataset()` in the following
- * regards:
- * - It operates on exactly one batch of data.
- * - It returns only the loss and matric values, instead of
- * returning the batch-by-batch loss and metric values.
- * - It doesn't support fine-grained options such as verbosity and
- * callbacks.
- *
- * @param x Input data. It could be one of the following:
- * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
- * multiple inputs).
- * - An Object mapping input names to corresponding `tf.Tensor` (if the
- * model has named inputs).
- * @param y Target darta. It could be either a `tf.Tensor` a multiple
- * `tf.Tensor`s. It should be consistent with `x`.
- * @returns Training loss or losses (in case the model has
- * multiple outputs), along with metrics (if any), as numbers.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async trainOnBatch(x, y) {
- return this.model.trainOnBatch(x, y);
- }
- /* See parent class for JsDoc */
- /** @nocollapse */
- static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
- let configArray;
- let extraModelConfig = {};
- if (config instanceof Array) {
- if (!(config[0].className != null) ||
- config[0]['className'] === 'Merge') {
- throw new ValueError('Legacy serialization format not supported yet.');
- }
- configArray = config;
- }
- else {
- assert(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` +
- `it must be an Object that contains the 'layers' field.`);
- configArray = config['layers'];
- delete config['layers'];
- extraModelConfig = config;
- }
- const model = new cls(extraModelConfig);
- if (!(model instanceof Sequential)) {
- throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`);
- }
- for (const conf of configArray) {
- const customObjects = undefined;
- const layer = deserialize(conf, customObjects, fastWeightInit);
- if (fastWeightInit) {
- layer.setFastWeightInitDuringBuild(true);
- }
- model.add(layer);
- }
- return model;
- }
- /**
- * Setter used for force stopping of LayersModel.fit() (i.e., training).
- *
- * Example:
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.dense({units: 1, inputShape: [10]}));
- * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
- * const xs = tf.ones([8, 10]);
- * const ys = tf.zeros([8, 1]);
- *
- * const history = await model.fit(xs, ys, {
- * epochs: 10,
- * callbacks: {
- * onEpochEnd: async (epoch, logs) => {
- * if (epoch === 2) {
- * model.stopTraining = true;
- * }
- * }
- * }
- * });
- *
- * // There should be only 3 values in the loss array, instead of 10 values,
- * // due to the stopping after 3 epochs.
- * console.log(history.history.loss);
- * ```
- */
- set stopTraining(stop) {
- // TODO(cais): When refactoring to remove the composition pattern happens,
- // remove this method overriding.
- if (this.model == null) {
- throw new ValueError('Cannot set the stopTraining property of a sequential model before ' +
- 'it is compiled.');
- }
- this.model.stopTraining = stop;
- }
- get stopTraining() {
- if (this.model == null) {
- throw new ValueError('Cannot get the stopTraining property of a sequential model before ' +
- 'it is compiled.');
- }
- return this.model.stopTraining;
- }
- // TODO(cais): Override get trainableWeights() here
- // tslint:disable-next-line:no-any
- getConfig() {
- // NOTE(cais): We override the return type of getConfig() to `any` here,
- // because the `Sequential` class is a special case among `Container`
- // subtypes in that its getConfig() method returns an Array (not a
- // dict).
- const layers = [];
- for (const layer of this.layers) {
- const dict = {};
- dict['className'] = layer.getClassName();
- dict['config'] = layer.getConfig();
- layers.push(dict);
- }
- return { name: this.name, layers };
- }
- }
- /** @nocollapse */
- Sequential.className = 'Sequential';
- registerClass(Sequential);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // TODO(cais): Add doc string to all the public static functions in this
- // class; include exectuable JavaScript code snippets where applicable
- // (b/74074458).
- // LayersModel and related factory methods.
- /**
- * A model is a data structure that consists of `Layers` and defines inputs
- * and outputs.
- *
- * The key difference between `tf.model` and `tf.sequential` is that
- * `tf.model` is more generic, supporting an arbitrary graph (without
- * cycles) of layers. `tf.sequential` is less generic and supports only a linear
- * stack of layers.
- *
- * When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers
- * are used to wire input(s) to output(s).
- *
- * For example, the following code snippet defines a model consisting of
- * two `dense` layers, with 10 and 4 units, respectively.
- *
- * ```js
- * // Define input, which has a size of 5 (not including batch dimension).
- * const input = tf.input({shape: [5]});
- *
- * // First dense layer uses relu activation.
- * const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'});
- * // Second dense layer uses softmax activation.
- * const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'});
- *
- * // Obtain the output symbolic tensor by applying the layers on the input.
- * const output = denseLayer2.apply(denseLayer1.apply(input));
- *
- * // Create the model based on the inputs.
- * const model = tf.model({inputs: input, outputs: output});
- *
- * // The model can be used for training, evaluation and prediction.
- * // For example, the following line runs prediction with the model on
- * // some fake data.
- * model.predict(tf.ones([2, 5])).print();
- * ```
- * See also:
- * `tf.sequential`, `tf.loadLayersModel`.
- *
- * @doc {heading: 'Models', subheading: 'Creation'}
- */
- function model(args) {
- return new LayersModel(args);
- }
- /**
- * Creates a `tf.Sequential` model. A sequential model is any model where the
- * outputs of one layer are the inputs to the next layer, i.e. the model
- * topology is a simple 'stack' of layers, with no branching or skipping.
- *
- * This means that the first layer passed to a `tf.Sequential` model should have
- * a defined input shape. What that means is that it should have received an
- * `inputShape` or `batchInputShape` argument, or for some type of layers
- * (recurrent, Dense...) an `inputDim` argument.
- *
- * The key difference between `tf.model` and `tf.sequential` is that
- * `tf.sequential` is less generic, supporting only a linear stack of layers.
- * `tf.model` is more generic and supports an arbitrary graph (without
- * cycles) of layers.
- *
- * Examples:
- *
- * ```js
- * const model = tf.sequential();
- *
- * // First layer must have an input shape defined.
- * model.add(tf.layers.dense({units: 32, inputShape: [50]}));
- * // Afterwards, TF.js does automatic shape inference.
- * model.add(tf.layers.dense({units: 4}));
- *
- * // Inspect the inferred shape of the model's output, which equals
- * // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
- * // 2nd is the output size of the model's last layer.
- * console.log(JSON.stringify(model.outputs[0].shape));
- * ```
- *
- * It is also possible to specify a batch size (with potentially undetermined
- * batch dimension, denoted by "null") for the first layer using the
- * `batchInputShape` key. The following example is equivalent to the above:
- *
- * ```js
- * const model = tf.sequential();
- *
- * // First layer must have a defined input shape
- * model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]}));
- * // Afterwards, TF.js does automatic shape inference.
- * model.add(tf.layers.dense({units: 4}));
- *
- * // Inspect the inferred shape of the model's output.
- * console.log(JSON.stringify(model.outputs[0].shape));
- * ```
- *
- * You can also use an `Array` of already-constructed `Layer`s to create
- * a `tf.Sequential` model:
- *
- * ```js
- * const model = tf.sequential({
- * layers: [tf.layers.dense({units: 32, inputShape: [50]}),
- * tf.layers.dense({units: 4})]
- * });
- * console.log(JSON.stringify(model.outputs[0].shape));
- * ```
- *
- * @doc {heading: 'Models', subheading: 'Creation'}
- */
- function sequential(config) {
- return new Sequential(config);
- }
- /**
- * Load a model composed of Layer objects, including its topology and optionally
- * weights. See the Tutorial named "How to import a Keras Model" for usage
- * examples.
- *
- * This method is applicable to:
- *
- * 1. Models created with the `tf.layers.*`, `tf.sequential`, and
- * `tf.model` APIs of TensorFlow.js and later saved with the
- * `tf.LayersModel.save` method.
- * 2. Models converted from Keras or TensorFlow tf.keras using the
- * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
- *
- * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
- * forms. For those models, use `tf.loadGraphModel`.
- *
- * Example 1. Load a model from an HTTP server.
- *
- * ```js
- * const model = await tf.loadLayersModel(
- * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
- * model.summary();
- * ```
- *
- * Example 2: Save `model`'s topology and weights to browser [local
- * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('localstorage://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 3. Saving `model`'s topology and weights to browser
- * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
- * then load it back.
- *
- * ```js
- * const model = tf.sequential(
- * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
- * console.log('Prediction from original model:');
- * model.predict(tf.ones([1, 3])).print();
- *
- * const saveResults = await model.save('indexeddb://my-model-1');
- *
- * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
- * console.log('Prediction from loaded model:');
- * loadedModel.predict(tf.ones([1, 3])).print();
- * ```
- *
- * Example 4. Load a model from user-selected files from HTML
- * [file input
- * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
- *
- * ```js
- * // Note: this code snippet will not work without the HTML elements in the
- * // page
- * const jsonUpload = document.getElementById('json-upload');
- * const weightsUpload = document.getElementById('weights-upload');
- *
- * const model = await tf.loadLayersModel(
- * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
- * ```
- *
- * @param pathOrIOHandler Can be either of the two formats
- * 1. A string path to the `ModelAndWeightsConfig` JSON describing
- * the model in the canonical TensorFlow.js format. For file://
- * (tfjs-node-only), http:// and https:// schemas, the path can be
- * either absolute or relative.
- * 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
- * method.
- * @param options Optional configuration arguments for the model loading,
- * including:
- * - `strict`: Require that the provided weights exactly match those required
- * by the layers. Default true. Passing false means that both extra
- * weights and missing weights will be silently ignored.
- * - `onProgress`: A function of the signature `(fraction: number) => void',
- * that can be used as the progress callback for the model loading.
- * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
- * loaded.
- *
- * @doc {heading: 'Models', subheading: 'Loading'}
- */
- function loadLayersModel(pathOrIOHandler, options) {
- if (options == null) {
- options = {};
- }
- return loadLayersModelInternal(pathOrIOHandler, options);
- }
- /**
- * Used to instantiate an input to a model as a `tf.SymbolicTensor`.
- *
- * Users should call the `input` factory function for
- * consistency with other generator functions.
- *
- * Example:
- *
- * ```js
- * // Defines a simple logistic regression model with 32 dimensional input
- * // and 3 dimensional output.
- * const x = tf.input({shape: [32]});
- * const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x);
- * const model = tf.model({inputs: x, outputs: y});
- * model.predict(tf.ones([2, 32])).print();
- * ```
- *
- * Note: `input` is only necessary when using `model`. When using
- * `sequential`, specify `inputShape` for the first layer or use `inputLayer`
- * as the first layer.
- *
- * @doc {heading: 'Models', subheading: 'Inputs'}
- */
- function input(config) {
- return Input(config);
- }
- function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
- CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Base class for Activations.
- *
- * Special note: due to cross-language compatibility reasons, the
- * static readonly className field in this family of classes must be set to
- * the initialLowerCamelCase name of the activation.
- */
- class Activation extends Serializable {
- getConfig() {
- return {};
- }
- }
- /**
- * Exponential linear unit (ELU).
- * Reference: https://arxiv.org/abs/1511.07289
- */
- class Elu$1 extends Activation {
- /**
- * Calculate the activation function.
- *
- * @param x: Input.
- * @param alpha: Scaling factor the negative section.
- * @return Output of the ELU activation.
- */
- apply(x, alpha = 1) {
- return elu$1(x, alpha);
- }
- }
- /** @nocollapse */
- Elu$1.className = 'elu';
- registerClass(Elu$1);
- /**
- * Scaled Exponential Linear Unit. (Klambauer et al., 2017).
- * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515
- * Notes:
- * - To be used together with the initialization "lecunNormal".
- * - To be used together with the dropout variant "AlphaDropout".
- */
- class Selu$1 extends Activation {
- apply(x) {
- return selu(x);
- }
- }
- /** @nocollapse */
- Selu$1.className = 'selu';
- registerClass(Selu$1);
- /**
- * Rectified linear unit
- */
- class Relu$1 extends Activation {
- apply(x) {
- return relu(x);
- }
- }
- /** @nocollapse */
- Relu$1.className = 'relu';
- registerClass(Relu$1);
- /**
- * Rectified linear unit activation maxing out at 6.0.
- */
- class Relu6$1 extends Activation {
- apply(x) {
- return tidy(() => minimum(6.0, relu(x)));
- }
- }
- /** @nocollapse */
- Relu6$1.className = 'relu6';
- registerClass(Relu6$1);
- //* Linear activation (no-op) */
- class Linear extends Activation {
- apply(x) {
- return x;
- }
- }
- /** @nocollapse */
- Linear.className = 'linear';
- registerClass(Linear);
- /**
- * Sigmoid activation function.
- */
- class Sigmoid$1 extends Activation {
- apply(x) {
- return sigmoid(x);
- }
- }
- /** @nocollapse */
- Sigmoid$1.className = 'sigmoid';
- registerClass(Sigmoid$1);
- /**
- * Segment-wise linear approximation of sigmoid.
- */
- class HardSigmoid extends Activation {
- apply(x) {
- return hardSigmoid(x);
- }
- }
- /** @nocollapse */
- HardSigmoid.className = 'hardSigmoid';
- registerClass(HardSigmoid);
- /**
- * Softplus activation function.
- */
- class Softplus$1 extends Activation {
- apply(x) {
- return softplus(x);
- }
- }
- /** @nocollapse */
- Softplus$1.className = 'softplus';
- registerClass(Softplus$1);
- /**
- * Softsign activation function.
- */
- class Softsign extends Activation {
- apply(x) {
- return softsign(x);
- }
- }
- /** @nocollapse */
- Softsign.className = 'softsign';
- registerClass(Softsign);
- /**
- * Hyperbolic tangent function.
- */
- class Tanh$1 extends Activation {
- apply(x) {
- return tanh$1(x);
- }
- }
- /** @nocollapse */
- Tanh$1.className = 'tanh';
- registerClass(Tanh$1);
- /**
- * Softmax activation function
- */
- class Softmax$1 extends Activation {
- /**
- * Calculate the activation function.
- *
- * @param x Tensor.
- * @param axis Integer, axis along which the softmax normalization is applied.
- * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
- * an error.
- *
- * @returns a Tensor of the same shape as x
- *
- * @throws ValueError: In case `dim(x) < 2`.
- */
- apply(x, axis = (-1)) {
- return softmax(x, axis);
- }
- }
- /** @nocollapse */
- Softmax$1.className = 'softmax';
- registerClass(Softmax$1);
- /**
- * Log softmax activation function
- */
- class LogSoftmax$1 extends Activation {
- /**
- * Calculate the activation function of log softmax:
- * log( exp(x_i) / sum(exp(x)) )
- *
- * @param x Tensor.
- * @param axis Integer, axis along which the softmax normalization is applied.
- * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
- * an error.
- *
- * @returns a Tensor of the same shape as x
- *
- * @throws ValueError: In case `dim(x) < 2`.
- */
- apply(x, axis = (-1)) {
- return logSoftmax(x, axis);
- }
- }
- /** @nocollapse */
- LogSoftmax$1.className = 'logSoftmax';
- registerClass(LogSoftmax$1);
- /**
- * Swish activation function
- */
- class Swish extends Activation {
- /**
- * Calculate the activation function.
- *
- * @param x Tensor.
- * @param alpha Scaling factor for the sigmoid function.
- * @returns a Tensor of the same shape as x
- */
- apply(x, alpha = 1) {
- return tidy(() => sigmoid(x.mul(alpha)).mul(x));
- }
- }
- /** @nocollapse */
- Swish.className = 'swish';
- registerClass(Swish);
- function serializeActivation(activation) {
- return activation.getClassName();
- }
- function deserializeActivation(config, customObjects = {}) {
- return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
- }
- function getActivation(identifier) {
- if (identifier == null) {
- const config = {};
- config['className'] = 'linear';
- config['config'] = {};
- return deserializeActivation(config);
- }
- if (typeof identifier === 'string') {
- const config = {};
- config['className'] = identifier;
- config['config'] = {};
- return deserializeActivation(config);
- }
- else if (identifier instanceof Activation) {
- return identifier;
- }
- else {
- return deserializeActivation(identifier);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- function assertObjectArgs(args) {
- if (args != null && typeof args !== 'object') {
- throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
- `object, but received: ${args}`);
- }
- }
- /**
- * Regularizer base class.
- */
- class Regularizer extends Serializable {
- }
- class L1L2 extends Regularizer {
- constructor(args) {
- super();
- assertObjectArgs(args);
- this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
- this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
- this.hasL1 = this.l1 !== 0;
- this.hasL2 = this.l2 !== 0;
- }
- /**
- * Porting note: Renamed from __call__.
- * @param x Variable of which to calculate the regularization score.
- */
- apply(x) {
- return tidy(() => {
- let regularization = zeros([1]);
- if (this.hasL1) {
- regularization = add$1(regularization, sum$1(mul(this.l1, abs(x))));
- }
- if (this.hasL2) {
- regularization =
- add$1(regularization, sum$1(mul(this.l2, square$1(x))));
- }
- return regularization.asScalar();
- });
- }
- getConfig() {
- return { 'l1': this.l1, 'l2': this.l2 };
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls({ l1: config['l1'], l2: config['l2'] });
- }
- }
- /** @nocollapse */
- L1L2.className = 'L1L2';
- registerClass(L1L2);
- function l1(args) {
- assertObjectArgs(args);
- return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 });
- }
- function l2(args) {
- assertObjectArgs(args);
- return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 });
- }
- // Maps the JavaScript-like identifier keys to the corresponding keras symbols.
- const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
- 'l1l2': 'L1L2'
- };
- function serializeRegularizer(constraint) {
- return serializeKerasObject(constraint);
- }
- function deserializeRegularizer(config, customObjects = {}) {
- return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
- }
- function getRegularizer(identifier) {
- if (identifier == null) {
- return null;
- }
- if (typeof identifier === 'string') {
- const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
- REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
- identifier;
- const config = { className, config: {} };
- return deserializeRegularizer(config);
- }
- else if (identifier instanceof Regularizer) {
- return identifier;
- }
- else {
- return deserializeRegularizer(identifier);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class ReLU extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.supportsMasking = true;
- if (args != null) {
- this.maxValue = args.maxValue;
- }
- }
- call(inputs, kwargs) {
- inputs = getExactlyOneTensor(inputs);
- let output = relu(inputs);
- if (this.maxValue != null) {
- output = clipByValue(output, 0, this.maxValue);
- }
- return output;
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const config = { maxValue: this.maxValue };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- ReLU.className = 'ReLU';
- registerClass(ReLU);
- class LeakyReLU extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.DEFAULT_ALPHA = 0.3;
- if (args == null) {
- args = {};
- }
- this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
- }
- call(inputs, kwargs) {
- const x = getExactlyOneTensor(inputs);
- return leakyRelu(x, this.alpha);
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const config = { alpha: this.alpha };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- LeakyReLU.className = 'LeakyReLU';
- registerClass(LeakyReLU);
- class PReLU extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.DEFAULT_ALPHA_INITIALIZER = 'zeros';
- if (args == null) {
- args = {};
- }
- this.supportsMasking = true;
- this.alphaInitializer =
- getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);
- this.alphaRegularizer = getRegularizer(args.alphaRegularizer);
- this.alphaConstraint = getConstraint(args.alphaConstraint);
- if (args.sharedAxes == null) {
- this.sharedAxes = null;
- }
- else if (Array.isArray(args.sharedAxes)) {
- this.sharedAxes = args.sharedAxes;
- }
- else if (typeof args.sharedAxes === 'number') {
- this.sharedAxes = [args.sharedAxes];
- }
- else {
- throw new ValueError(`Expected sharedAxes to be a number or an array of numbers, ` +
- `but got ${args.sharedAxes}`);
- }
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const paramShape = inputShape.slice(1);
- if (this.sharedAxes != null) {
- for (const i of this.sharedAxes) {
- paramShape[i - 1] = 1;
- }
- }
- this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
- // Set input spec.
- const axes = {};
- if (this.sharedAxes != null) {
- for (let i = 1; i < inputShape.length; ++i) {
- axes[i] = inputShape[i];
- }
- }
- this.inputSpec = [new InputSpec({
- ndim: inputShape.length,
- axes,
- })];
- this.built = true;
- }
- call(inputs, kwargs) {
- inputs = getExactlyOneTensor(inputs);
- return prelu(inputs, this.alpha.read());
- }
- getConfig() {
- const config = {
- alphaInitializer: serializeInitializer(this.alphaInitializer),
- alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
- alphaConstraint: serializeConstraint(this.alphaConstraint),
- sharedAxes: this.sharedAxes
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- PReLU.className = 'PReLU';
- registerClass(PReLU);
- class ELU extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.DEFAULT_ALPHA = 1.0;
- if (args == null) {
- args = {};
- }
- if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {
- throw new NotImplementedError(`Non-default alpha value (${args.alpha}) is not supported by the ` +
- `ELU layer yet.`);
- }
- this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
- }
- call(inputs, kwargs) {
- const x = getExactlyOneTensor(inputs);
- return elu(x);
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const config = { alpha: this.alpha };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- ELU.className = 'ELU';
- registerClass(ELU);
- class ThresholdedReLU extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.DEFAULT_THETA = 1.0;
- if (args == null) {
- args = {};
- }
- this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;
- }
- call(inputs, kwargs) {
- const x = getExactlyOneTensor(inputs);
- return x.mul(cast$1(x.greater(this.theta), 'float32'));
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const config = { theta: this.theta };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- ThresholdedReLU.className = 'ThresholdedReLU';
- registerClass(ThresholdedReLU);
- class Softmax$2 extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.DEFAULT_AXIS = 1.0;
- if (args == null) {
- args = {};
- }
- this.softmax = new Softmax$1().apply;
- this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
- }
- call(inputs, kwargs) {
- const x = getExactlyOneTensor(inputs);
- return this.softmax(x, this.axis);
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const config = { axis: this.axis };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Softmax$2.className = 'Softmax';
- registerClass(Softmax$2);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Transforms a single number of array of numbers into an array of numbers.
- * @param value
- * @param n: The size of the tuple to be returned.
- * @param name: Name of the parameter, used for generating error messages.
- * @returns An array of numbers.
- */
- function normalizeArray(value, n, name) {
- if (typeof value === 'number') {
- return pyListRepeat(value, n);
- }
- else {
- if (value.length !== n) {
- throw new ValueError(`The ${name} argument must be an integer or tuple of ${n} integers.` +
- ` Received: ${value.length} elements.`);
- }
- for (let i = 0; i < n; ++i) {
- const singleValue = value[i];
- if (!isInteger(singleValue)) {
- throw new ValueError(`The ${name} argument must be an integer or tuple of ${n}` +
- ` integers. Received: ${JSON.stringify(value)} including a` +
- ` non-integer number ${singleValue}`);
- }
- }
- return value;
- }
- }
- /**
- * Determines output length of a convolution given input length.
- * @param inputLength
- * @param filterSize
- * @param padding
- * @param stride
- * @param dilation: dilation rate.
- */
- function convOutputLength(inputLength, filterSize, padding, stride, dilation = 1) {
- if (inputLength == null) {
- return inputLength;
- }
- const dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
- let outputLength;
- if (padding === 'same') {
- outputLength = inputLength;
- }
- else { // VALID
- outputLength = inputLength - dilatedFilterSize + 1;
- }
- return Math.floor((outputLength + stride - 1) / stride);
- }
- function deconvLength(dimSize, strideSize, kernelSize, padding) {
- if (dimSize == null) {
- return null;
- }
- if (padding === 'valid') {
- dimSize = dimSize * strideSize + max$1([kernelSize - strideSize, 0]);
- }
- else if (padding === 'same') {
- dimSize = dimSize * strideSize;
- }
- else {
- throw new ValueError(`Unsupport padding mode: ${padding}.`);
- }
- return dimSize;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Transpose and cast the input before the conv2d.
- * @param x Input image tensor.
- * @param dataFormat
- */
- function preprocessConv2DInput(x, dataFormat) {
- // TODO(cais): Cast type to float32 if not.
- return tidy(() => {
- checkDataFormat(dataFormat);
- if (dataFormat === 'channelsFirst') {
- return transpose(x, [0, 2, 3, 1]); // NCHW -> NHWC.
- }
- else {
- return x;
- }
- });
- }
- /**
- * Transpose and cast the input before the conv3d.
- * @param x Input image tensor.
- * @param dataFormat
- */
- function preprocessConv3DInput(x, dataFormat) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- if (dataFormat === 'channelsFirst') {
- return transpose(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC.
- }
- else {
- return x;
- }
- });
- }
- /**
- * 1D-convolution with bias added.
- *
- * Porting Note: This function does not exist in the Python Keras backend.
- * It is exactly the same as `conv2d`, except the added `bias`.
- *
- * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
- * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.
- * @param bias Bias, rank-3, of shape `[outDepth]`.
- * @param strides
- * @param padding Padding mode.
- * @param dataFormat Data format.
- * @param dilationRate
- * @returns The result of the 1D convolution.
- * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
- */
- function conv1dWithBias(x, kernel, bias, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
- return tidy(() => {
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- checkDataFormat(dataFormat);
- // Check the ranks of x, kernel and bias.
- if (x.shape.length !== 3) {
- throw new ValueError(`The input of a conv1dWithBias operation should be 3, but is ` +
- `${x.shape.length} instead.`);
- }
- if (kernel.shape.length !== 3) {
- throw new ValueError(`The kernel for a conv1dWithBias operation should be 3, but is ` +
- `${kernel.shape.length} instead`);
- }
- if (bias != null && bias.shape.length !== 1) {
- throw new ValueError(`The bias for a conv1dWithBias operation should be 1, but is ` +
- `${kernel.shape.length} instead`);
- }
- // TODO(cais): Support CAUSAL padding mode.
- if (dataFormat === 'channelsFirst') {
- x = transpose(x, [0, 2, 1]); // NCW -> NWC.
- }
- if (padding === 'causal') {
- throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
- 'implemented yet.');
- }
- let y = conv1d(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate);
- if (bias != null) {
- y = biasAdd(y, bias);
- }
- return y;
- });
- }
- /**
- * 1D-convolution.
- *
- * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
- * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.s
- * @param strides
- * @param padding Padding mode.
- * @param dataFormat Data format.
- * @param dilationRate
- * @returns The result of the 1D convolution.
- * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
- */
- function conv1d$1(x, kernel, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
- });
- }
- /**
- * 2D Convolution
- * @param x
- * @param kernel kernel of the convolution.
- * @param strides strides array.
- * @param padding padding mode. Default to 'valid'.
- * @param dataFormat data format. Defaults to 'channelsLast'.
- * @param dilationRate dilation rate array.
- * @returns Result of the 2D pooling.
- */
- function conv2d$2(x, kernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- return conv2dWithBiasActivation(x, kernel, null, strides, padding, dataFormat, dilationRate);
- });
- }
- /**
- * 2D Convolution with an added bias and optional activation.
- * Note: This function does not exist in the Python Keras Backend. This function
- * is exactly the same as `conv2d`, except the added `bias`.
- */
- function conv2dWithBiasActivation(x, kernel, bias, strides = [1, 1], padding = 'valid', dataFormat, dilationRate, activation = null) {
- return tidy(() => {
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- checkDataFormat(dataFormat);
- if (x.rank !== 3 && x.rank !== 4) {
- throw new ValueError(`conv2dWithBiasActivation expects input to be of rank 3 or 4, ` +
- `but received ${x.rank}.`);
- }
- if (kernel.rank !== 3 && kernel.rank !== 4) {
- throw new ValueError(`conv2dWithBiasActivation expects kernel to be of rank 3 or 4, ` +
- `but received ${x.rank}.`);
- }
- let y = preprocessConv2DInput(x, dataFormat);
- if (padding === 'causal') {
- throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
- 'implemented yet.');
- }
- y = conv2d$1({
- x: y,
- filter: kernel,
- strides: strides,
- pad: padding === 'same' ? 'same' : 'valid',
- dilations: dilationRate,
- dataFormat: 'NHWC',
- bias,
- activation
- });
- if (dataFormat === 'channelsFirst') {
- y = transpose(y, [0, 3, 1, 2]);
- }
- return y;
- });
- }
- /**
- * 3D Convolution.
- * @param x
- * @param kernel kernel of the convolution.
- * @param strides strides array.
- * @param padding padding mode. Default to 'valid'.
- * @param dataFormat data format. Defaults to 'channelsLast'.
- * @param dilationRate dilation rate array.
- * @returns Result of the 3D convolution.
- */
- function conv3d$1(x, kernel, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- return conv3dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
- });
- }
- /**
- * 3D Convolution with an added bias.
- * Note: This function does not exist in the Python Keras Backend. This function
- * is exactly the same as `conv3d`, except the added `bias`.
- */
- function conv3dWithBias(x, kernel, bias, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
- return tidy(() => {
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- checkDataFormat(dataFormat);
- if (x.rank !== 4 && x.rank !== 5) {
- throw new ValueError(`conv3dWithBias expects input to be of rank 4 or 5, but received ` +
- `${x.rank}.`);
- }
- if (kernel.rank !== 4 && kernel.rank !== 5) {
- throw new ValueError(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ` +
- `${x.rank}.`);
- }
- let y = preprocessConv3DInput(x, dataFormat);
- if (padding === 'causal') {
- throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' +
- 'implemented yet.');
- }
- y = conv3d(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate);
- if (bias != null) {
- y = biasAdd(y, bias);
- }
- if (dataFormat === 'channelsFirst') {
- y = transpose(y, [0, 4, 1, 2, 3]);
- }
- return y;
- });
- }
- /**
- * Abstract convolution layer.
- */
- class BaseConv extends Layer {
- constructor(rank, args) {
- super(args);
- this.bias = null;
- this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
- this.DEFAULT_BIAS_INITIALIZER = 'zeros';
- BaseConv.verifyArgs(args);
- this.rank = rank;
- assertPositiveInteger(this.rank, 'rank');
- if (this.rank !== 1 && this.rank !== 2 && this.rank !== 3) {
- throw new NotImplementedError(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is ` +
- `not implemented yet.`);
- }
- this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
- this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides');
- this.padding = args.padding == null ? 'valid' : args.padding;
- checkPaddingMode(this.padding);
- this.dataFormat =
- args.dataFormat == null ? 'channelsLast' : args.dataFormat;
- checkDataFormat(this.dataFormat);
- this.activation = getActivation(args.activation);
- this.useBias = args.useBias == null ? true : args.useBias;
- this.biasInitializer =
- getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
- this.biasConstraint = getConstraint(args.biasConstraint);
- this.biasRegularizer = getRegularizer(args.biasRegularizer);
- this.activityRegularizer = getRegularizer(args.activityRegularizer);
- this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate');
- if (this.rank === 1 &&
- (Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)) {
- throw new ValueError(`dilationRate must be a number or an array of a single number ` +
- `for 1D convolution, but received ` +
- `${JSON.stringify(this.dilationRate)}`);
- }
- else if (this.rank === 2) {
- if (typeof this.dilationRate === 'number') {
- this.dilationRate = [this.dilationRate, this.dilationRate];
- }
- else if (this.dilationRate.length !== 2) {
- throw new ValueError(`dilationRate must be a number or array of two numbers for 2D ` +
- `convolution, but received ${JSON.stringify(this.dilationRate)}`);
- }
- }
- else if (this.rank === 3) {
- if (typeof this.dilationRate === 'number') {
- this.dilationRate =
- [this.dilationRate, this.dilationRate, this.dilationRate];
- }
- else if (this.dilationRate.length !== 3) {
- throw new ValueError(`dilationRate must be a number or array of three numbers for 3D ` +
- `convolution, but received ${JSON.stringify(this.dilationRate)}`);
- }
- }
- }
- static verifyArgs(args) {
- // Check config.kernelSize type and shape.
- assert$1('kernelSize' in args, `required key 'kernelSize' not in config`);
- if (typeof args.kernelSize !== 'number' &&
- !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) {
- throw new ValueError(`BaseConv expects config.kernelSize to be number or number[] with ` +
- `length 1, 2, or 3, but received ${JSON.stringify(args.kernelSize)}.`);
- }
- }
- getConfig() {
- const config = {
- kernelSize: this.kernelSize,
- strides: this.strides,
- padding: this.padding,
- dataFormat: this.dataFormat,
- dilationRate: this.dilationRate,
- activation: serializeActivation(this.activation),
- useBias: this.useBias,
- biasInitializer: serializeInitializer(this.biasInitializer),
- biasRegularizer: serializeRegularizer(this.biasRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- biasConstraint: serializeConstraint(this.biasConstraint)
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /**
- * Abstract nD convolution layer. Ancestor of convolution layers which reduce
- * across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D.
- */
- class Conv extends BaseConv {
- constructor(rank, args) {
- super(rank, args);
- this.kernel = null;
- Conv.verifyArgs(args);
- this.filters = args.filters;
- assertPositiveInteger(this.filters, 'filters');
- this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.kernelConstraint = getConstraint(args.kernelConstraint);
- this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
- if (inputShape[channelAxis] == null) {
- throw new ValueError(`The channel dimension of the input should be defined. ` +
- `Found ${inputShape[channelAxis]}`);
- }
- const inputDim = inputShape[channelAxis];
- const kernelShape = this.kernelSize.concat([inputDim, this.filters]);
- this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- this.inputSpec = [{ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } }];
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = getExactlyOneTensor(inputs);
- let outputs;
- const biasValue = this.bias == null ? null : this.bias.read();
- const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
- if (fusedActivationName != null && this.rank === 2) {
- outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate, fusedActivationName);
- }
- else {
- if (this.rank === 1) {
- outputs = conv1dWithBias(inputs, this.kernel.read(), biasValue, this.strides[0], this.padding, this.dataFormat, this.dilationRate[0]);
- }
- else if (this.rank === 2) {
- // TODO(cais): Move up to constructor.
- outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
- }
- else if (this.rank === 3) {
- outputs = conv3dWithBias(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
- }
- else {
- throw new NotImplementedError('convolutions greater than 3D are not implemented yet.');
- }
- if (this.activation != null) {
- outputs = this.activation.apply(outputs);
- }
- }
- return outputs;
- });
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const newSpace = [];
- const space = (this.dataFormat === 'channelsLast') ?
- inputShape.slice(1, inputShape.length - 1) :
- inputShape.slice(2);
- for (let i = 0; i < space.length; ++i) {
- const newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate :
- this.dilationRate[i]);
- newSpace.push(newDim);
- }
- let outputShape = [inputShape[0]];
- if (this.dataFormat === 'channelsLast') {
- outputShape = outputShape.concat(newSpace);
- outputShape.push(this.filters);
- }
- else {
- outputShape.push(this.filters);
- outputShape = outputShape.concat(newSpace);
- }
- return outputShape;
- }
- getConfig() {
- const config = {
- filters: this.filters,
- kernelInitializer: serializeInitializer(this.kernelInitializer),
- kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
- kernelConstraint: serializeConstraint(this.kernelConstraint)
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- static verifyArgs(args) {
- // Check config.filters type, shape, and value.
- if (!('filters' in args) || typeof args.filters !== 'number' ||
- args.filters < 1) {
- throw new ValueError(`Convolution layer expected config.filters to be a 'number' > 0 ` +
- `but got ${JSON.stringify(args.filters)}`);
- }
- }
- }
- class Conv2D$1 extends Conv {
- constructor(args) {
- super(2, args);
- Conv2D$1.verifyArgs(args);
- }
- getConfig() {
- const config = super.getConfig();
- delete config['rank'];
- return config;
- }
- static verifyArgs(args) {
- // config.kernelSize must be a number or array of numbers.
- if ((typeof args.kernelSize !== 'number') &&
- !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) {
- throw new ValueError(`Conv2D expects config.kernelSize to be number or number[] with ` +
- `length 1 or 2, but received ${JSON.stringify(args.kernelSize)}.`);
- }
- }
- }
- /** @nocollapse */
- Conv2D$1.className = 'Conv2D';
- registerClass(Conv2D$1);
- class Conv3D$1 extends Conv {
- constructor(args) {
- super(3, args);
- Conv3D$1.verifyArgs(args);
- }
- getConfig() {
- const config = super.getConfig();
- delete config['rank'];
- return config;
- }
- static verifyArgs(args) {
- // config.kernelSize must be a number or array of numbers.
- if (typeof args.kernelSize !== 'number') {
- if (!(Array.isArray(args.kernelSize) &&
- (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
- throw new ValueError(`Conv3D expects config.kernelSize to be number or` +
- ` [number, number, number], but received ${JSON.stringify(args.kernelSize)}.`);
- }
- }
- }
- }
- /** @nocollapse */
- Conv3D$1.className = 'Conv3D';
- registerClass(Conv3D$1);
- class Conv2DTranspose extends Conv2D$1 {
- constructor(args) {
- super(args);
- this.inputSpec = [new InputSpec({ ndim: 4 })];
- if (this.padding !== 'same' && this.padding !== 'valid') {
- throw new ValueError(`Conv2DTranspose currently supports only padding modes 'same' ` +
- `and 'valid', but received padding mode ${this.padding}`);
- }
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- if (inputShape.length !== 4) {
- throw new ValueError('Input should have rank 4; Received input shape: ' +
- JSON.stringify(inputShape));
- }
- const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
- if (inputShape[channelAxis] == null) {
- throw new ValueError('The channel dimension of the inputs should be defined. ' +
- 'Found `None`.');
- }
- const inputDim = inputShape[channelAxis];
- const kernelShape = this.kernelSize.concat([this.filters, inputDim]);
- this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- // Set input spec.
- this.inputSpec =
- [new InputSpec({ ndim: 4, axes: { [channelAxis]: inputDim } })];
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- let input = getExactlyOneTensor(inputs);
- if (input.shape.length !== 4) {
- throw new ValueError(`Conv2DTranspose.call() expects input tensor to be rank-4, but ` +
- `received a tensor of rank-${input.shape.length}`);
- }
- const inputShape = input.shape;
- const batchSize = inputShape[0];
- let hAxis;
- let wAxis;
- if (this.dataFormat === 'channelsFirst') {
- hAxis = 2;
- wAxis = 3;
- }
- else {
- hAxis = 1;
- wAxis = 2;
- }
- const height = inputShape[hAxis];
- const width = inputShape[wAxis];
- const kernelH = this.kernelSize[0];
- const kernelW = this.kernelSize[1];
- const strideH = this.strides[0];
- const strideW = this.strides[1];
- // Infer the dynamic output shape.
- const outHeight = deconvLength(height, strideH, kernelH, this.padding);
- const outWidth = deconvLength(width, strideW, kernelW, this.padding);
- // Porting Note: We don't branch based on `this.dataFormat` here,
- // because
- // the tjfs-core function `conv2dTranspose` called below always
- // assumes channelsLast.
- const outputShape = [batchSize, outHeight, outWidth, this.filters];
- if (this.dataFormat !== 'channelsLast') {
- input = transpose(input, [0, 2, 3, 1]);
- }
- let outputs = conv2dTranspose(input, this.kernel.read(), outputShape, this.strides, this.padding);
- if (this.dataFormat !== 'channelsLast') {
- outputs = transpose(outputs, [0, 3, 1, 2]);
- }
- if (this.bias != null) {
- outputs =
- biasAdd(outputs, this.bias.read(), this.dataFormat);
- }
- if (this.activation != null) {
- outputs = this.activation.apply(outputs);
- }
- return outputs;
- });
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const outputShape = inputShape.slice();
- let channelAxis;
- let heightAxis;
- let widthAxis;
- if (this.dataFormat === 'channelsFirst') {
- channelAxis = 1;
- heightAxis = 2;
- widthAxis = 3;
- }
- else {
- channelAxis = 3;
- heightAxis = 1;
- widthAxis = 2;
- }
- const kernelH = this.kernelSize[0];
- const kernelW = this.kernelSize[1];
- const strideH = this.strides[0];
- const strideW = this.strides[1];
- outputShape[channelAxis] = this.filters;
- outputShape[heightAxis] =
- deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
- outputShape[widthAxis] =
- deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
- return outputShape;
- }
- getConfig() {
- const config = super.getConfig();
- delete config['dilationRate'];
- return config;
- }
- }
- /** @nocollapse */
- Conv2DTranspose.className = 'Conv2DTranspose';
- registerClass(Conv2DTranspose);
- class SeparableConv extends Conv {
- constructor(rank, config) {
- super(rank, config);
- this.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform';
- this.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform';
- this.depthwiseKernel = null;
- this.pointwiseKernel = null;
- if (config.filters == null) {
- throw new ValueError('The `filters` configuration field is required by SeparableConv, ' +
- 'but is unspecified.');
- }
- if (config.kernelInitializer != null || config.kernelRegularizer != null ||
- config.kernelConstraint != null) {
- throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' +
- 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' +
- 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' +
- 'pointwiseRegularizer and pointwiseConstraint instead.');
- }
- if (config.padding != null && config.padding !== 'same' &&
- config.padding !== 'valid') {
- throw new ValueError(`SeparableConv${this.rank}D supports only padding modes: ` +
- `'same' and 'valid', but received ${JSON.stringify(config.padding)}`);
- }
- this.depthMultiplier =
- config.depthMultiplier == null ? 1 : config.depthMultiplier;
- this.depthwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER);
- this.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
- this.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
- this.pointwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER);
- this.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
- this.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- if (inputShape.length < this.rank + 2) {
- throw new ValueError(`Inputs to SeparableConv${this.rank}D should have rank ` +
- `${this.rank + 2}, but received input shape: ` +
- `${JSON.stringify(inputShape)}`);
- }
- const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
- if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
- throw new ValueError(`The channel dimension of the inputs should be defined, ` +
- `but found ${JSON.stringify(inputShape[channelAxis])}`);
- }
- const inputDim = inputShape[channelAxis];
- const depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
- const pointwiseKernelShape = [];
- for (let i = 0; i < this.rank; ++i) {
- pointwiseKernelShape.push(1);
- }
- pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
- const trainable = true;
- this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
- this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
- }
- else {
- this.bias = null;
- }
- this.inputSpec =
- [new InputSpec({ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } })];
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = getExactlyOneTensor(inputs);
- let output;
- if (this.rank === 1) {
- throw new NotImplementedError('1D separable convolution is not implemented yet.');
- }
- else if (this.rank === 2) {
- if (this.dataFormat === 'channelsFirst') {
- inputs = transpose(inputs, [0, 2, 3, 1]); // NCHW -> NHWC.
- }
- output = separableConv2d(inputs, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, 'NHWC');
- }
- if (this.useBias) {
- output = biasAdd(output, this.bias.read(), this.dataFormat);
- }
- if (this.activation != null) {
- output = this.activation.apply(output);
- }
- if (this.dataFormat === 'channelsFirst') {
- output = transpose(output, [0, 3, 1, 2]); // NHWC -> NCHW.
- }
- return output;
- });
- }
- getConfig() {
- const config = super.getConfig();
- delete config['rank'];
- delete config['kernelInitializer'];
- delete config['kernelRegularizer'];
- delete config['kernelConstraint'];
- config['depthwiseInitializer'] =
- serializeInitializer(this.depthwiseInitializer);
- config['pointwiseInitializer'] =
- serializeInitializer(this.pointwiseInitializer);
- config['depthwiseRegularizer'] =
- serializeRegularizer(this.depthwiseRegularizer);
- config['pointwiseRegularizer'] =
- serializeRegularizer(this.pointwiseRegularizer);
- config['depthwiseConstraint'] =
- serializeConstraint(this.depthwiseConstraint);
- config['pointwiseConstraint'] =
- serializeConstraint(this.pointwiseConstraint);
- return config;
- }
- }
- /** @nocollapse */
- SeparableConv.className = 'SeparableConv';
- class SeparableConv2D extends SeparableConv {
- constructor(args) {
- super(2, args);
- }
- }
- /** @nocollapse */
- SeparableConv2D.className = 'SeparableConv2D';
- registerClass(SeparableConv2D);
- class Conv1D extends Conv {
- constructor(args) {
- super(1, args);
- Conv1D.verifyArgs(args);
- this.inputSpec = [{ ndim: 3 }];
- }
- getConfig() {
- const config = super.getConfig();
- delete config['rank'];
- delete config['dataFormat'];
- return config;
- }
- static verifyArgs(args) {
- // config.kernelSize must be a number or array of numbers.
- if (typeof args.kernelSize !== 'number' &&
- !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) {
- throw new ValueError(`Conv1D expects config.kernelSize to be number or number[] with ` +
- `length 1, but received ${JSON.stringify(args.kernelSize)}.`);
- }
- }
- }
- /** @nocollapse */
- Conv1D.className = 'Conv1D';
- registerClass(Conv1D);
- class Cropping2D extends Layer {
- constructor(args) {
- super(args);
- if (typeof args.cropping === 'number') {
- this.cropping =
- [[args.cropping, args.cropping], [args.cropping, args.cropping]];
- }
- else if (typeof args.cropping[0] === 'number') {
- this.cropping = [
- [args.cropping[0], args.cropping[0]],
- [args.cropping[1], args.cropping[1]]
- ];
- }
- else {
- this.cropping = args.cropping;
- }
- this.dataFormat =
- args.dataFormat === undefined ? 'channelsLast' : args.dataFormat;
- this.inputSpec = [{ ndim: 4 }];
- }
- computeOutputShape(inputShape) {
- if (this.dataFormat === 'channelsFirst') {
- return [
- inputShape[0], inputShape[1],
- inputShape[2] - this.cropping[0][0] - this.cropping[0][1],
- inputShape[3] - this.cropping[1][0] - this.cropping[1][1]
- ];
- }
- else {
- return [
- inputShape[0],
- inputShape[1] - this.cropping[0][0] - this.cropping[0][1],
- inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3]
- ];
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = getExactlyOneTensor(inputs);
- if (this.dataFormat === 'channelsLast') {
- const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[1] - this.cropping[0][0] - this.cropping[0][1], 2);
- return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[2] - this.cropping[1][1] - this.cropping[1][0], 3);
- }
- else {
- const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[2] - this.cropping[0][0] - this.cropping[0][1], 3);
- return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[3] - this.cropping[1][1] - this.cropping[1][0], 4);
- }
- });
- }
- getConfig() {
- const config = { cropping: this.cropping, dataFormat: this.dataFormat };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Cropping2D.className = 'Cropping2D';
- registerClass(Cropping2D);
- class UpSampling2D extends Layer {
- constructor(args) {
- super(args);
- this.DEFAULT_SIZE = [2, 2];
- this.inputSpec = [{ ndim: 4 }];
- this.size = args.size == null ? this.DEFAULT_SIZE : args.size;
- this.dataFormat =
- args.dataFormat == null ? 'channelsLast' : args.dataFormat;
- }
- computeOutputShape(inputShape) {
- if (this.dataFormat === 'channelsFirst') {
- const height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
- const width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
- return [inputShape[0], inputShape[1], height, width];
- }
- else {
- const height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
- const width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
- return [inputShape[0], height, width, inputShape[3]];
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- let input = getExactlyOneTensor(inputs);
- const inputShape = input.shape;
- if (this.dataFormat === 'channelsFirst') {
- input = transpose(input, [0, 2, 3, 1]);
- const height = this.size[0] * inputShape[2];
- const width = this.size[1] * inputShape[3];
- const resized = input.resizeNearestNeighbor([height, width]);
- return transpose(resized, [0, 3, 1, 2]);
- }
- else {
- const height = this.size[0] * inputShape[1];
- const width = this.size[1] * inputShape[2];
- return input.resizeNearestNeighbor([height, width]);
- }
- });
- }
- getConfig() {
- const config = { size: this.size, dataFormat: this.dataFormat };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- UpSampling2D.className = 'UpSampling2D';
- registerClass(UpSampling2D);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * 2D convolution with separable filters.
- * @param x Input tensor.
- * @param depthwiseKernel Convolution kernel for depthwise convolution.
- * @param strides Strides (Array of two integers).
- * @param padding Padding model.
- * @param dataFormat Data format.
- * @param dilationRate Array of two integers, dilation rates for the separable
- * convolution.
- * @returns Output tensor.
- * @throws ValueError If depthwiseKernel is not a 4D array.
- */
- function depthwiseConv2d$2(x, depthwiseKernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
- return tidy(() => {
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- checkDataFormat(dataFormat);
- let y = preprocessConv2DInput(x, dataFormat);
- if (x.rank !== 4) {
- throw new ValueError(`Input for depthwiseConv2d is required to be 4-D, but is instead ` +
- `${x.rank}-D`);
- }
- if (depthwiseKernel.rank !== 4) {
- throw new ValueError(`depthwiseKernel is required to be 4-D, but is instead ` +
- `${depthwiseKernel.rank}-D`);
- }
- y = depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
- if (dataFormat === 'channelsFirst') {
- y = transpose(y, [0, 3, 1, 2]);
- }
- return y;
- });
- }
- class DepthwiseConv2D extends BaseConv {
- constructor(args) {
- super(2, args);
- this.depthwiseKernel = null;
- this.depthMultiplier =
- args.depthMultiplier == null ? 1 : args.depthMultiplier;
- this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
- this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- if (inputShape.length < 4) {
- throw new ValueError(`Inputs to DepthwiseConv2D should have rank 4. ` +
- `Received input shape: ${JSON.stringify(inputShape)}.`);
- }
- const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
- if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
- throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' +
- `be defined, but is not (${inputShape[channelAxis]}).`);
- }
- const inputDim = inputShape[channelAxis];
- const depthwiseKernelShape = [
- this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier
- ];
- this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- else {
- this.bias = null;
- }
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = getExactlyOneTensor(inputs);
- let outputs = depthwiseConv2d$2(inputs, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
- // TODO(cais): Add support for dilation.
- if (this.useBias) {
- outputs = biasAdd(outputs, this.bias.read(), this.dataFormat);
- }
- if (this.activation != null) {
- outputs = this.activation.apply(outputs);
- }
- return outputs;
- });
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
- const cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
- const outFilters = this.dataFormat === 'channelsFirst' ?
- inputShape[1] * this.depthMultiplier :
- inputShape[3] * this.depthMultiplier;
- const outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
- const outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
- if (this.dataFormat === 'channelsFirst') {
- return [inputShape[0], outFilters, outRows, outCols];
- }
- else {
- // In this case, assume 'channelsLast'.
- return [inputShape[0], outRows, outCols, outFilters];
- }
- }
- getConfig() {
- const config = super.getConfig();
- config['depthMultiplier'] = this.depthMultiplier;
- config['depthwiseInitializer'] =
- serializeInitializer(this.depthwiseInitializer);
- config['depthwiseRegularizer'] =
- serializeRegularizer(this.depthwiseRegularizer);
- config['depthwiseConstraint'] =
- serializeConstraint(this.depthwiseRegularizer);
- return config;
- }
- }
- /** @nocollapse */
- DepthwiseConv2D.className = 'DepthwiseConv2D';
- registerClass(DepthwiseConv2D);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Standardize `apply()` args to a single list of tensor inputs.
- *
- * When running a model loaded from file, the input tensors `initialState` and
- * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
- * dedicated kwargs fields. `inputs` consists of
- * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
- * case.
- * This method makes sure that arguments are
- * separated and that `initialState` and `constants` are `Array`s of tensors
- * (or None).
- *
- * @param inputs Tensor or `Array` of tensors.
- * @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
- * @param constants Tensor or `Array` of tensors or `null`/`undefined`.
- * @returns An object consisting of
- * inputs: A tensor.
- * initialState: `Array` of tensors or `null`.
- * constants: `Array` of tensors or `null`.
- * @throws ValueError, if `inputs` is an `Array` but either `initialState` or
- * `constants` is provided.
- */
- function standardizeArgs(inputs, initialState, constants, numConstants) {
- if (Array.isArray(inputs)) {
- if (initialState != null || constants != null) {
- throw new ValueError('When inputs is an array, neither initialState or constants ' +
- 'should be provided');
- }
- if (numConstants != null) {
- constants = inputs.slice(inputs.length - numConstants, inputs.length);
- inputs = inputs.slice(0, inputs.length - numConstants);
- }
- if (inputs.length > 1) {
- initialState = inputs.slice(1, inputs.length);
- }
- inputs = inputs[0];
- }
- function toListOrNull(x) {
- if (x == null || Array.isArray(x)) {
- return x;
- }
- else {
- return [x];
- }
- }
- initialState = toListOrNull(initialState);
- constants = toListOrNull(constants);
- return { inputs, initialState, constants };
- }
- /**
- * Iterates over the time dimension of a tensor.
- *
- * @param stepFunction RNN step function.
- * Parameters:
- * inputs: tensor with shape `[samples, ...]` (no time dimension),
- * representing input for the batch of samples at a certain time step.
- * states: an Array of tensors.
- * Returns:
- * outputs: tensor with shape `[samples, outputDim]` (no time dimension).
- * newStates: list of tensors, same length and shapes as `states`. The first
- * state in the list must be the output tensor at the previous timestep.
- * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
- * least 3D).
- * @param initialStates Tensor with shape `[samples, outputDim]` (no time
- * dimension), containing the initial values of the states used in the step
- * function.
- * @param goBackwards If `true`, do the iteration over the time dimension in
- * reverse order and return the reversed sequence.
- * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
- * every element that is masked.
- * @param constants An Array of constant values passed at each step.
- * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
- * applicable to this imperative deeplearn.js backend. Its value is ignored.
- * @param needPerStepOutputs Whether the per-step outputs are to be
- * concatenated into a single tensor and returned (as the second return
- * value). Default: `false`. This arg is included so that the relatively
- * expensive concatenation of the stepwise outputs can be omitted unless
- * the stepwise outputs need to be kept (e.g., for an LSTM layer of which
- * `returnSequence` is `true`.)
- * @returns An Array: `[lastOutput, outputs, newStates]`.
- * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
- * outputs: tensor with shape `[samples, time, ...]` where each entry
- * `output[s, t]` is the output of the step function at time `t` for sample
- * `s`. This return value is provided if and only if the
- * `needPerStepOutputs` is set as `true`. If it is set as `false`, this
- * return value will be `undefined`.
- * newStates: Array of tensors, latest states returned by the step function,
- * of shape `(samples, ...)`.
- * @throws ValueError If input dimension is less than 3.
- *
- * TODO(nielsene): This needs to be tidy-ed.
- */
- function rnn(stepFunction, inputs, initialStates, goBackwards = false, mask, constants, unroll = false, needPerStepOutputs = false) {
- return tidy(() => {
- const ndim = inputs.shape.length;
- if (ndim < 3) {
- throw new ValueError(`Input should be at least 3D, but is ${ndim}D.`);
- }
- // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
- // ...].
- const axes = [1, 0].concat(range$1(2, ndim));
- inputs = transpose(inputs, axes);
- if (constants != null) {
- throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' +
- 'constants yet.');
- }
- // Porting Note: the unroll option is ignored by the imperative backend.
- if (unroll) {
- console.warn('Backend rnn(): the unroll = true option is not applicable to the ' +
- 'imperative deeplearn.js backend.');
- }
- if (mask != null) {
- mask = mask.asType('bool').asType('float32');
- if (mask.rank === ndim - 1) {
- mask = expandDims(mask, -1);
- }
- mask = transpose(mask, axes);
- }
- if (goBackwards) {
- inputs = reverse(inputs, 0);
- if (mask != null) {
- mask = reverse(mask, 0);
- }
- }
- // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
- // (tf.while_loop). But for the imperative deeplearn.js backend, we just
- // use the usual TypeScript control flow to iterate over the time steps in
- // the inputs.
- // Porting Note: PyKeras patches a "_use_learning_phase" attribute to
- // outputs.
- // This is not idiomatic in TypeScript. The info regarding whether we are
- // in a learning (i.e., training) phase for RNN is passed in a different
- // way.
- const perStepOutputs = [];
- let lastOutput;
- let states = initialStates;
- const timeSteps = inputs.shape[0];
- const perStepInputs = unstack(inputs);
- let perStepMasks;
- if (mask != null) {
- perStepMasks = unstack(mask);
- }
- for (let t = 0; t < timeSteps; ++t) {
- const currentInput = perStepInputs[t];
- const stepOutputs = tidy(() => stepFunction(currentInput, states));
- if (mask == null) {
- lastOutput = stepOutputs[0];
- states = stepOutputs[1];
- }
- else {
- const maskedOutputs = tidy(() => {
- const stepMask = perStepMasks[t];
- const negStepMask = onesLike(stepMask).sub(stepMask);
- // TODO(cais): Would tfc.where() be better for performance?
- const output = stepOutputs[0].mul(stepMask).add(states[0].mul(negStepMask));
- const newStates = states.map((state, i) => {
- return stepOutputs[1][i].mul(stepMask).add(state.mul(negStepMask));
- });
- return { output, newStates };
- });
- lastOutput = maskedOutputs.output;
- states = maskedOutputs.newStates;
- }
- if (needPerStepOutputs) {
- perStepOutputs.push(lastOutput);
- }
- }
- let outputs;
- if (needPerStepOutputs) {
- const axis = 1;
- outputs = stack(perStepOutputs, axis);
- }
- return [lastOutput, outputs, states];
- });
- }
- class RNN extends Layer {
- constructor(args) {
- super(args);
- let cell;
- if (args.cell == null) {
- throw new ValueError('cell property is missing for the constructor of RNN.');
- }
- else if (Array.isArray(args.cell)) {
- cell = new StackedRNNCells({ cells: args.cell });
- }
- else {
- cell = args.cell;
- }
- if (cell.stateSize == null) {
- throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' +
- 'integers, one integer per RNN state).');
- }
- this.cell = cell;
- this.returnSequences =
- args.returnSequences == null ? false : args.returnSequences;
- this.returnState = args.returnState == null ? false : args.returnState;
- this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
- this._stateful = args.stateful == null ? false : args.stateful;
- this.unroll = args.unroll == null ? false : args.unroll;
- this.supportsMasking = true;
- this.inputSpec = [new InputSpec({ ndim: 3 })];
- this.stateSpec = null;
- this.states_ = null;
- // TODO(cais): Add constantsSpec and numConstants.
- this.numConstants = null;
- // TODO(cais): Look into the use of initial_state in the kwargs of the
- // constructor.
- this.keptStates = [];
- }
- // Porting Note: This is the equivalent of `RNN.states` property getter in
- // PyKeras.
- getStates() {
- if (this.states_ == null) {
- const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
- return range$1(0, numStates).map(x => null);
- }
- else {
- return this.states_;
- }
- }
- // Porting Note: This is the equivalent of the `RNN.states` property setter in
- // PyKeras.
- setStates(states) {
- this.states_ = states;
- }
- computeOutputShape(inputShape) {
- if (isArrayOfShapes(inputShape)) {
- inputShape = inputShape[0];
- }
- inputShape = inputShape;
- // TODO(cais): Remove the casting once stacked RNN cells become supported.
- let stateSize = this.cell.stateSize;
- if (!Array.isArray(stateSize)) {
- stateSize = [stateSize];
- }
- const outputDim = stateSize[0];
- let outputShape;
- if (this.returnSequences) {
- outputShape = [inputShape[0], inputShape[1], outputDim];
- }
- else {
- outputShape = [inputShape[0], outputDim];
- }
- if (this.returnState) {
- const stateShape = [];
- for (const dim of stateSize) {
- stateShape.push([inputShape[0], dim]);
- }
- return [outputShape].concat(stateShape);
- }
- else {
- return outputShape;
- }
- }
- computeMask(inputs, mask) {
- return tidy(() => {
- if (Array.isArray(mask)) {
- mask = mask[0];
- }
- const outputMask = this.returnSequences ? mask : null;
- if (this.returnState) {
- const stateMask = this.states.map(s => null);
- return [outputMask].concat(stateMask);
- }
- else {
- return outputMask;
- }
- });
- }
- /**
- * Get the current state tensors of the RNN.
- *
- * If the state hasn't been set, return an array of `null`s of the correct
- * length.
- */
- get states() {
- if (this.states_ == null) {
- const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
- const output = [];
- for (let i = 0; i < numStates; ++i) {
- output.push(null);
- }
- return output;
- }
- else {
- return this.states_;
- }
- }
- set states(s) {
- this.states_ = s;
- }
- build(inputShape) {
- // Note inputShape will be an Array of Shapes of initial states and
- // constants if these are passed in apply().
- const constantShape = null;
- if (this.numConstants != null) {
- throw new NotImplementedError('Constants support is not implemented in RNN yet.');
- }
- if (isArrayOfShapes(inputShape)) {
- inputShape = inputShape[0];
- }
- inputShape = inputShape;
- const batchSize = this.stateful ? inputShape[0] : null;
- const inputDim = inputShape.slice(2);
- this.inputSpec[0] = new InputSpec({ shape: [batchSize, null, ...inputDim] });
- // Allow cell (if RNNCell Layer) to build before we set or validate
- // stateSpec.
- const stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
- if (constantShape != null) {
- throw new NotImplementedError('Constants support is not implemented in RNN yet.');
- }
- else {
- this.cell.build(stepInputShape);
- }
- // Set or validate stateSpec.
- let stateSize;
- if (Array.isArray(this.cell.stateSize)) {
- stateSize = this.cell.stateSize;
- }
- else {
- stateSize = [this.cell.stateSize];
- }
- if (this.stateSpec != null) {
- if (!arraysEqual(this.stateSpec.map(spec => spec.shape[spec.shape.length - 1]), stateSize)) {
- throw new ValueError(`An initialState was passed that is not compatible with ` +
- `cell.stateSize. Received stateSpec=${this.stateSpec}; ` +
- `However cell.stateSize is ${this.cell.stateSize}`);
- }
- }
- else {
- this.stateSpec =
- stateSize.map(dim => new InputSpec({ shape: [null, dim] }));
- }
- if (this.stateful) {
- this.resetStates();
- }
- }
- /**
- * Reset the state tensors of the RNN.
- *
- * If the `states` argument is `undefined` or `null`, will set the
- * state tensor(s) of the RNN to all-zero tensors of the appropriate
- * shape(s).
- *
- * If `states` is provided, will set the state tensors of the RNN to its
- * value.
- *
- * @param states Optional externally-provided initial states.
- * @param training Whether this call is done during training. For stateful
- * RNNs, this affects whether the old states are kept or discarded. In
- * particular, if `training` is `true`, the old states will be kept so
- * that subsequent backpropgataion through time (BPTT) may work properly.
- * Else, the old states will be discarded.
- */
- resetStates(states, training = false) {
- tidy(() => {
- if (!this.stateful) {
- throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
- }
- const batchSize = this.inputSpec[0].shape[0];
- if (batchSize == null) {
- throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
- 'the batch size of your input tensors: \n' +
- '- If using a Sequential model, specify the batch size by ' +
- 'passing a `batchInputShape` option to your first layer.\n' +
- '- If using the functional API, specify the batch size by ' +
- 'passing a `batchShape` option to your Input layer.');
- }
- // Initialize state if null.
- if (this.states_ == null) {
- if (Array.isArray(this.cell.stateSize)) {
- this.states_ =
- this.cell.stateSize.map(dim => zeros([batchSize, dim]));
- }
- else {
- this.states_ = [zeros([batchSize, this.cell.stateSize])];
- }
- }
- else if (states == null) {
- // Dispose old state tensors.
- dispose(this.states_);
- // For stateful RNNs, fully dispose kept old states.
- if (this.keptStates != null) {
- dispose(this.keptStates);
- this.keptStates = [];
- }
- if (Array.isArray(this.cell.stateSize)) {
- this.states_ =
- this.cell.stateSize.map(dim => zeros([batchSize, dim]));
- }
- else {
- this.states_[0] = zeros([batchSize, this.cell.stateSize]);
- }
- }
- else {
- if (!Array.isArray(states)) {
- states = [states];
- }
- if (states.length !== this.states_.length) {
- throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
- `but it received ${states.length} state value(s). Input ` +
- `received: ${states}`);
- }
- if (training === true) {
- // Store old state tensors for complete disposal later, i.e., during
- // the next no-arg call to this method. We do not dispose the old
- // states immediately because that BPTT (among other things) require
- // them.
- this.keptStates.push(this.states_.slice());
- }
- else {
- dispose(this.states_);
- }
- for (let index = 0; index < this.states_.length; ++index) {
- const value = states[index];
- const dim = Array.isArray(this.cell.stateSize) ?
- this.cell.stateSize[index] :
- this.cell.stateSize;
- const expectedShape = [batchSize, dim];
- if (!arraysEqual(value.shape, expectedShape)) {
- throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
- `expected shape=${expectedShape}, received shape=${value.shape}`);
- }
- this.states_[index] = value;
- }
- }
- this.states_ = this.states_.map(state => keep(state.clone()));
- });
- }
- apply(inputs, kwargs) {
- // TODO(cais): Figure out whether initialState is in kwargs or inputs.
- let initialState = kwargs == null ? null : kwargs['initialState'];
- let constants = kwargs == null ? null : kwargs['constants'];
- if (kwargs == null) {
- kwargs = {};
- }
- const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
- inputs = standardized.inputs;
- initialState = standardized.initialState;
- constants = standardized.constants;
- // If any of `initial_state` or `constants` are specified and are
- // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
- // the input_spec to include them.
- let additionalInputs = [];
- let additionalSpecs = [];
- if (initialState != null) {
- kwargs['initialState'] = initialState;
- additionalInputs = additionalInputs.concat(initialState);
- this.stateSpec = [];
- for (const state of initialState) {
- this.stateSpec.push(new InputSpec({ shape: state.shape }));
- }
- // TODO(cais): Use the following instead.
- // this.stateSpec = initialState.map(state => new InputSpec({shape:
- // state.shape}));
- additionalSpecs = additionalSpecs.concat(this.stateSpec);
- }
- if (constants != null) {
- kwargs['constants'] = constants;
- additionalInputs = additionalInputs.concat(constants);
- // TODO(cais): Add this.constantsSpec.
- this.numConstants = constants.length;
- }
- const isTensor = additionalInputs[0] instanceof SymbolicTensor;
- if (isTensor) {
- // Compute full input spec, including state and constants.
- const fullInput = [inputs].concat(additionalInputs);
- const fullInputSpec = this.inputSpec.concat(additionalSpecs);
- // Perform the call with temporarily replaced inputSpec.
- const originalInputSpec = this.inputSpec;
- this.inputSpec = fullInputSpec;
- const output = super.apply(fullInput, kwargs);
- this.inputSpec = originalInputSpec;
- return output;
- }
- else {
- return super.apply(inputs, kwargs);
- }
- }
- // tslint:disable-next-line:no-any
- call(inputs, kwargs) {
- // Input shape: `[samples, time (padded with zeros), input_dim]`.
- // Note that the .build() method of subclasses **must** define
- // this.inputSpec and this.stateSpec owith complete input shapes.
- return tidy(() => {
- const mask = kwargs == null ? null : kwargs['mask'];
- const training = kwargs == null ? null : kwargs['training'];
- let initialState = kwargs == null ? null : kwargs['initialState'];
- inputs = getExactlyOneTensor(inputs);
- if (initialState == null) {
- if (this.stateful) {
- initialState = this.states_;
- }
- else {
- initialState = this.getInitialState(inputs);
- }
- }
- const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
- if (initialState.length !== numStates) {
- throw new ValueError(`RNN Layer has ${numStates} state(s) but was passed ` +
- `${initialState.length} initial state(s).`);
- }
- if (this.unroll) {
- console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
- }
- const cellCallKwargs = { training };
- // TODO(cais): Add support for constants.
- const step = (inputs, states) => {
- // `inputs` and `states` are concatenated to form a single `Array` of
- // `tf.Tensor`s as the input to `cell.call()`.
- const outputs = this.cell.call([inputs].concat(states), cellCallKwargs);
- // Marshall the return value into output and new states.
- return [outputs[0], outputs.slice(1)];
- };
- // TODO(cais): Add support for constants.
- const rnnOutputs = rnn(step, inputs, initialState, this.goBackwards, mask, null, this.unroll, this.returnSequences);
- const lastOutput = rnnOutputs[0];
- const outputs = rnnOutputs[1];
- const states = rnnOutputs[2];
- if (this.stateful) {
- this.resetStates(states, training);
- }
- const output = this.returnSequences ? outputs : lastOutput;
- // TODO(cais): Porperty set learning phase flag.
- if (this.returnState) {
- return [output].concat(states);
- }
- else {
- return output;
- }
- });
- }
- getInitialState(inputs) {
- return tidy(() => {
- // Build an all-zero tensor of shape [samples, outputDim].
- // [Samples, timeSteps, inputDim].
- let initialState = zeros(inputs.shape);
- // [Samples].
- initialState = sum$1(initialState, [1, 2]);
- initialState = expandDims$1(initialState); // [Samples, 1].
- if (Array.isArray(this.cell.stateSize)) {
- return this.cell.stateSize.map(dim => dim > 1 ? tile$2(initialState, [1, dim]) : initialState);
- }
- else {
- return this.cell.stateSize > 1 ?
- [tile$2(initialState, [1, this.cell.stateSize])] :
- [initialState];
- }
- });
- }
- get trainableWeights() {
- if (!this.trainable) {
- return [];
- }
- // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
- return this.cell.trainableWeights;
- }
- get nonTrainableWeights() {
- // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
- if (!this.trainable) {
- return this.cell.weights;
- }
- return this.cell.nonTrainableWeights;
- }
- setFastWeightInitDuringBuild(value) {
- super.setFastWeightInitDuringBuild(value);
- if (this.cell != null) {
- this.cell.setFastWeightInitDuringBuild(value);
- }
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = {
- returnSequences: this.returnSequences,
- returnState: this.returnState,
- goBackwards: this.goBackwards,
- stateful: this.stateful,
- unroll: this.unroll,
- };
- if (this.numConstants != null) {
- config['numConstants'] = this.numConstants;
- }
- const cellConfig = this.cell.getConfig();
- if (this.getClassName() === RNN.className) {
- config['cell'] = {
- 'className': this.cell.getClassName(),
- 'config': cellConfig,
- };
- }
- // this order is necessary, to prevent cell name from replacing layer name
- return { ...cellConfig, ...baseConfig, ...config };
- }
- /** @nocollapse */
- static fromConfig(cls, config, customObjects = {}) {
- const cellConfig = config['cell'];
- const cell = deserialize(cellConfig, customObjects);
- return new cls(Object.assign(config, { cell }));
- }
- }
- /** @nocollapse */
- RNN.className = 'RNN';
- registerClass(RNN);
- // Porting Note: This is a common parent class for RNN cells. There is no
- // equivalent of this in PyKeras. Having a common parent class forgoes the
- // need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
- /**
- * An RNNCell layer.
- *
- * @doc {heading: 'Layers', subheading: 'Classes'}
- */
- class RNNCell extends Layer {
- }
- class SimpleRNNCell extends RNNCell {
- constructor(args) {
- super(args);
- this.DEFAULT_ACTIVATION = 'tanh';
- this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
- this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
- this.DEFAULT_BIAS_INITIALIZER = 'zeros';
- this.units = args.units;
- assertPositiveInteger(this.units, `units`);
- this.activation = getActivation(args.activation == null ? this.DEFAULT_ACTIVATION : args.activation);
- this.useBias = args.useBias == null ? true : args.useBias;
- this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
- this.biasInitializer =
- getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
- this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
- this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
- this.biasRegularizer = getRegularizer(args.biasRegularizer);
- this.kernelConstraint = getConstraint(args.kernelConstraint);
- this.recurrentConstraint = getConstraint(args.recurrentConstraint);
- this.biasConstraint = getConstraint(args.biasConstraint);
- this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
- this.recurrentDropout = min$1([
- 1,
- max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
- ]);
- this.stateSize = this.units;
- this.dropoutMask = null;
- this.recurrentDropoutMask = null;
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- // TODO(cais): Use regularizer.
- this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- else {
- this.bias = null;
- }
- this.built = true;
- }
- // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
- // `inputs` and `states`. Here, the two tensors are combined into an
- // `Tensor[]` Array as the first input argument.
- // Similarly, PyKeras' equivalent of this method returns two values:
- // `output` and `[output]`. Here the two are combined into one length-2
- // `Tensor[]`, consisting of `output` repeated.
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = inputs;
- if (inputs.length !== 2) {
- throw new ValueError(`SimpleRNNCell expects 2 input Tensors, got ${inputs.length}.`);
- }
- let prevOutput = inputs[1];
- inputs = inputs[0];
- const training = kwargs['training'] == null ? false : kwargs['training'];
- if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
- this.dropoutMask = generateDropoutMask({
- ones: () => onesLike(inputs),
- rate: this.dropout,
- training
- });
- }
- if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
- this.recurrentDropoutMask == null) {
- this.recurrentDropoutMask = generateDropoutMask({
- ones: () => onesLike(prevOutput),
- rate: this.recurrentDropout,
- training
- });
- }
- let h;
- const dpMask = this.dropoutMask;
- const recDpMask = this.recurrentDropoutMask;
- if (dpMask != null) {
- h = dot$1(mul(inputs, dpMask), this.kernel.read());
- }
- else {
- h = dot$1(inputs, this.kernel.read());
- }
- if (this.bias != null) {
- h = biasAdd(h, this.bias.read());
- }
- if (recDpMask != null) {
- prevOutput = mul(prevOutput, recDpMask);
- }
- let output = add$1(h, dot$1(prevOutput, this.recurrentKernel.read()));
- if (this.activation != null) {
- output = this.activation.apply(output);
- }
- // TODO(cais): Properly set learning phase on output tensor?
- return [output, output];
- });
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = {
- units: this.units,
- activation: serializeActivation(this.activation),
- useBias: this.useBias,
- kernelInitializer: serializeInitializer(this.kernelInitializer),
- recurrentInitializer: serializeInitializer(this.recurrentInitializer),
- biasInitializer: serializeInitializer(this.biasInitializer),
- kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
- recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
- biasRegularizer: serializeRegularizer(this.biasRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- kernelConstraint: serializeConstraint(this.kernelConstraint),
- recurrentConstraint: serializeConstraint(this.recurrentConstraint),
- biasConstraint: serializeConstraint(this.biasConstraint),
- dropout: this.dropout,
- recurrentDropout: this.recurrentDropout,
- };
- return { ...baseConfig, ...config };
- }
- }
- /** @nocollapse */
- SimpleRNNCell.className = 'SimpleRNNCell';
- registerClass(SimpleRNNCell);
- class SimpleRNN extends RNN {
- constructor(args) {
- args.cell = new SimpleRNNCell(args);
- super(args);
- // TODO(cais): Add activityRegularizer.
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (this.cell.dropoutMask != null) {
- dispose(this.cell.dropoutMask);
- this.cell.dropoutMask = null;
- }
- if (this.cell.recurrentDropoutMask != null) {
- dispose(this.cell.recurrentDropoutMask);
- this.cell.recurrentDropoutMask = null;
- }
- const mask = kwargs == null ? null : kwargs['mask'];
- const training = kwargs == null ? null : kwargs['training'];
- const initialState = kwargs == null ? null : kwargs['initialState'];
- return super.call(inputs, { mask, training, initialState });
- });
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config);
- }
- }
- /** @nocollapse */
- SimpleRNN.className = 'SimpleRNN';
- registerClass(SimpleRNN);
- class GRUCell extends RNNCell {
- constructor(args) {
- super(args);
- this.DEFAULT_ACTIVATION = 'tanh';
- this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
- this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
- this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
- this.DEFAULT_BIAS_INITIALIZER = 'zeros';
- if (args.resetAfter) {
- throw new ValueError(`GRUCell does not support reset_after parameter set to true.`);
- }
- this.units = args.units;
- assertPositiveInteger(this.units, 'units');
- this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
- args.activation);
- this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
- this.DEFAULT_RECURRENT_ACTIVATION :
- args.recurrentActivation);
- this.useBias = args.useBias == null ? true : args.useBias;
- this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
- this.biasInitializer =
- getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
- this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
- this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
- this.biasRegularizer = getRegularizer(args.biasRegularizer);
- this.kernelConstraint = getConstraint(args.kernelConstraint);
- this.recurrentConstraint = getConstraint(args.recurrentConstraint);
- this.biasConstraint = getConstraint(args.biasConstraint);
- this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
- this.recurrentDropout = min$1([
- 1,
- max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
- ]);
- this.implementation = args.implementation;
- this.stateSize = this.units;
- this.dropoutMask = null;
- this.recurrentDropoutMask = null;
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const inputDim = inputShape[inputShape.length - 1];
- this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- else {
- this.bias = null;
- }
- // Porting Notes: Unlike the PyKeras implementation, we perform slicing
- // of the weights and bias in the call() method, at execution time.
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = inputs;
- if (inputs.length !== 2) {
- throw new ValueError(`GRUCell expects 2 input Tensors (inputs, h, c), got ` +
- `${inputs.length}.`);
- }
- const training = kwargs['training'] == null ? false : kwargs['training'];
- let hTMinus1 = inputs[1]; // Previous memory state.
- inputs = inputs[0];
- // Note: For superior performance, TensorFlow.js always uses
- // implementation 2, regardless of the actual value of
- // config.implementation.
- if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
- this.dropoutMask = generateDropoutMask({
- ones: () => onesLike(inputs),
- rate: this.dropout,
- training,
- count: 3
- });
- }
- if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
- this.recurrentDropoutMask == null) {
- this.recurrentDropoutMask = generateDropoutMask({
- ones: () => onesLike(hTMinus1),
- rate: this.recurrentDropout,
- training,
- count: 3
- });
- }
- const dpMask = this.dropoutMask;
- const recDpMask = this.recurrentDropoutMask;
- let z;
- let r;
- let hh;
- if (0 < this.dropout && this.dropout < 1) {
- inputs = mul(inputs, dpMask[0]);
- }
- let matrixX = dot$1(inputs, this.kernel.read());
- if (this.useBias) {
- matrixX = biasAdd(matrixX, this.bias.read());
- }
- if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
- hTMinus1 = mul(hTMinus1, recDpMask[0]);
- }
- const recurrentKernelValue = this.recurrentKernel.read();
- const [rk1, rk2] = split(recurrentKernelValue, [2 * this.units, this.units], recurrentKernelValue.rank - 1);
- const matrixInner = dot$1(hTMinus1, rk1);
- const [xZ, xR, xH] = split(matrixX, 3, matrixX.rank - 1);
- const [recurrentZ, recurrentR] = split(matrixInner, 2, matrixInner.rank - 1);
- z = this.recurrentActivation.apply(add$1(xZ, recurrentZ));
- r = this.recurrentActivation.apply(add$1(xR, recurrentR));
- const recurrentH = dot$1(mul(r, hTMinus1), rk2);
- hh = this.activation.apply(add$1(xH, recurrentH));
- const h = add$1(mul(z, hTMinus1), mul(add$1(1, neg(z)), hh));
- // TODO(cais): Add use_learning_phase flag properly.
- return [h, h];
- });
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = {
- units: this.units,
- activation: serializeActivation(this.activation),
- recurrentActivation: serializeActivation(this.recurrentActivation),
- useBias: this.useBias,
- kernelInitializer: serializeInitializer(this.kernelInitializer),
- recurrentInitializer: serializeInitializer(this.recurrentInitializer),
- biasInitializer: serializeInitializer(this.biasInitializer),
- kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
- recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
- biasRegularizer: serializeRegularizer(this.biasRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- kernelConstraint: serializeConstraint(this.kernelConstraint),
- recurrentConstraint: serializeConstraint(this.recurrentConstraint),
- biasConstraint: serializeConstraint(this.biasConstraint),
- dropout: this.dropout,
- recurrentDropout: this.recurrentDropout,
- implementation: this.implementation,
- resetAfter: false
- };
- return { ...baseConfig, ...config };
- }
- }
- /** @nocollapse */
- GRUCell.className = 'GRUCell';
- registerClass(GRUCell);
- class GRU extends RNN {
- constructor(args) {
- if (args.implementation === 0) {
- console.warn('`implementation=0` has been deprecated, and now defaults to ' +
- '`implementation=1`. Please update your layer call.');
- }
- args.cell = new GRUCell(args);
- super(args);
- // TODO(cais): Add activityRegularizer.
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (this.cell.dropoutMask != null) {
- dispose(this.cell.dropoutMask);
- this.cell.dropoutMask = null;
- }
- if (this.cell.recurrentDropoutMask != null) {
- dispose(this.cell.recurrentDropoutMask);
- this.cell.recurrentDropoutMask = null;
- }
- const mask = kwargs == null ? null : kwargs['mask'];
- const training = kwargs == null ? null : kwargs['training'];
- const initialState = kwargs == null ? null : kwargs['initialState'];
- return super.call(inputs, { mask, training, initialState });
- });
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- if (config['implmentation'] === 0) {
- config['implementation'] = 1;
- }
- return new cls(config);
- }
- }
- /** @nocollapse */
- GRU.className = 'GRU';
- registerClass(GRU);
- class LSTMCell extends RNNCell {
- constructor(args) {
- super(args);
- this.DEFAULT_ACTIVATION = 'tanh';
- this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
- this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
- this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
- this.DEFAULT_BIAS_INITIALIZER = 'zeros';
- this.units = args.units;
- assertPositiveInteger(this.units, 'units');
- this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
- args.activation);
- this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
- this.DEFAULT_RECURRENT_ACTIVATION :
- args.recurrentActivation);
- this.useBias = args.useBias == null ? true : args.useBias;
- this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
- this.biasInitializer =
- getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
- this.unitForgetBias = args.unitForgetBias;
- this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
- this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
- this.biasRegularizer = getRegularizer(args.biasRegularizer);
- this.kernelConstraint = getConstraint(args.kernelConstraint);
- this.recurrentConstraint = getConstraint(args.recurrentConstraint);
- this.biasConstraint = getConstraint(args.biasConstraint);
- this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
- this.recurrentDropout = min$1([
- 1,
- max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
- ]);
- this.implementation = args.implementation;
- this.stateSize = [this.units, this.units];
- this.dropoutMask = null;
- this.recurrentDropoutMask = null;
- }
- build(inputShape) {
- var _a;
- inputShape = getExactlyOneShape(inputShape);
- const inputDim = inputShape[inputShape.length - 1];
- this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
- let biasInitializer;
- if (this.useBias) {
- if (this.unitForgetBias) {
- const capturedBiasInit = this.biasInitializer;
- const capturedUnits = this.units;
- biasInitializer = new (_a = class CustomInit extends Initializer {
- apply(shape, dtype) {
- // TODO(cais): More informative variable names?
- const bI = capturedBiasInit.apply([capturedUnits]);
- const bF = (new Ones()).apply([capturedUnits]);
- const bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
- return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
- }
- },
- /** @nocollapse */
- _a.className = 'CustomInit',
- _a)();
- }
- else {
- biasInitializer = this.biasInitializer;
- }
- this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- else {
- this.bias = null;
- }
- // Porting Notes: Unlike the PyKeras implementation, we perform slicing
- // of the weights and bias in the call() method, at execution time.
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- const training = kwargs['training'] == null ? false : kwargs['training'];
- inputs = inputs;
- if (inputs.length !== 3) {
- throw new ValueError(`LSTMCell expects 3 input Tensors (inputs, h, c), got ` +
- `${inputs.length}.`);
- }
- let hTMinus1 = inputs[1]; // Previous memory state.
- const cTMinus1 = inputs[2]; // Previous carry state.
- inputs = inputs[0];
- if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
- this.dropoutMask = generateDropoutMask({
- ones: () => onesLike(inputs),
- rate: this.dropout,
- training,
- count: 4
- });
- }
- if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
- this.recurrentDropoutMask == null) {
- this.recurrentDropoutMask = generateDropoutMask({
- ones: () => onesLike(hTMinus1),
- rate: this.recurrentDropout,
- training,
- count: 4
- });
- }
- const dpMask = this.dropoutMask;
- const recDpMask = this.recurrentDropoutMask;
- // Note: For superior performance, TensorFlow.js always uses
- // implementation 2 regardless of the actual value of
- // config.implementation.
- let i;
- let f;
- let c;
- let o;
- if (0 < this.dropout && this.dropout < 1) {
- inputs = mul(inputs, dpMask[0]);
- }
- let z = dot$1(inputs, this.kernel.read());
- if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
- hTMinus1 = mul(hTMinus1, recDpMask[0]);
- }
- z = add$1(z, dot$1(hTMinus1, this.recurrentKernel.read()));
- if (this.useBias) {
- z = biasAdd(z, this.bias.read());
- }
- const [z0, z1, z2, z3] = split(z, 4, z.rank - 1);
- i = this.recurrentActivation.apply(z0);
- f = this.recurrentActivation.apply(z1);
- c = add$1(mul(f, cTMinus1), mul(i, this.activation.apply(z2)));
- o = this.recurrentActivation.apply(z3);
- const h = mul(o, this.activation.apply(c));
- // TODO(cais): Add use_learning_phase flag properly.
- return [h, h, c];
- });
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = {
- units: this.units,
- activation: serializeActivation(this.activation),
- recurrentActivation: serializeActivation(this.recurrentActivation),
- useBias: this.useBias,
- kernelInitializer: serializeInitializer(this.kernelInitializer),
- recurrentInitializer: serializeInitializer(this.recurrentInitializer),
- biasInitializer: serializeInitializer(this.biasInitializer),
- unitForgetBias: this.unitForgetBias,
- kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
- recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
- biasRegularizer: serializeRegularizer(this.biasRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- kernelConstraint: serializeConstraint(this.kernelConstraint),
- recurrentConstraint: serializeConstraint(this.recurrentConstraint),
- biasConstraint: serializeConstraint(this.biasConstraint),
- dropout: this.dropout,
- recurrentDropout: this.recurrentDropout,
- implementation: this.implementation,
- };
- return { ...baseConfig, ...config };
- }
- }
- /** @nocollapse */
- LSTMCell.className = 'LSTMCell';
- registerClass(LSTMCell);
- class LSTM extends RNN {
- constructor(args) {
- if (args.implementation === 0) {
- console.warn('`implementation=0` has been deprecated, and now defaults to ' +
- '`implementation=1`. Please update your layer call.');
- }
- args.cell = new LSTMCell(args);
- super(args);
- // TODO(cais): Add activityRegularizer.
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (this.cell.dropoutMask != null) {
- dispose(this.cell.dropoutMask);
- this.cell.dropoutMask = null;
- }
- if (this.cell.recurrentDropoutMask != null) {
- dispose(this.cell.recurrentDropoutMask);
- this.cell.recurrentDropoutMask = null;
- }
- const mask = kwargs == null ? null : kwargs['mask'];
- const training = kwargs == null ? null : kwargs['training'];
- const initialState = kwargs == null ? null : kwargs['initialState'];
- return super.call(inputs, { mask, training, initialState });
- });
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- if (config['implmentation'] === 0) {
- config['implementation'] = 1;
- }
- return new cls(config);
- }
- }
- /** @nocollapse */
- LSTM.className = 'LSTM';
- registerClass(LSTM);
- class StackedRNNCells extends RNNCell {
- constructor(args) {
- super(args);
- this.cells = args.cells;
- }
- get stateSize() {
- // States are a flat list in reverse order of the cell stack.
- // This allows perserving the requirement `stack.statesize[0] ===
- // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
- // assuming one LSTM has states `[h, c]`.
- const stateSize = [];
- for (const cell of this.cells.slice().reverse()) {
- if (Array.isArray(cell.stateSize)) {
- stateSize.push(...cell.stateSize);
- }
- else {
- stateSize.push(cell.stateSize);
- }
- }
- return stateSize;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = inputs;
- let states = inputs.slice(1);
- // Recover per-cell states.
- const nestedStates = [];
- for (const cell of this.cells.slice().reverse()) {
- if (Array.isArray(cell.stateSize)) {
- nestedStates.push(states.splice(0, cell.stateSize.length));
- }
- else {
- nestedStates.push(states.splice(0, 1));
- }
- }
- nestedStates.reverse();
- // Call the cells in order and store the returned states.
- const newNestedStates = [];
- let callInputs;
- for (let i = 0; i < this.cells.length; ++i) {
- const cell = this.cells[i];
- states = nestedStates[i];
- // TODO(cais): Take care of constants.
- if (i === 0) {
- callInputs = [inputs[0]].concat(states);
- }
- else {
- callInputs = [callInputs[0]].concat(states);
- }
- callInputs = cell.call(callInputs, kwargs);
- newNestedStates.push(callInputs.slice(1));
- }
- // Format the new states as a flat list in reverse cell order.
- states = [];
- for (const cellStates of newNestedStates.slice().reverse()) {
- states.push(...cellStates);
- }
- return [callInputs[0]].concat(states);
- });
- }
- build(inputShape) {
- if (isArrayOfShapes(inputShape)) {
- // TODO(cais): Take care of input constants.
- // const constantShape = inputShape.slice(1);
- inputShape = inputShape[0];
- }
- inputShape = inputShape;
- let outputDim;
- this.cells.forEach((cell, i) => {
- nameScope(`RNNCell_${i}`, () => {
- // TODO(cais): Take care of input constants.
- cell.build(inputShape);
- if (Array.isArray(cell.stateSize)) {
- outputDim = cell.stateSize[0];
- }
- else {
- outputDim = cell.stateSize;
- }
- inputShape = [inputShape[0], outputDim];
- });
- });
- this.built = true;
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const getCellConfig = (cell) => {
- return {
- 'className': cell.getClassName(),
- 'config': cell.getConfig(),
- };
- };
- const cellConfigs = this.cells.map(getCellConfig);
- const config = { 'cells': cellConfigs };
- return { ...baseConfig, ...config };
- }
- /** @nocollapse */
- static fromConfig(cls, config, customObjects = {}) {
- const cells = [];
- for (const cellConfig of config['cells']) {
- cells.push(deserialize(cellConfig, customObjects));
- }
- return new cls({ cells });
- }
- get trainableWeights() {
- if (!this.trainable) {
- return [];
- }
- const weights = [];
- for (const cell of this.cells) {
- weights.push(...cell.trainableWeights);
- }
- return weights;
- }
- get nonTrainableWeights() {
- const weights = [];
- for (const cell of this.cells) {
- weights.push(...cell.nonTrainableWeights);
- }
- if (!this.trainable) {
- const trainableWeights = [];
- for (const cell of this.cells) {
- trainableWeights.push(...cell.trainableWeights);
- }
- return trainableWeights.concat(weights);
- }
- return weights;
- }
- /**
- * Retrieve the weights of a the model.
- *
- * @returns A flat `Array` of `tf.Tensor`s.
- */
- getWeights() {
- const weights = [];
- for (const cell of this.cells) {
- weights.push(...cell.weights);
- }
- return batchGetValue(weights);
- }
- /**
- * Set the weights of the model.
- *
- * @param weights An `Array` of `tf.Tensor`s with shapes and types matching
- * the output of `getWeights()`.
- */
- setWeights(weights) {
- const tuples = [];
- for (const cell of this.cells) {
- const numParams = cell.weights.length;
- const inputWeights = weights.splice(numParams);
- for (let i = 0; i < cell.weights.length; ++i) {
- tuples.push([cell.weights[i], inputWeights[i]]);
- }
- }
- batchSetValue(tuples);
- }
- }
- /** @nocollapse */
- StackedRNNCells.className = 'StackedRNNCells';
- registerClass(StackedRNNCells);
- function generateDropoutMask(args) {
- const { ones, rate, training = false, count = 1 } = args;
- const droppedInputs = () => dropout$1(ones(), rate);
- const createMask = () => inTrainPhase(droppedInputs, ones, training);
- // just in case count is provided with null or undefined
- if (!count || count <= 1) {
- return keep(createMask().clone());
- }
- const masks = Array(count).fill(undefined).map(createMask);
- return masks.map(m => keep(m.clone()));
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class ConvRNN2DCell extends RNNCell {
- }
- /**
- * Base class for convolutional-recurrent layers.
- */
- class ConvRNN2D extends RNN {
- constructor(args) {
- if (args.unroll) {
- throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.');
- }
- if (Array.isArray(args.cell)) {
- throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.');
- }
- super(args);
- this.inputSpec = [new InputSpec({ ndim: 5 })];
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (this.cell.dropoutMask != null) {
- dispose(this.cell.dropoutMask);
- this.cell.dropoutMask = null;
- }
- if (this.cell.recurrentDropoutMask != null) {
- dispose(this.cell.recurrentDropoutMask);
- this.cell.recurrentDropoutMask = null;
- }
- if (kwargs && kwargs['constants']) {
- throw new ValueError('ConvRNN2D cell does not support constants');
- }
- const mask = kwargs == null ? null : kwargs['mask'];
- const training = kwargs == null ? null : kwargs['training'];
- const initialState = kwargs == null ? null : kwargs['initialState'];
- return super.call(inputs, { mask, training, initialState });
- });
- }
- computeOutputShape(inputShape) {
- let outShape = this.computeSingleOutputShape(inputShape);
- if (!this.returnSequences) {
- outShape = [outShape[0], ...outShape.slice(2)];
- }
- if (this.returnState) {
- outShape =
- [outShape, ...Array(2).fill([inputShape[0], ...outShape.slice(-3)])];
- }
- return outShape;
- }
- getInitialState(inputs) {
- return tidy(() => {
- const { stateSize } = this.cell;
- const inputShape = inputs.shape;
- const outputShape = this.computeSingleOutputShape(inputShape);
- const stateShape = [outputShape[0], ...outputShape.slice(2)];
- const initialState = zeros(stateShape);
- if (Array.isArray(stateSize)) {
- return Array(stateSize.length).fill(initialState);
- }
- return [initialState];
- });
- }
- resetStates(states, training = false) {
- tidy(() => {
- if (!this.stateful) {
- throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
- }
- const inputShape = this.inputSpec[0].shape;
- const outputShape = this.computeSingleOutputShape(inputShape);
- const stateShape = [outputShape[0], ...outputShape.slice(2)];
- const batchSize = inputShape[0];
- if (batchSize == null) {
- throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
- 'the batch size of your input tensors: \n' +
- '- If using a Sequential model, specify the batch size by ' +
- 'passing a `batchInputShape` option to your first layer.\n' +
- '- If using the functional API, specify the batch size by ' +
- 'passing a `batchShape` option to your Input layer.');
- }
- // Initialize state if null.
- if (this.getStates() == null) {
- if (Array.isArray(this.cell.stateSize)) {
- this.states_ = this.cell.stateSize.map(() => zeros(stateShape));
- }
- else {
- this.states_ = [zeros(stateShape)];
- }
- }
- else if (states == null) {
- // Dispose old state tensors.
- dispose(this.states_);
- // For stateful RNNs, fully dispose kept old states.
- if (this.keptStates != null) {
- dispose(this.keptStates);
- this.keptStates = [];
- }
- if (Array.isArray(this.cell.stateSize)) {
- this.states_ = this.cell.stateSize.map(() => zeros(stateShape));
- }
- else {
- this.states_[0] = zeros(stateShape);
- }
- }
- else {
- if (!Array.isArray(states)) {
- states = [states];
- }
- if (states.length !== this.states_.length) {
- throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
- `but it received ${states.length} state value(s). Input ` +
- `received: ${states}`);
- }
- if (training) {
- // Store old state tensors for complete disposal later, i.e., during
- // the next no-arg call to this method. We do not dispose the old
- // states immediately because that BPTT (among other things) require
- // them.
- this.keptStates.push(this.states_.slice());
- }
- else {
- dispose(this.states_);
- }
- for (let index = 0; index < this.states_.length; ++index) {
- const value = states[index];
- const expectedShape = stateShape;
- if (!arraysEqual(value.shape, expectedShape)) {
- throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
- `expected shape=${expectedShape}, received shape=${value.shape}`);
- }
- this.states_[index] = value;
- }
- }
- this.states_ = this.states_.map(state => keep(state.clone()));
- });
- }
- computeSingleOutputShape(inputShape) {
- const { dataFormat, filters, kernelSize, padding, strides, dilationRate } = this.cell;
- const isChannelsFirst = dataFormat === 'channelsFirst';
- const h = inputShape[isChannelsFirst ? 3 : 2];
- const w = inputShape[isChannelsFirst ? 4 : 3];
- const hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
- const wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
- const outShape = [
- ...inputShape.slice(0, 2),
- ...(isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters])
- ];
- return outShape;
- }
- }
- /** @nocollapse */
- ConvRNN2D.className = 'ConvRNN2D';
- class ConvLSTM2DCell extends LSTMCell {
- constructor(args) {
- const { filters, kernelSize, strides, padding, dataFormat, dilationRate, } = args;
- super({ ...args, units: filters });
- this.filters = filters;
- assertPositiveInteger(this.filters, 'filters');
- this.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize');
- this.kernelSize.forEach(size => assertPositiveInteger(size, 'kernelSize'));
- this.strides = normalizeArray(strides || 1, 2, 'strides');
- this.strides.forEach(stride => assertPositiveInteger(stride, 'strides'));
- this.padding = padding || 'valid';
- checkPaddingMode(this.padding);
- this.dataFormat = dataFormat || 'channelsLast';
- checkDataFormat(this.dataFormat);
- this.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate');
- this.dilationRate.forEach(rate => assertPositiveInteger(rate, 'dilationRate'));
- }
- build(inputShape) {
- var _a;
- inputShape = getExactlyOneShape(inputShape);
- const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
- if (inputShape[channelAxis] == null) {
- throw new ValueError(`The channel dimension of the input should be defined. ` +
- `Found ${inputShape[channelAxis]}`);
- }
- const inputDim = inputShape[channelAxis];
- const numOfKernels = 4;
- const kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
- this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- const recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
- this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
- if (this.useBias) {
- let biasInitializer;
- if (this.unitForgetBias) {
- const init = this.biasInitializer;
- const filters = this.filters;
- biasInitializer = new (_a = class CustomInit extends Initializer {
- apply(shape, dtype) {
- const biasI = init.apply([filters]);
- const biasF = ones$1([filters]);
- const biasCAndO = init.apply([filters * 2]);
- return concatenate([biasI, biasF, biasCAndO]);
- }
- },
- /** @nocollapse */
- _a.className = 'CustomInit',
- _a)();
- }
- else {
- biasInitializer = this.biasInitializer;
- }
- this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (inputs.length !== 3) {
- throw new ValueError(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ` +
- `${inputs.length}.`);
- }
- const training = kwargs['training'] || false;
- const x = inputs[0]; // Current input
- const hTMinus1 = inputs[1]; // Previous memory state.
- const cTMinus1 = inputs[2]; // Previous carry state.
- const numOfKernels = 4;
- if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
- this.dropoutMask = generateDropoutMask({
- ones: () => onesLike(x),
- rate: this.dropout,
- training,
- count: numOfKernels
- });
- }
- const dropoutMask = this.dropoutMask;
- const applyDropout = (x, mask, index) => {
- if (!mask || !mask[index]) {
- return x;
- }
- return mul(mask[index], x);
- };
- let xI = applyDropout(x, dropoutMask, 0);
- let xF = applyDropout(x, dropoutMask, 1);
- let xC = applyDropout(x, dropoutMask, 2);
- let xO = applyDropout(x, dropoutMask, 3);
- if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
- this.recurrentDropoutMask == null) {
- this.recurrentDropoutMask = generateDropoutMask({
- ones: () => onesLike(hTMinus1),
- rate: this.recurrentDropout,
- training,
- count: numOfKernels
- });
- }
- const recDropoutMask = this.recurrentDropoutMask;
- let hI = applyDropout(hTMinus1, recDropoutMask, 0);
- let hF = applyDropout(hTMinus1, recDropoutMask, 1);
- let hC = applyDropout(hTMinus1, recDropoutMask, 2);
- let hO = applyDropout(hTMinus1, recDropoutMask, 3);
- const kernelChannelAxis = 3;
- const [kernelI, kernelF, kernelC, kernelO] = split(this.kernel.read(), numOfKernels, kernelChannelAxis);
- const [biasI, biasF, biasC, biasO] = this.useBias ?
- split(this.bias.read(), numOfKernels) :
- [null, null, null, null];
- xI = this.inputConv(xI, kernelI, biasI, this.padding);
- xF = this.inputConv(xF, kernelF, biasF, this.padding);
- xC = this.inputConv(xC, kernelC, biasC, this.padding);
- xO = this.inputConv(xO, kernelO, biasO, this.padding);
- const [recKernelI, recKernelF, recKernelC, recKernelO] = split(this.recurrentKernel.read(), numOfKernels, kernelChannelAxis);
- hI = this.recurrentConv(hI, recKernelI);
- hF = this.recurrentConv(hF, recKernelF);
- hC = this.recurrentConv(hC, recKernelC);
- hO = this.recurrentConv(hO, recKernelO);
- const i = this.recurrentActivation.apply(add$1(xI, hI));
- const f = this.recurrentActivation.apply(add$1(xF, hF));
- const c = add$1(mul(f, cTMinus1), mul(i, this.activation.apply(add$1(xC, hC))));
- const h = mul(this.recurrentActivation.apply(add$1(xO, hO)), this.activation.apply(c));
- return [h, h, c];
- });
- }
- getConfig() {
- const { 'units': _, ...baseConfig } = super.getConfig();
- const config = {
- filters: this.filters,
- kernelSize: this.kernelSize,
- padding: this.padding,
- dataFormat: this.dataFormat,
- dilationRate: this.dilationRate,
- strides: this.strides,
- };
- return { ...baseConfig, ...config };
- }
- inputConv(x, w, b, padding) {
- const out = conv2d(x, w, this.strides, (padding || 'valid'), this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate);
- if (b) {
- return biasAdd(out, b, this.dataFormat);
- }
- return out;
- }
- recurrentConv(x, w) {
- const strides = 1;
- return conv2d(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC');
- }
- }
- /** @nocollapse */
- ConvLSTM2DCell.className = 'ConvLSTM2DCell';
- registerClass(ConvLSTM2DCell);
- class ConvLSTM2D extends ConvRNN2D {
- constructor(args) {
- const cell = new ConvLSTM2DCell(args);
- super({ ...args, cell });
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- return new cls(config);
- }
- }
- /** @nocollapse */
- ConvLSTM2D.className = 'ConvLSTM2D';
- registerClass(ConvLSTM2D);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class Dropout extends Layer {
- constructor(args) {
- super(args);
- this.rate = Math.max(Math.min(args.rate, 1), 0);
- // So that the scalar doesn't get tidied up between executions.
- this.noiseShape = args.noiseShape;
- this.seed = args.seed;
- this.supportsMasking = true;
- }
- getNoiseShape(input) {
- if (this.noiseShape == null) {
- return this.noiseShape;
- }
- const inputShape = input.shape;
- const noiseShape = [];
- for (let i = 0; i < this.noiseShape.length; ++i) {
- noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
- }
- return noiseShape;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- if (0 < this.rate && this.rate < 1) {
- const training = kwargs['training'] == null ? false : kwargs['training'];
- const noiseShape = this.getNoiseShape(input);
- const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training);
- return output;
- }
- return inputs;
- });
- }
- getConfig() {
- const config = {
- rate: this.rate,
- noiseShape: this.noiseShape,
- seed: this.seed,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- dispose() {
- return super.dispose();
- }
- }
- /** @nocollapse */
- Dropout.className = 'Dropout';
- registerClass(Dropout);
- class SpatialDropout1D extends Dropout {
- constructor(args) {
- super(args);
- this.inputSpec = [{ ndim: 3 }];
- }
- getNoiseShape(input) {
- const inputShape = input.shape;
- return [inputShape[0], 1, inputShape[2]];
- }
- }
- /** @nocollapse */
- SpatialDropout1D.className = 'SpatialDropout1D';
- registerClass(SpatialDropout1D);
- class Dense extends Layer {
- constructor(args) {
- super(args);
- // Default activation: Linear (none).
- this.activation = null;
- this.useBias = true;
- this.kernel = null;
- this.bias = null;
- this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
- this.DEFAULT_BIAS_INITIALIZER = 'zeros';
- if (args.batchInputShape == null && args.inputShape == null &&
- args.inputDim != null) {
- // This logic is copied from Layer's constructor, since we can't
- // do exactly what the Python constructor does for Dense().
- let batchSize = null;
- if (args.batchSize != null) {
- batchSize = args.batchSize;
- }
- this.batchInputShape = [batchSize, args.inputDim];
- }
- this.units = args.units;
- assertPositiveInteger(this.units, 'units');
- this.activation = getActivation(args.activation);
- if (args.useBias != null) {
- this.useBias = args.useBias;
- }
- this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
- this.biasInitializer =
- getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
- this.kernelConstraint = getConstraint(args.kernelConstraint);
- this.biasConstraint = getConstraint(args.biasConstraint);
- this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
- this.biasRegularizer = getRegularizer(args.biasRegularizer);
- this.activityRegularizer = getRegularizer(args.activityRegularizer);
- this.supportsMasking = true;
- this.inputSpec = [{ minNDim: 2 }];
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const inputLastDim = inputShape[inputShape.length - 1];
- if (this.kernel == null) {
- this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
- if (this.useBias) {
- this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
- }
- }
- this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }];
- this.built = true;
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const outputShape = inputShape.slice();
- outputShape[outputShape.length - 1] = this.units;
- return outputShape;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- // Dense layer accepts only a single input.
- const input = getExactlyOneTensor(inputs);
- const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
- let output;
- if (fusedActivationName != null) {
- output = dot$1(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null);
- }
- else {
- output = dot$1(input, this.kernel.read());
- if (this.bias != null) {
- output = biasAdd(output, this.bias.read());
- }
- if (this.activation != null) {
- output = this.activation.apply(output);
- }
- }
- return output;
- });
- }
- getConfig() {
- const config = {
- units: this.units,
- activation: serializeActivation(this.activation),
- useBias: this.useBias,
- kernelInitializer: serializeInitializer(this.kernelInitializer),
- biasInitializer: serializeInitializer(this.biasInitializer),
- kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
- biasRegularizer: serializeRegularizer(this.biasRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- kernelConstraint: serializeConstraint(this.kernelConstraint),
- biasConstraint: serializeConstraint(this.biasConstraint)
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Dense.className = 'Dense';
- registerClass(Dense);
- class Flatten extends Layer {
- constructor(args) {
- args = args || {};
- super(args);
- this.inputSpec = [{ minNDim: 3 }];
- this.dataFormat = args.dataFormat;
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- for (const dim of inputShape.slice(1)) {
- if (dim == null) {
- throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` +
- `(got ${inputShape.slice(1)}). Make sure to pass a complete ` +
- `"input_shape" or "batch_input_shape" argument to the first ` +
- `layer in your model.`);
- }
- }
- return [inputShape[0], arrayProd(inputShape, 1)];
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- let input = getExactlyOneTensor(inputs);
- if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
- const permutation = [0];
- for (let i = 2; i < input.rank; ++i) {
- permutation.push(i);
- }
- permutation.push(1);
- input = input.transpose(permutation);
- }
- return batchFlatten(input);
- });
- }
- getConfig() {
- const config = {};
- if (this.dataFormat != null) {
- config['dataFormat'] = this.dataFormat;
- }
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Flatten.className = 'Flatten';
- registerClass(Flatten);
- class Activation$1 extends Layer {
- constructor(args) {
- super(args);
- this.supportsMasking = true;
- this.activation = getActivation(args.activation);
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- return this.activation.apply(input);
- });
- }
- getConfig() {
- const config = { activation: serializeActivation(this.activation) };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Activation$1.className = 'Activation';
- registerClass(Activation$1);
- class RepeatVector extends Layer {
- constructor(args) {
- super(args);
- this.n = args.n;
- this.inputSpec = [{ ndim: 2 }];
- }
- computeOutputShape(inputShape) {
- return [inputShape[0], this.n, inputShape[1]];
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = getExactlyOneTensor(inputs);
- return repeat(inputs, this.n);
- });
- }
- getConfig() {
- const config = {
- n: this.n,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- RepeatVector.className = 'RepeatVector';
- registerClass(RepeatVector);
- class Reshape$1 extends Layer {
- constructor(args) {
- super(args);
- this.targetShape = args.targetShape;
- // Make sure that all unknown dimensions are represented as `null`.
- for (let i = 0; i < this.targetShape.length; ++i) {
- if (this.isUnknown(this.targetShape[i])) {
- this.targetShape[i] = null;
- }
- }
- }
- isUnknown(dim) {
- return dim < 0 || dim == null;
- }
- /**
- * Finds and replaces a missing dimension in output shape.
- *
- * This is a near direct port of the internal Numpy function
- * `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`.
- *
- * @param inputShape: Original shape of array begin reshape.
- * @param outputShape: Target shape of the array, with at most a single
- * `null` or negative number, which indicates an underdetermined dimension
- * that should be derived from `inputShape` and the known dimensions of
- * `outputShape`.
- * @returns: The output shape with `null` replaced with its computed value.
- * @throws: ValueError: If `inputShape` and `outputShape` do not match.
- */
- fixUnknownDimension(inputShape, outputShape) {
- const errorMsg = 'Total size of new array must be unchanged.';
- const finalShape = outputShape.slice();
- let known = 1;
- let unknown = null;
- for (let i = 0; i < finalShape.length; ++i) {
- const dim = finalShape[i];
- if (this.isUnknown(dim)) {
- if (unknown === null) {
- unknown = i;
- }
- else {
- throw new ValueError('Can only specifiy one unknown dimension.');
- }
- }
- else {
- known *= dim;
- }
- }
- const originalSize = arrayProd(inputShape);
- if (unknown !== null) {
- if (known === 0 || originalSize % known !== 0) {
- throw new ValueError(errorMsg);
- }
- finalShape[unknown] = originalSize / known;
- }
- else if (originalSize !== known) {
- throw new ValueError(errorMsg);
- }
- return finalShape;
- }
- computeOutputShape(inputShape) {
- let anyUnknownDims = false;
- for (let i = 0; i < inputShape.length; ++i) {
- if (this.isUnknown(inputShape[i])) {
- anyUnknownDims = true;
- break;
- }
- }
- if (anyUnknownDims) {
- return inputShape.slice(0, 1).concat(this.targetShape);
- }
- else {
- return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- const inputShape = input.shape;
- const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
- return input.reshape(outputShape);
- });
- }
- getConfig() {
- const config = {
- targetShape: this.targetShape,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Reshape$1.className = 'Reshape';
- registerClass(Reshape$1);
- class Permute extends Layer {
- constructor(args) {
- super(args);
- if (args.dims == null) {
- throw new Error('Required configuration field `dims` is missing during Permute ' +
- 'constructor call.');
- }
- if (!Array.isArray(args.dims)) {
- throw new Error('Permute constructor requires `dims` to be an Array, but received ' +
- `${args.dims} instead.`);
- }
- // Check the validity of the permutation indices.
- const expectedSortedIndices = range$1(1, args.dims.length + 1);
- if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
- throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) +
- ' `dims` must contain consecutive integers starting from 1.');
- }
- this.dims = args.dims;
- this.dimsIncludingBatch = [0].concat(this.dims);
- this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })];
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const outputShape = inputShape.slice();
- this.dims.forEach((dim, i) => {
- outputShape[i + 1] = inputShape[dim];
- });
- return outputShape;
- }
- call(inputs, kwargs) {
- return transpose(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
- }
- getConfig() {
- const config = {
- dims: this.dims,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Permute.className = 'Permute';
- registerClass(Permute);
- class Masking extends Layer {
- constructor(args) {
- super(args == null ? {} : args);
- this.supportsMasking = true;
- if (args != null) {
- this.maskValue = args.maskValue == null ? 0 : args.maskValue;
- }
- else {
- this.maskValue = 0;
- }
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = { maskValue: this.maskValue };
- Object.assign(config, baseConfig);
- return config;
- }
- computeMask(inputs, mask) {
- const input = getExactlyOneTensor(inputs);
- const axis = -1;
- return any(notEqual(input, this.maskValue), axis);
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- const axis = -1;
- const keepDims = true;
- const booleanMask = any(notEqual(input, this.maskValue), axis, keepDims);
- const output = input.mul(booleanMask.asType(input.dtype));
- return output;
- });
- }
- }
- /** @nocollapse */
- Masking.className = 'Masking';
- registerClass(Masking);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class Embedding extends Layer {
- constructor(args) {
- super(args);
- this.embeddings = null;
- this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
- if (args.batchInputShape == null && args.inputShape == null) {
- // Porting Note: This logic is copied from Layer's constructor, since we
- // can't do exactly what the Python constructor does for Embedding().
- // Specifically, the super constructor can not be called after the
- // mutation of the `config` argument.
- let batchSize = null;
- if (args.batchSize != null) {
- batchSize = args.batchSize;
- }
- if (args.inputLength == null) {
- // Fix super-constructor to what it would have done if
- // 'config.inputShape' were (None, )
- this.batchInputShape = [batchSize, null];
- }
- else {
- // Fix super-constructor to what it would have done if
- // 'config.inputShape' were (config.inputLength, )
- this.batchInputShape =
- [batchSize].concat(toList(args.inputLength));
- }
- }
- this.inputDim = args.inputDim;
- assertPositiveInteger(this.inputDim, 'inputDim');
- this.outputDim = args.outputDim;
- assertPositiveInteger(this.outputDim, 'outputDim');
- this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
- this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
- this.activityRegularizer = getRegularizer(args.activityRegularizer);
- this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
- this.maskZero = args.maskZero;
- this.supportsMasking = args.maskZero;
- this.inputLength = args.inputLength;
- }
- build(inputShape) {
- this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
- this.built = true;
- }
- // Override warnOnIncompatibleInputShape because an embedding layer allows
- // the input to have varying ranks.
- warnOnIncompatibleInputShape(inputShape) { }
- computeMask(inputs, mask) {
- return tidy(() => {
- if (!this.maskZero) {
- return null;
- }
- else {
- inputs = getExactlyOneTensor(inputs);
- return notEqual(inputs, zerosLike(inputs));
- }
- });
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- if (this.inputLength == null) {
- return [...inputShape, this.outputDim];
- }
- // inputLength can be an array if input is 3D or higher.
- const inLens = toList(this.inputLength);
- if (inLens.length !== inputShape.length - 1) {
- throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
- `input shape has shape ${inputShape}`);
- }
- else {
- let i = 0;
- for (let k = 0; k < inLens.length; ++k) {
- const s1 = inLens[k];
- const s2 = inputShape[k + 1];
- if ((s1 != null) && (s2 != null) && (s1 !== s2)) {
- throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
- `input shape has shape ${inputShape}`);
- }
- else if (s1 == null) {
- inLens[i] = s2;
- }
- i++;
- }
- }
- return [inputShape[0], ...inLens, this.outputDim];
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- // Embedding layer accepts only a single input.
- let input = getExactlyOneTensor(inputs);
- if (input.dtype !== 'int32') {
- input = cast$1(input, 'int32');
- }
- const output = gather$1(this.embeddings.read(), input.as1D());
- return output.reshape(getExactlyOneShape(this.computeOutputShape(input.shape)));
- });
- }
- getConfig() {
- const config = {
- inputDim: this.inputDim,
- outputDim: this.outputDim,
- embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
- embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
- activityRegularizer: serializeRegularizer(this.activityRegularizer),
- embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
- maskZero: this.maskZero,
- inputLength: this.inputLength
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Embedding.className = 'Embedding';
- registerClass(Embedding);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Generic Merge layer for element-wise merge functions.
- *
- * Used to implement `Sum`, `Average`, `Concatenate`, etc.
- */
- class Merge extends Layer {
- constructor(args) {
- super(args || {});
- this.supportsMasking = true;
- }
- /**
- * Logic for merging multiple tensors, to be overridden by subclasses.
- * @param inputs
- */
- mergeFunction(inputs) {
- throw new NotImplementedError();
- }
- /**
- * Computes the shape of the result of an elementwise operation.
- *
- * @param shape1: Shape of the first tensor.
- * @param shape2: Shape of the second tensor.
- * @returns Expected output shape when an elementwise operation is carried
- * out on 2 tensors with shapes `shape1` and `shape2`.
- * @throws ValueError: If `shape1` and `shape2` are not compatible for
- * element-wise operations.
- */
- computeElementwiseOpOutputShape(shape1, shape2) {
- if (shape1 == null || shape2 == null) {
- return null;
- }
- else if (shape1.length < shape2.length) {
- return this.computeElementwiseOpOutputShape(shape2, shape1);
- }
- else if (shape2.length === 0) {
- return shape1;
- }
- const outputShape = shape1.slice(0, shape1.length - shape2.length);
- for (let k = 0; k < shape2.length; ++k) {
- const i = shape1[shape1.length - shape2.length + k];
- const j = shape2[k];
- if (i == null || j == null || i < 0 || j < 0) {
- outputShape.push(null);
- }
- else if (i === 1) {
- outputShape.push(j);
- }
- else if (j === 1) {
- outputShape.push(i);
- }
- else {
- if (i !== j) {
- throw new ValueError('Operands could not be broadcast together with shapes ' +
- JSON.stringify(shape1) + ' ' + JSON.stringify(shape2));
- }
- outputShape.push(i);
- }
- }
- return outputShape;
- }
- build(inputShape) {
- // Used purely for shape validation.
- if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
- // Make sure that inputShape is an Array of shape.
- inputShape = [getExactlyOneShape(inputShape)];
- }
- inputShape = inputShape;
- if (inputShape.length < 2) {
- throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' +
- ` Got ${inputShape.length} input(s).`);
- }
- // Make sure that there is at most one unique batch size among the input
- // shapes.
- let batchSizes = [];
- for (const shape of inputShape) {
- if (shape != null && shape[0] !== null) {
- batchSizes.push(shape[0]);
- }
- }
- batchSizes = unique$1(batchSizes);
- if (batchSizes.length > 1) {
- throw new ValueError(`Can not merge tensors with different batch sizes. ` +
- `Got tensors with shapes: ${JSON.stringify(inputShape)}.`);
- }
- let outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
- for (let i = 1; i < inputShape.length; ++i) {
- const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
- outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
- }
- // If the inputs have different ranks, we have to reshape them to make them
- // broadcastable.
- const allRanks = inputShape.map(shape => shape.length);
- if (inputShape.indexOf(null) === -1 &&
- unique$1(allRanks).length === 1) {
- this.reshapeRequired = false;
- }
- else {
- this.reshapeRequired = true;
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- inputs = inputs;
- if (this.reshapeRequired) {
- const reshapedInputs = [];
- const inputDims = inputs.map(input => input.rank);
- if (inputDims.indexOf(null) === -1) {
- // If ranks of all inputs are available, we simply expand each of them
- // at axis=1 until all of them have the same rank.
- const maxNDim = max$1(inputDims);
- for (let x of inputs) {
- const xNDim = x.rank;
- for (let k = 0; k < maxNDim - xNDim; ++k) {
- x = expandDims$1(x, 1);
- }
- reshapedInputs.push(x);
- }
- return this.mergeFunction(reshapedInputs);
- }
- else {
- // Transpose all inputs so that batch size is the last dimension.
- // [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize]
- let transposed = false;
- for (const x of inputs) {
- const xNDim = x.rank;
- if (xNDim == null) {
- const xShape = x.shape;
- const batchSize = xShape[0];
- const newShape = xShape.slice(1).concat([batchSize]);
- let xTransposed = x.reshape([batchSize].concat(arrayProd(xShape.slice(1))));
- xTransposed = transpose(xTransposed, [1, 0]);
- xTransposed = xTransposed.reshape(newShape);
- reshapedInputs.push(xTransposed);
- transposed = true;
- }
- else if (xNDim > 1) {
- const dims = range$1(1, xNDim).concat([0]);
- reshapedInputs.push(transpose(x, dims));
- transposed = true;
- }
- else {
- // We don't transpose inputs if they are 1D vectors or scalars.
- reshapedInputs.push(x);
- }
- }
- let y = this.mergeFunction(reshapedInputs);
- const yNDim = y.rank;
- if (transposed) {
- // If inputs have been transposed, we have to transpose the output
- // too.
- if (yNDim == null) {
- const yShape = y.shape;
- const yNDim = yShape.length;
- const batchSize = yShape[yNDim - 1];
- const newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
- y = transpose(y.reshape([-1, batchSize]), [1, 0])
- .reshape(newShape);
- }
- else if (yNDim > 1) {
- const dims = [yNDim - 1].concat(range$1(0, yNDim - 1));
- y = transpose(y, dims);
- }
- }
- return y;
- }
- }
- else {
- return this.mergeFunction(inputs);
- }
- });
- }
- computeOutputShape(inputShape) {
- inputShape = inputShape;
- let outputShape;
- if (inputShape[0] == null) {
- outputShape = null;
- }
- else {
- outputShape = inputShape[0].slice(1);
- }
- for (let i = 1; i < inputShape.length; ++i) {
- const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
- outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
- }
- let batchSizes = [];
- for (const shape of inputShape) {
- if (shape != null && shape[0] !== null) {
- batchSizes.push(shape[0]);
- }
- }
- batchSizes = unique$1(batchSizes);
- if (batchSizes.length === 1) {
- outputShape = batchSizes.concat(outputShape);
- }
- else {
- outputShape = [null].concat(outputShape);
- }
- return outputShape;
- }
- computeMask(inputs, mask) {
- return tidy(() => {
- if (mask == null) {
- return null;
- }
- if (!Array.isArray(mask)) {
- throw new ValueError('`mask` should be an Array');
- }
- if (!Array.isArray(inputs)) {
- throw new ValueError('`inputs` should be an Array');
- }
- if (mask.length !== inputs.length) {
- throw new ValueError(`The Array 'inputs' and 'mask' are expected to have the same ` +
- `length, but have different lengths ` +
- `(${inputs.length} vs ${mask.length})`);
- }
- if (mask.every(m => m == null)) {
- return null;
- }
- mask = mask.map(m => m == null ? m : expandDims(m, 0));
- let output = mask[0];
- for (let i = 1; i < mask.length - 1; ++i) {
- output = logicalAnd(output, mask[i]);
- }
- return output;
- });
- }
- }
- class Add$1 extends Merge {
- constructor(args) {
- super(args);
- }
- mergeFunction(inputs) {
- return tidy(() => {
- let output = inputs[0].clone();
- for (let i = 1; i < inputs.length; ++i) {
- output = add$1(output, inputs[i]);
- }
- return output;
- });
- }
- }
- /** @nocollapse */
- Add$1.className = 'Add';
- registerClass(Add$1);
- /**
- * Calculate the element-wise sum of inputs, which all have the same shape.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Add` layer, by using no input argument
- * or a single configuration argument. The resultant `Add` layer can then
- * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const addLayer = tf.layers.add();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = addLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = tf.layers.add([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
- * tf.layers.add([input1, input2]).print();
- * // Gives [[11, 22], [33, 44]].
- *
- */
- function add$2(config) {
- if (Array.isArray(config)) {
- const layer = new Add$1({});
- return layer.apply(config);
- }
- else {
- return new Add$1(config);
- }
- }
- class Multiply$1 extends Merge {
- constructor(args) {
- super(args);
- }
- mergeFunction(inputs) {
- return tidy(() => {
- let output = inputs[0].clone();
- for (let i = 1; i < inputs.length; ++i) {
- output = mul(output, inputs[i]);
- }
- return output;
- });
- }
- }
- /** @nocollapse */
- Multiply$1.className = 'Multiply';
- registerClass(Multiply$1);
- /**
- * Calculate the element-wise product of inputs, which all have the same shape.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Multiply` layer, by using no input argument
- * or a single configuration argument. The resultant `Multiply` layer can
- * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const multiplyLayer = tf.layers.multiply();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = multiplyLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = tf.layers.multiply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
- * tf.layers.multiply([input1, input2]).print();
- * // Gives [[10, 40], [90, 160]].
- *
- */
- function multiply(config) {
- if (Array.isArray(config)) {
- const layer = new Multiply$1({});
- return layer.apply(config);
- }
- else {
- return new Multiply$1(config);
- }
- }
- class Average extends Merge {
- constructor(args) {
- super(args);
- }
- mergeFunction(inputs) {
- return tidy(() => {
- let output = inputs[0].clone();
- for (let i = 1; i < inputs.length; ++i) {
- output = add$1(output, inputs[i]);
- }
- return mul(1 / inputs.length, output);
- });
- }
- }
- /** @nocollapse */
- Average.className = 'Average';
- registerClass(Average);
- /**
- * Calculate the element-wise arithmetic mean of inputs, which all have the same
- * shape.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Average` layer, by using no input argument
- * or a single configuration argument. The resultant `Average` layer can then
- * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const averageLayer = tf.layers.average();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = averageLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = tf.layers.average([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
- * tf.layers.average([input1, input2]).print();
- * // Gives [[5.5, 11], [16.5, 22]].
- *
- */
- function average(config) {
- if (Array.isArray(config)) {
- const layer = new Average({});
- return layer.apply(config);
- }
- else {
- return new Average(config);
- }
- }
- class Maximum$1 extends Merge {
- constructor(args) {
- super(args);
- }
- mergeFunction(inputs) {
- return tidy(() => {
- let output = inputs[0];
- for (let i = 1; i < inputs.length; ++i) {
- output = maximum(output, inputs[i]);
- }
- return output;
- });
- }
- }
- /** @nocollapse */
- Maximum$1.className = 'Maximum';
- registerClass(Maximum$1);
- /**
- * Calculate the element-wise maximum of inputs, which all have the same shape.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Maximum` layer, by using no input argument
- * or a single configuration argument. The resultant `Maximum` layer can then
- * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const maximumLayer = tf.layers.maximum();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = maximumLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = tf.layers.maximum([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
- * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
- * tf.layers.maximum([input1, input2]).print();
- * // Gives [[10, 20], [30, 40]].
- *
- */
- function maximum$1(config) {
- if (Array.isArray(config)) {
- const layer = new Maximum$1({});
- return layer.apply(config);
- }
- else {
- return new Maximum$1(config);
- }
- }
- class Minimum$1 extends Merge {
- constructor(args) {
- super(args);
- }
- mergeFunction(inputs) {
- return tidy(() => {
- let output = inputs[0];
- for (let i = 1; i < inputs.length; ++i) {
- output = minimum(output, inputs[i]);
- }
- return output;
- });
- }
- }
- /** @nocollapse */
- Minimum$1.className = 'Minimum';
- registerClass(Minimum$1);
- /**
- * Calculate the element-wise minimum of inputs, which all have the same shape.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Minimum` layer, by using no input argument
- * or a single configuration argument. The resultant `Minimum` layer can then
- * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const minimumLayer = tf.layers.minimum();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = minimumLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const output = tf.layers.minimum([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
- * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
- * tf.layers.minimum([input1, input2]).print();
- * // Gives [[1, 2], [3, 4]].
- *
- */
- function minimum$1(config) {
- if (Array.isArray(config)) {
- const layer = new Minimum$1({});
- return layer.apply(config);
- }
- else {
- return new Minimum$1(config);
- }
- }
- class Concatenate extends Merge {
- constructor(args) {
- super(args);
- this.DEFAULT_AXIS = -1;
- if (args == null) {
- args = {};
- }
- this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
- this.supportsMasking = true;
- this.reshapeRequired = false;
- }
- build(inputShape) {
- // Used purely for shape validation.]
- if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) ||
- inputShape.length === 1) {
- throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' +
- 'inputs');
- }
- inputShape = inputShape;
- let allNoneShape = true;
- for (const shape of inputShape) {
- if (shape != null) {
- allNoneShape = false;
- break;
- }
- }
- if (allNoneShape) {
- return;
- }
- const shapeSet = [];
- for (let i = 0; i < inputShape.length; ++i) {
- const shapeWithoutConcatAxis = inputShape[i].slice();
- shapeWithoutConcatAxis.splice(this.axis, 1);
- let exists = false;
- for (const shape of shapeSet) {
- if (arraysEqual(shape, shapeWithoutConcatAxis)) {
- exists = true;
- break;
- }
- }
- if (!exists) {
- shapeSet.push(shapeWithoutConcatAxis);
- }
- }
- if (shapeSet.length > 1) {
- throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' +
- 'except for the concat axis. Got input shapes: ' +
- JSON.stringify(inputShape));
- }
- }
- mergeFunction(inputs) {
- return tidy(() => {
- return concatenate(inputs, this.axis);
- });
- }
- computeOutputShape(inputShape) {
- if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
- throw new ValueError('A `Concatenate` layer should be called on a list of inputs.');
- }
- const inputShapes = inputShape;
- const outputShape = inputShapes[0].slice();
- const axis = this.axis < 0 ? outputShape.length + this.axis : this.axis;
- // Porting Note: the line above is because TypeScript doesn't support
- // negative indices.
- for (const shape of inputShapes.slice(1)) {
- if (outputShape[axis] == null || shape[axis] == null) {
- outputShape[axis] = null;
- break;
- }
- outputShape[axis] += shape[axis];
- }
- return outputShape;
- }
- computeMask(inputs, mask) {
- if (mask == null) {
- return null;
- }
- if (!Array.isArray(mask)) {
- throw new ValueError('`mask` should be an array for Concatenate');
- }
- if (!Array.isArray(inputs)) {
- throw new ValueError('`inputs` should be an array for Concatenate');
- }
- if (mask.length !== inputs.length) {
- throw new ValueError(`Mismatch in the length of mask (${mask.length}) ` +
- `and the legnth of inputs (${inputs.length})`);
- }
- return tidy(() => {
- let allNullMasks = true;
- mask.forEach(m => {
- if (m != null) {
- allNullMasks = false;
- return;
- }
- });
- if (allNullMasks) {
- return null;
- }
- const outputMasks = [];
- for (let i = 0; i < inputs.length; ++i) {
- if (mask[i] == null) {
- // Input is unmasked. Append all 1's to masks.
- outputMasks.push(onesLike(inputs[i]).asType('bool'));
- }
- else if (mask[i].rank < inputs[i].rank) {
- // Mask is smaller than the input, expand it.
- outputMasks.push(expandDims(mask[i], -1));
- }
- else {
- outputMasks.push(mask[i]);
- }
- }
- const concatenatedMasks = concat(outputMasks, this.axis);
- return all(concatenatedMasks, -1, false);
- });
- }
- getConfig() {
- const config = {
- 'axis': this.axis,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Concatenate.className = 'Concatenate';
- registerClass(Concatenate);
- /**
- * Concatenate an `Array` of inputs.
- *
- * This function can be invoked in three ways.
- *
- * 1. Construct an instance of `Concatenate` layer, by using no input argument
- * or a single configuration argument. The resultant `Concatenate` layer can
- * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
- *
- * ```js
- * const concatLayer = tf.layers.concatenate();
- *
- * // The layer can be applied to inputs.
- * const input1 = tf.input({shape: [2, 3]});
- * const input2 = tf.input({shape: [2, 4]});
- * const output = concatLayer.apply([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 7], with the first dimension as the undetermined batch
- * // dimension and the last dimension as the result of concatenating the
- * // last dimensions of the two inputs.
- * ```
- *
- * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.SymbolicTensor`. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 3]});
- * const input2 = tf.input({shape: [2, 4]});
- * const output = tf.layers.concatenate([input1, input2]);
- * console.log(output.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension and the last dimension as the result of concatenating the
- * // last dimensions of the two inputs.
- * ```
- *
- * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
- * an `Layer` object internally and calls its `apply` method on the inputs,
- * generating a new `tf.Tensor` as the result of the computation. For
- * example:
- *
- * ```js
- * const input1 = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
- * const input2 = tf.tensor2d([[10, 20], [30, 40]], [2, 2]);
- * tf.layers.concatenate([input1, input2]).print();
- * // Gives [[1, 2, 10, 20], [3, 4, 30, 40]].
- *
- */
- function concatenate$1(config) {
- if (Array.isArray(config)) {
- const layer = new Concatenate({});
- return layer.apply(config);
- }
- else {
- return new Concatenate(config);
- }
- }
- /**
- * Interpretable potentially negative axis index.
- *
- * For example, given axis = -1, and dim = 3, this function will return 2.
- *
- * @param axis The axis index, may be a positive, zero or negative integer.
- * @param dim Total number of dimensions, a positive integer.
- * @returns A non-negative axis index equivalent to the input `axis`.
- */
- function interpretAxis(axis, dim) {
- while (axis < 0) {
- axis += dim;
- }
- return axis;
- }
- function batchDot(x, y, axes) {
- if (x.shape.length > 3 || y.shape.length > 3) {
- throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet');
- }
- assert(x.shape.length >= 2, () => `batchDot requires the rank of x to be >= 2, ` +
- `but got ${x.shape.length}`);
- assert(x.shape.length >= 2, () => `batchDot requires the rank of y to be >= 2, ` +
- `but got ${y.shape.length}`);
- if (typeof axes === 'number') {
- axes = [axes, axes];
- }
- if (x.dtype === 'complex64' || y.dtype === 'complex64') {
- throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.');
- }
- const xNDim = x.shape.length;
- const yNDim = y.shape.length;
- if (axes == null) {
- // Behave like batchMatmul by default.
- axes = [xNDim - 1, yNDim - 2];
- }
- const axesArray = axes;
- return tidy(() => {
- let diff;
- if (xNDim > yNDim) {
- diff = xNDim - yNDim;
- const diffShape = [];
- for (let i = 0; i < diff; ++i) {
- diffShape.push(1);
- }
- y = y.reshape(y.shape.concat(diffShape));
- }
- else if (yNDim > xNDim) {
- diff = yNDim - xNDim;
- const diffShape = [];
- for (let i = 0; i < diff; ++i) {
- diffShape.push(1);
- }
- x = x.reshape(x.shape.concat(diffShape));
- }
- else {
- diff = 0;
- }
- let out;
- if (x.shape.length === 2 && y.shape.length === 2) {
- if (axesArray[0] === axesArray[1]) {
- out = x.mul(y).sum(axesArray[0]);
- }
- else {
- out = x.transpose([1, 0]).mul(y).sum(axesArray[1]);
- }
- }
- else {
- const adjX = axesArray[0] !== x.shape.length - 1;
- const adjY = axesArray[1] === y.shape.length - 1;
- out = x.matMul(y, adjX, adjY);
- }
- if (diff > 0) {
- let idx;
- if (xNDim > yNDim) {
- idx = xNDim + yNDim - 3;
- }
- else {
- idx = xNDim - 1;
- }
- const squeezeAxes = [];
- for (let i = idx; i < idx + diff; ++i) {
- squeezeAxes.push(i);
- }
- out = out.squeeze(squeezeAxes);
- }
- if (out.shape.length === 1) {
- out = out.expandDims(1);
- }
- return out;
- });
- }
- class Dot extends Merge {
- constructor(args) {
- super(args);
- this.axes = args.axes;
- this.normalize = args.normalize == null ? false : args.normalize;
- this.supportsMasking = true;
- this.reshapeRequired = false;
- }
- build(inputShape) {
- assert(Array.isArray(inputShape) && inputShape.length === 2 &&
- Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
- const shape1 = inputShape[0];
- const shape2 = inputShape[1];
- if (shape1.length > 3 || shape2.length > 3) {
- throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
- }
- const axes = this.interpretAxes(shape1, shape2);
- if (shape1[axes[0]] !== shape2[axes[1]]) {
- throw new ValueError(`Dimension incompatibility: ` +
- `${shape1[axes[0]]} !== ${shape2[axes[1]]}`);
- }
- }
- mergeFunction(inputs) {
- if (inputs.length !== 2) {
- throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' +
- `but received ${inputs.length} input(s).`);
- }
- let x1 = inputs[0];
- let x2 = inputs[1];
- let axes;
- if (!Array.isArray(this.axes)) {
- axes = [
- interpretAxis(this.axes, x1.shape.length),
- interpretAxis(this.axes, x2.shape.length)
- ];
- }
- else {
- axes = this.axes.map((axis, i) => interpretAxis(axis, inputs[i].shape.length));
- }
- if (this.normalize) {
- x1 = l2Normalize(x1, axes[0]);
- x2 = l2Normalize(x2, axes[1]);
- }
- return batchDot(x1, x2, axes);
- }
- interpretAxes(shape1, shape2) {
- let axes;
- if (!Array.isArray(this.axes)) {
- // `this.axes` is a single integer.
- axes = [
- interpretAxis(this.axes, shape1.length),
- interpretAxis(this.axes, shape2.length)
- ];
- }
- else {
- // `this.axes` is an Array of integers.
- axes = this.axes;
- }
- return axes;
- }
- computeOutputShape(inputShape) {
- assert(Array.isArray(inputShape) && inputShape.length === 2 &&
- Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
- const shape1 = inputShape[0].slice();
- const shape2 = inputShape[1].slice();
- if (shape1.length > 3 || shape2.length > 3) {
- throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
- }
- const axes = this.interpretAxes(shape1, shape2);
- shape1.splice(axes[0], 1);
- shape2.splice(axes[1], 1);
- shape2.splice(0, 1);
- const outputShape = shape1.concat(shape2);
- if (outputShape.length === 1) {
- outputShape.push(1);
- }
- return outputShape;
- }
- computeMask(inputs, mask) {
- return null;
- }
- getConfig() {
- const config = {
- 'axes': this.axes,
- 'normalize': this.normalize
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- Dot.className = 'Dot';
- registerClass(Dot);
- // TODO(cais): Add functional interfaces for the merge layers.
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class GaussianNoise extends Layer {
- constructor(args) {
- super(args);
- this.supportsMasking = true;
- this.stddev = args.stddev;
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = { stddev: this.stddev };
- Object.assign(config, baseConfig);
- return config;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- const noised = () => randomNormal$1(input.shape, 0, this.stddev).add(input);
- const output = inTrainPhase(noised, () => input, kwargs['training'] || false);
- return output;
- });
- }
- }
- /** @nocollapse */
- GaussianNoise.className = 'GaussianNoise';
- registerClass(GaussianNoise);
- class GaussianDropout extends Layer {
- constructor(args) {
- super(args);
- this.supportsMasking = true;
- this.rate = args.rate;
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = { rate: this.rate };
- Object.assign(config, baseConfig);
- return config;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- const input = getExactlyOneTensor(inputs);
- if (this.rate > 0 && this.rate < 1) {
- const noised = () => {
- const stddev = Math.sqrt(this.rate / (1 - this.rate));
- return input.mul(randomNormal$1(input.shape, 1, stddev));
- };
- return inTrainPhase(noised, () => input, kwargs['training'] || false);
- }
- return input;
- });
- }
- }
- /** @nocollapse */
- GaussianDropout.className = 'GaussianDropout';
- registerClass(GaussianDropout);
- /**
- * Applies Alpha Dropout to the input.
- *
- * As it is a regularization layer, it is only active at training time.
- *
- * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
- * to their original values, in order to ensure the self-normalizing property
- * even after this dropout.
- * Alpha Dropout fits well to Scaled Exponential Linear Units
- * by randomly setting activations to the negative saturation value.
- *
- * Arguments:
- * - `rate`: float, drop probability (as with `Dropout`).
- * The multiplicative noise will have
- * standard deviation `sqrt(rate / (1 - rate))`.
- * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
- * shape for randomly generated keep/drop flags.
- *
- * Input shape:
- * Arbitrary. Use the keyword argument `inputShape`
- * (tuple of integers, does not include the samples axis)
- * when using this layer as the first layer in a model.
- *
- * Output shape:
- * Same shape as input.
- *
- * References:
- * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- */
- class AlphaDropout extends Layer {
- constructor(args) {
- super(args);
- this.supportsMasking = true;
- this.rate = args.rate;
- this.noiseShape = args.noiseShape;
- }
- _getNoiseShape(inputs) {
- return this.noiseShape || getExactlyOneTensor(inputs).shape;
- }
- computeOutputShape(inputShape) {
- return inputShape;
- }
- getConfig() {
- const baseConfig = super.getConfig();
- const config = { rate: this.rate };
- Object.assign(config, baseConfig);
- return config;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- if (this.rate < 1 && this.rate > 0) {
- const noiseShape = this._getNoiseShape(inputs);
- const droppedInputs = () => {
- const input = getExactlyOneTensor(inputs);
- const alpha = 1.6732632423543772848170429916717;
- const scale = 1.0507009873554804934193349852946;
- const alphaP = -alpha * scale;
- let keptIdx = greaterEqual(randomUniform(noiseShape), this.rate);
- keptIdx = cast$1(keptIdx, 'float32'); // get default dtype.
- // Get affine transformation params.
- const a = ((1 - this.rate) * (1 + this.rate * alphaP ** 2)) ** -0.5;
- const b = -a * alphaP * this.rate;
- // Apply mask.
- const x = input.mul(keptIdx).add(keptIdx.add(-1).mul(alphaP));
- return x.mul(a).add(b);
- };
- return inTrainPhase(droppedInputs, () => getExactlyOneTensor(inputs), kwargs['training'] || false);
- }
- return inputs;
- });
- }
- }
- /** @nocollapse */
- AlphaDropout.className = 'AlphaDropout';
- registerClass(AlphaDropout);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Applies batch normalization on x given mean, var, beta and gamma.
- *
- * I.e. returns:
- * `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
- *
- * @param x Input tensor.
- * @param mean Mean of batch.
- * @param variance Variance of batch.
- * @param beta Tensor with which to center the input.
- * @param gamma Tensor by which to scale the input.
- * @param epsilon Fuzz factor.
- * @returns The result of the batch normalization.
- */
- function batchNormalization(x, mean, variance, beta, gamma, epsilon = 1e-3) {
- let out;
- if (x.rank === 2) {
- out = batchNorm2d(x, mean, variance, beta, gamma, epsilon);
- }
- else if (x.rank === 3) {
- // TODO(cais): Check rank; give proper error message.
- out = batchNorm3d(x, mean, variance, beta, gamma, epsilon);
- }
- else if (x.rank === 4) {
- out = batchNorm4d(x, mean, variance, beta, gamma, epsilon);
- }
- else {
- throw new NotImplementedError(`batchNormalization is not implemented for array of rank ${x.rank} ` +
- `yet`);
- }
- return out;
- }
- /**
- * Non-broadcasting batch normalization for use in training (not inference).
- *
- * The input is normalized to zero mean and unit variance along the
- * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
- * The result of that is returned as the first element
- * of the returned `Array`. The other two elements are the mean and variance,
- * respectively.
- *
- * @param x Input tensor to be normalized.
- * @param gamma Tensor by which to scale the input.
- * @param beta Tensor by which to center the input.
- * @param reductionAxes Axes over which to normalize.
- * @param epsilon Fuzz factor.
- * @returns An `Array` of three `Tensors`:
- * [normalized tensor, mean of input, variance of input].
- */
- function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
- return tidy(() => {
- const meanAndVariance = moments(x, reductionAxes);
- const mean = meanAndVariance.mean;
- const variance = meanAndVariance.variance;
- const normed = batchNormalization(x, mean, variance, beta, gamma, epsilon);
- return [normed, mean, variance];
- });
- }
- /**
- * Broadcasting batch normalization for use in training (not inference).
- *
- * The input is normalized to zero mean and unit variance along the
- * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
- * The result of that is returned as the first element
- * of the returned `Array`. The other two elements are the mean and variance,
- * respectively.
- *
- * @param x Input tensor to be normalized.
- * @param gamma Tensor by which to scale the input.
- * @param beta Tensor by which to center the input.
- * @param reductionAxes Axes over which to normalize.
- * @param epsilon Fuzz factor.
- * @returns An `Array` of three `Tensors`:
- * [normalized tensor, mean of input, variance of input].
- */
- function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
- return tidy(() => {
- const meanAndVariance = moments(x, reductionAxes);
- const mean = meanAndVariance.mean;
- const variance = meanAndVariance.variance;
- const targetShape = [];
- for (const axis of range$1(0, x.rank)) {
- if (reductionAxes.indexOf(axis) !== -1) {
- targetShape.push(1);
- }
- else {
- targetShape.push(x.shape[axis]);
- }
- }
- const broadcastMean = mean.reshape(targetShape);
- const broadcastVariance = variance.reshape(targetShape);
- const broadcastGamma = gamma == null ? null : gamma.reshape(targetShape);
- const broadcastBeta = beta == null ? null : beta.reshape(targetShape);
- const normed = batchNormalization(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon);
- return [normed, mean, variance];
- });
- }
- /**
- * Batch normalization for use in training (not inference).
- *
- * @param x Input tensor to be normalized.
- * @param gamma Tensor by which to scale the input.
- * @param beta Tensor by which to center the input.
- * @param reductionAxes Axes over which to normalize.
- * @param epsilon Fuzz factor.
- * @returns An `Array` of three `Tensors`:
- * [normalized tensor, mean of input, variance of input].
- */
- function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
- if (arraysEqual(reductionAxes.slice().sort(), range$1(0, x.rank - 1))) {
- return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
- }
- else {
- return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
- }
- }
- class BatchNormalization extends Layer {
- constructor(args) {
- if (args == null) {
- args = {};
- }
- super(args);
- this.supportsMasking = true;
- this.axis = args.axis == null ? -1 : args.axis;
- this.momentum = args.momentum == null ? 0.99 : args.momentum;
- this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
- this.center = args.center == null ? true : args.center;
- this.scale = args.scale == null ? true : args.scale;
- this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
- this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
- this.movingMeanInitializer =
- getInitializer(args.movingMeanInitializer || 'zeros');
- this.movingVarianceInitializer =
- getInitializer(args.movingVarianceInitializer || 'ones');
- this.betaConstraint = getConstraint(args.betaConstraint);
- this.gammaConstraint = getConstraint(args.gammaConstraint);
- this.betaRegularizer = getRegularizer(args.betaRegularizer);
- this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const axis = this.axis >= 0 ? this.axis : (this.axis + inputShape.length);
- const dim = inputShape[axis];
- if (dim == null) {
- throw new ValueError(`Axis ${axis} of input tensor should have a defined dimension but ` +
- `the layer received an input with shape ` +
- `${JSON.stringify(inputShape)}.`);
- }
- this.inputSpec =
- [new InputSpec({ ndim: inputShape.length, axes: { [axis]: dim } })];
- const shape = [dim];
- if (this.scale) {
- this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
- }
- if (this.center) {
- this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
- }
- this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false);
- this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false);
- this.built = true;
- }
- call(inputs, kwargs) {
- return tidy(() => {
- const training = kwargs['training'] == null ? false : kwargs['training'];
- const input = getExactlyOneTensor(inputs);
- const inputShape = input.shape;
- const ndim = inputShape.length;
- const reductionAxes = range$1(0, ndim);
- const axis = this.axis >= 0 ? this.axis : (this.axis + ndim);
- reductionAxes.splice(axis, 1);
- const broadcastShape = pyListRepeat(1, ndim);
- broadcastShape[axis] = inputShape[axis];
- const sortedReductionAxes = reductionAxes.slice();
- sortedReductionAxes.sort();
- const needsBroadcasting = !arraysEqual(sortedReductionAxes, range$1(0, ndim).slice(0, ndim - 1));
- const normalizeInference = () => {
- if (needsBroadcasting) {
- const broadcastMovingMean = this.movingMean.read().reshape(broadcastShape);
- const broadcastMovingVariance = this.movingVariance.read().reshape(broadcastShape);
- const broadcastBeta = this.center ? this.beta.read().reshape(broadcastShape) : null;
- const broadcastGamma = this.scale ? this.gamma.read().reshape(broadcastShape) : null;
- return batchNormalization(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, this.epsilon);
- }
- else {
- return batchNormalization(input, this.movingMean.read(), this.movingVariance.read(), this.beta == null ? null : this.beta.read(), this.gamma == null ? null : this.gamma.read(), this.epsilon);
- }
- };
- if (!training) {
- return normalizeInference();
- }
- const [normedTraining, mean, variance] = normalizeBatchInTraining(input, this.gamma.read(), this.beta.read(), reductionAxes, this.epsilon);
- const doMovingAverage = (variable, value, momentum) => {
- tidy(() => {
- const decay = 1 - momentum;
- const origValue = variable.read();
- const updateDelta = origValue.sub(value).mul(decay);
- variable.write(origValue.sub(updateDelta));
- });
- };
- // Perform updates to moving mean and moving variance for training.
- // Porting Note: In PyKeras, these updates to `movingMean` and
- // `movingAverage` are done as a deferred Graph, added to the `Layer`'s
- // `update`s using the `add_update()` method. Here we do it imperatively
- // and encapsulate the updates in a function that is invoked
- // immediately.
- const updateMovingMeanAndVariance = () => {
- doMovingAverage(this.movingMean, mean, this.momentum);
- doMovingAverage(this.movingVariance, variance, this.momentum);
- };
- updateMovingMeanAndVariance();
- return normedTraining;
- });
- }
- getConfig() {
- const config = {
- axis: this.axis,
- momentum: this.momentum,
- epsilon: this.epsilon,
- center: this.center,
- scale: this.scale,
- betaInitializer: serializeInitializer(this.betaInitializer),
- gammaInitializer: serializeInitializer(this.gammaInitializer),
- movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
- movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
- betaRegularizer: serializeRegularizer(this.betaRegularizer),
- gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
- betaConstraint: serializeConstraint(this.betaConstraint),
- gammaConstraint: serializeConstraint(this.gammaConstraint)
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- BatchNormalization.className = 'BatchNormalization';
- registerClass(BatchNormalization);
- class LayerNormalization extends Layer {
- constructor(args) {
- if (args == null) {
- args = {};
- }
- super(args);
- this.axis = args.axis == null ? -1 : args.axis;
- if (typeof this.axis === 'number') {
- if (!Number.isInteger(this.axis)) {
- throw new Error(`Expected axis to be an integer, but received ${this.axis}`);
- }
- }
- else if (Array.isArray(this.axis)) {
- for (const axis of this.axis) {
- if (!Number.isInteger(axis)) {
- throw new Error(`Expected axis to be an array of integers, ` +
- `but received ${JSON.stringify(this.axis)}`);
- }
- }
- }
- else {
- throw new Error(`Expected axis to be an integer or an array of integers, ` +
- `but received ${JSON.stringify(this.axis)}`);
- }
- this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
- this.center = args.center == null ? true : args.center;
- this.scale = args.scale == null ? true : args.scale;
- this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
- this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
- this.betaRegularizer = getRegularizer(args.betaRegularizer);
- this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
- this.supportsMasking = true;
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const nDims = inputShape.length;
- // Convert axis to array and resolve negatives.
- if (typeof this.axis === 'number') {
- this.axis = [this.axis];
- }
- for (let i = 0; i < this.axis.length; ++i) {
- if (this.axis[i] < 0) {
- this.axis[i] += nDims;
- }
- }
- // Further validate axes.
- for (const axis of this.axis) {
- if (axis < 0 || axis >= nDims) {
- throw new Error(`Invalid axis: ${axis}`);
- }
- }
- if (this.axis.length !== unique$1(this.axis).length) {
- throw new Error(`Found duplicate axes in: ${this.axis}`);
- }
- const paramShape = this.axis.map(axis => inputShape[axis]);
- const trainable = true;
- if (this.scale) {
- this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable);
- }
- else {
- this.gamma = null;
- }
- if (this.center) {
- this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable);
- }
- else {
- this.beta = null;
- }
- this.built = true;
- }
- call(inputs, kwargs) {
- const input = getExactlyOneTensor(inputs);
- const inputShape = input.shape;
- const nDims = inputShape.length;
- return tidy(() => {
- const keepDims = true;
- let { mean, variance } = moments(input, this.axis, keepDims);
- const broadcastShape = pyListRepeat(1, nDims);
- for (const dim of this.axis) {
- broadcastShape[dim] = inputShape[dim];
- }
- const broadcast = (v) => {
- if (v != null && v.shape.length !== nDims &&
- this.axis !== [nDims - 1]) {
- return v.reshape(broadcastShape);
- }
- else {
- return v;
- }
- };
- let scale = broadcast(this.gamma.read());
- let offset = broadcast(this.beta.read());
- // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below
- // is a workaround for the limitation of core's batchNormalization?d don't
- // support broadcasting in their gradients. In addition, the tiling is
- // necessary to ensure correctness on the browser CPU backend regardless
- // of forward or backward computation. Remove this workaround once the
- // limitation is addressed. See .
- const momentsTiling = [];
- const scaleOffsetTiling = [];
- for (let i = 0; i < nDims; ++i) {
- if (this.axis.indexOf(i) !== -1) {
- momentsTiling.push(inputShape[i]);
- scaleOffsetTiling.push(1);
- }
- else {
- momentsTiling.push(1);
- scaleOffsetTiling.push(inputShape[i]);
- }
- }
- mean = mean.tile(momentsTiling);
- variance = variance.tile(momentsTiling);
- scale = scale.tile(scaleOffsetTiling);
- offset = offset.tile(scaleOffsetTiling);
- return batchNormalization(input, mean, variance, offset, scale, this.epsilon);
- });
- }
- getConfig() {
- const config = {
- axis: this.axis,
- epsilon: this.epsilon,
- center: this.center,
- scale: this.scale,
- betaInitializer: serializeInitializer(this.betaInitializer),
- gammaInitializer: serializeInitializer(this.gammaInitializer),
- betaRegularizer: serializeRegularizer(this.betaRegularizer),
- gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- LayerNormalization.className = 'LayerNormalization';
- registerClass(LayerNormalization);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Pads the middle dimension of a 3D tensor.
- *
- * @param x Input `tf.Tensor` to be padded.
- * @param padding `Array` of 2 integers, how many zeros to add at the start and
- * end of the middle dimension (i.e., dimension 1).
- * @return A padded 3D `tf.Tensor`.
- */
- function temporalPadding(x, padding) {
- return tidy(() => {
- if (x.rank !== 3) {
- throw new ValueError(`temporalPadding expects input tensor to be 3-D, but received a ` +
- `${x.rank}-D tensor.`);
- }
- if (padding == null) {
- padding = [1, 1];
- }
- if (padding.length !== 2) {
- throw new ValueError(`temporalPadding expects input padding pattern to be a length-2 ` +
- `array, but received a length-${padding.length} array.`);
- }
- const pattern = [[0, 0], padding, [0, 0]];
- return pad(x, pattern);
- });
- }
- /**
- * Pads the 2nd and 3rd dimensions of a 4D tensor.
- *
- * @param x Input `tf.Tensor` to be padded.
- * @param padding `Array` of two `Array`s, each of which is an `Array` of two
- * integers. The amount of padding at the beginning and end of the 2nd and 3rd
- * dimensions, respectively.
- * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
- * @return Padded 4D `tf.Tensor`.
- */
- function spatial2dPadding(x, padding, dataFormat) {
- return tidy(() => {
- if (x.rank !== 4) {
- throw new ValueError(`temporalPadding expects input tensor to be 4-D, but received a ` +
- `${x.rank}-D tensor.`);
- }
- if (padding == null) {
- padding = [[1, 1], [1, 1]];
- }
- if (padding.length !== 2 || padding[0].length !== 2 ||
- padding[1].length !== 2) {
- throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' +
- 'each of which is an Array of two integers.');
- }
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
- throw new ValueError(`Unknown data format: ${dataFormat}. ` +
- `Supported data formats are 'channelsLast' and 'channelsFirst.`);
- }
- let pattern;
- if (dataFormat === 'channelsFirst') {
- pattern = [[0, 0], [0, 0], padding[0], padding[1]];
- }
- else {
- pattern = [[0, 0], padding[0], padding[1], [0, 0]];
- }
- return pad(x, pattern);
- });
- }
- class ZeroPadding2D extends Layer {
- constructor(args) {
- if (args == null) {
- args = {};
- }
- super(args);
- this.dataFormat =
- args.dataFormat == null ? imageDataFormat() : args.dataFormat;
- // TODO(cais): Maybe refactor the following logic surrounding `padding`
- // into a helper method.
- if (args.padding == null) {
- this.padding = [[1, 1], [1, 1]];
- }
- else if (typeof args.padding === 'number') {
- this.padding =
- [[args.padding, args.padding], [args.padding, args.padding]];
- }
- else {
- args.padding = args.padding;
- if (args.padding.length !== 2) {
- throw new ValueError(`ZeroPadding2D expects padding to be a length-2 array, but ` +
- `received a length-${args.padding.length} array.`);
- }
- let heightPadding;
- let widthPadding;
- if (typeof args.padding[0] === 'number') {
- heightPadding = [args.padding[0], args.padding[0]];
- widthPadding = [args.padding[1], args.padding[1]];
- }
- else {
- args.padding = args.padding;
- if (args.padding[0].length !== 2) {
- throw new ValueError(`ZeroPadding2D expects height padding to be a length-2 array, ` +
- `but received a length-${args.padding[0].length} array.`);
- }
- heightPadding = args.padding[0];
- if (args.padding[1].length !== 2) {
- throw new ValueError(`ZeroPadding2D expects width padding to be a length-2 array, ` +
- `but received a length-${args.padding[1].length} array.`);
- }
- widthPadding = args.padding[1];
- }
- this.padding = [heightPadding, widthPadding];
- }
- this.inputSpec = [new InputSpec({ ndim: 4 })];
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- let rows;
- let cols;
- if (this.dataFormat === 'channelsFirst') {
- if (inputShape[2] != null && inputShape[2] >= 0) {
- rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
- }
- else {
- rows = null;
- }
- if (inputShape[3] != null && inputShape[3] >= 0) {
- cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
- }
- else {
- cols = null;
- }
- return [inputShape[0], inputShape[1], rows, cols];
- }
- else {
- if (inputShape[1] != null && inputShape[1] >= 0) {
- rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
- }
- else {
- rows = null;
- }
- if (inputShape[2] != null && inputShape[2] >= 0) {
- cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
- }
- else {
- cols = null;
- }
- return [inputShape[0], rows, cols, inputShape[3]];
- }
- }
- call(inputs, kwargs) {
- return tidy(() => spatial2dPadding(getExactlyOneTensor(inputs), this.padding, this.dataFormat));
- }
- getConfig() {
- const config = {
- padding: this.padding,
- dataFormat: this.dataFormat,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- /** @nocollapse */
- ZeroPadding2D.className = 'ZeroPadding2D';
- registerClass(ZeroPadding2D);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * 2D pooling.
- * @param x
- * @param poolSize
- * @param stridesdes strides. Defaults to [1, 1].
- * @param padding padding. Defaults to 'valid'.
- * @param dataFormat data format. Defaults to 'channelsLast'.
- * @param poolMode Mode of pooling. Defaults to 'max'.
- * @returns Result of the 2D pooling.
- */
- function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- checkPoolMode(poolMode);
- checkPaddingMode(padding);
- if (strides == null) {
- strides = [1, 1];
- }
- if (padding == null) {
- padding = 'valid';
- }
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- if (poolMode == null) {
- poolMode = 'max';
- }
- // TODO(cais): Remove the preprocessing step once deeplearn.js supports
- // dataFormat as an input argument.
- x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing.
- let y;
- const paddingString = (padding === 'same') ? 'same' : 'valid';
- if (poolMode === 'max') {
- // TODO(cais): Rank check?
- y = maxPool(x, poolSize, strides, paddingString);
- }
- else { // 'avg'
- // TODO(cais): Check the dtype and rank of x and give clear error message
- // if those are incorrect.
- y = avgPool(
- // TODO(cais): Rank check?
- x, poolSize, strides, paddingString);
- }
- if (dataFormat === 'channelsFirst') {
- y = transpose(y, [0, 3, 1, 2]); // NHWC -> NCHW.
- }
- return y;
- });
- }
- /**
- * 3D pooling.
- * @param x
- * @param poolSize. Default to [1, 1, 1].
- * @param strides strides. Defaults to [1, 1, 1].
- * @param padding padding. Defaults to 'valid'.
- * @param dataFormat data format. Defaults to 'channelsLast'.
- * @param poolMode Mode of pooling. Defaults to 'max'.
- * @returns Result of the 3D pooling.
- */
- function pool3d(x, poolSize, strides, padding, dataFormat, poolMode) {
- return tidy(() => {
- checkDataFormat(dataFormat);
- checkPoolMode(poolMode);
- checkPaddingMode(padding);
- if (strides == null) {
- strides = [1, 1, 1];
- }
- if (padding == null) {
- padding = 'valid';
- }
- if (dataFormat == null) {
- dataFormat = imageDataFormat();
- }
- if (poolMode == null) {
- poolMode = 'max';
- }
- // x is NDHWC after preprocessing.
- x = preprocessConv3DInput(x, dataFormat);
- let y;
- const paddingString = (padding === 'same') ? 'same' : 'valid';
- if (poolMode === 'max') {
- y = maxPool3d(x, poolSize, strides, paddingString);
- }
- else { // 'avg'
- y = avgPool3d(x, poolSize, strides, paddingString);
- }
- if (dataFormat === 'channelsFirst') {
- y = transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
- }
- return y;
- });
- }
- /**
- * Abstract class for different pooling 1D layers.
- */
- class Pooling1D extends Layer {
- /**
- *
- * @param args Parameters for the Pooling layer.
- *
- * config.poolSize defaults to 2.
- */
- constructor(args) {
- if (args.poolSize == null) {
- args.poolSize = 2;
- }
- super(args);
- if (typeof args.poolSize === 'number') {
- this.poolSize = [args.poolSize];
- }
- else if (Array.isArray(args.poolSize) &&
- args.poolSize.length === 1 &&
- typeof args.poolSize[0] === 'number') {
- this.poolSize = args.poolSize;
- }
- else {
- throw new ValueError(`poolSize for 1D convolutional layer must be a number or an ` +
- `Array of a single number, but received ` +
- `${JSON.stringify(args.poolSize)}`);
- }
- assertPositiveInteger(this.poolSize, 'poolSize');
- if (args.strides == null) {
- this.strides = this.poolSize;
- }
- else {
- if (typeof args.strides === 'number') {
- this.strides = [args.strides];
- }
- else if (Array.isArray(args.strides) &&
- args.strides.length === 1 &&
- typeof args.strides[0] === 'number') {
- this.strides = args.strides;
- }
- else {
- throw new ValueError(`strides for 1D convolutional layer must be a number or an ` +
- `Array of a single number, but received ` +
- `${JSON.stringify(args.strides)}`);
- }
- }
- assertPositiveInteger(this.strides, 'strides');
- this.padding = args.padding == null ? 'valid' : args.padding;
- checkPaddingMode(this.padding);
- this.inputSpec = [new InputSpec({ ndim: 3 })];
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
- return [inputShape[0], length, inputShape[2]];
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- // Add dummy last dimension.
- inputs = expandDims$1(getExactlyOneTensor(inputs), 2);
- const output = this.poolingFunction(getExactlyOneTensor(inputs), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, 'channelsLast');
- // Remove dummy last dimension.
- return squeeze(output, [2]);
- });
- }
- getConfig() {
- const config = {
- poolSize: this.poolSize,
- padding: this.padding,
- strides: this.strides,
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- class MaxPooling1D extends Pooling1D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
- }
- }
- /** @nocollapse */
- MaxPooling1D.className = 'MaxPooling1D';
- registerClass(MaxPooling1D);
- class AveragePooling1D extends Pooling1D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
- }
- }
- /** @nocollapse */
- AveragePooling1D.className = 'AveragePooling1D';
- registerClass(AveragePooling1D);
- /**
- * Abstract class for different pooling 2D layers.
- */
- class Pooling2D extends Layer {
- constructor(args) {
- if (args.poolSize == null) {
- args.poolSize = [2, 2];
- }
- super(args);
- this.poolSize = Array.isArray(args.poolSize) ?
- args.poolSize :
- [args.poolSize, args.poolSize];
- if (args.strides == null) {
- this.strides = this.poolSize;
- }
- else if (Array.isArray(args.strides)) {
- if (args.strides.length !== 2) {
- throw new ValueError(`If the strides property of a 2D pooling layer is an Array, ` +
- `it is expected to have a length of 2, but received length ` +
- `${args.strides.length}.`);
- }
- this.strides = args.strides;
- }
- else {
- // `config.strides` is a number.
- this.strides = [args.strides, args.strides];
- }
- assertPositiveInteger(this.poolSize, 'poolSize');
- assertPositiveInteger(this.strides, 'strides');
- this.padding = args.padding == null ? 'valid' : args.padding;
- this.dataFormat =
- args.dataFormat == null ? 'channelsLast' : args.dataFormat;
- checkDataFormat(this.dataFormat);
- checkPaddingMode(this.padding);
- this.inputSpec = [new InputSpec({ ndim: 4 })];
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- let rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
- let cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
- rows =
- convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
- cols =
- convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
- if (this.dataFormat === 'channelsFirst') {
- return [inputShape[0], inputShape[1], rows, cols];
- }
- else {
- return [inputShape[0], rows, cols, inputShape[3]];
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
- });
- }
- getConfig() {
- const config = {
- poolSize: this.poolSize,
- padding: this.padding,
- strides: this.strides,
- dataFormat: this.dataFormat
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- class MaxPooling2D extends Pooling2D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
- }
- }
- /** @nocollapse */
- MaxPooling2D.className = 'MaxPooling2D';
- registerClass(MaxPooling2D);
- class AveragePooling2D extends Pooling2D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
- }
- }
- /** @nocollapse */
- AveragePooling2D.className = 'AveragePooling2D';
- registerClass(AveragePooling2D);
- /**
- * Abstract class for different pooling 3D layers.
- */
- class Pooling3D extends Layer {
- constructor(args) {
- if (args.poolSize == null) {
- args.poolSize = [2, 2, 2];
- }
- super(args);
- this.poolSize = Array.isArray(args.poolSize) ?
- args.poolSize :
- [args.poolSize, args.poolSize, args.poolSize];
- if (args.strides == null) {
- this.strides = this.poolSize;
- }
- else if (Array.isArray(args.strides)) {
- if (args.strides.length !== 3) {
- throw new ValueError(`If the strides property of a 3D pooling layer is an Array, ` +
- `it is expected to have a length of 3, but received length ` +
- `${args.strides.length}.`);
- }
- this.strides = args.strides;
- }
- else {
- // `config.strides` is a number.
- this.strides = [args.strides, args.strides, args.strides];
- }
- assertPositiveInteger(this.poolSize, 'poolSize');
- assertPositiveInteger(this.strides, 'strides');
- this.padding = args.padding == null ? 'valid' : args.padding;
- this.dataFormat =
- args.dataFormat == null ? 'channelsLast' : args.dataFormat;
- checkDataFormat(this.dataFormat);
- checkPaddingMode(this.padding);
- this.inputSpec = [new InputSpec({ ndim: 5 })];
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- let depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
- let rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
- let cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
- depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
- rows =
- convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
- cols =
- convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
- if (this.dataFormat === 'channelsFirst') {
- return [inputShape[0], inputShape[1], depths, rows, cols];
- }
- else {
- return [inputShape[0], depths, rows, cols, inputShape[4]];
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- this.invokeCallHook(inputs, kwargs);
- return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
- });
- }
- getConfig() {
- const config = {
- poolSize: this.poolSize,
- padding: this.padding,
- strides: this.strides,
- dataFormat: this.dataFormat
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- class MaxPooling3D extends Pooling3D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool3d(inputs, poolSize, strides, padding, dataFormat, 'max');
- }
- }
- /** @nocollapse */
- MaxPooling3D.className = 'MaxPooling3D';
- registerClass(MaxPooling3D);
- class AveragePooling3D extends Pooling3D {
- constructor(args) {
- super(args);
- }
- poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
- checkDataFormat(dataFormat);
- checkPaddingMode(padding);
- return pool3d(inputs, poolSize, strides, padding, dataFormat, 'avg');
- }
- }
- /** @nocollapse */
- AveragePooling3D.className = 'AveragePooling3D';
- registerClass(AveragePooling3D);
- /**
- * Abstract class for different global pooling 1D layers.
- */
- class GlobalPooling1D extends Layer {
- constructor(args) {
- super(args);
- this.inputSpec = [new InputSpec({ ndim: 3 })];
- }
- computeOutputShape(inputShape) {
- return [inputShape[0], inputShape[2]];
- }
- call(inputs, kwargs) {
- throw new NotImplementedError();
- }
- }
- class GlobalAveragePooling1D extends GlobalPooling1D {
- constructor(args) {
- super(args || {});
- }
- call(inputs, kwargs) {
- return tidy(() => {
- const input = getExactlyOneTensor(inputs);
- return mean(input, 1);
- });
- }
- }
- /** @nocollapse */
- GlobalAveragePooling1D.className = 'GlobalAveragePooling1D';
- registerClass(GlobalAveragePooling1D);
- class GlobalMaxPooling1D extends GlobalPooling1D {
- constructor(args) {
- super(args || {});
- }
- call(inputs, kwargs) {
- return tidy(() => {
- const input = getExactlyOneTensor(inputs);
- return max(input, 1);
- });
- }
- }
- /** @nocollapse */
- GlobalMaxPooling1D.className = 'GlobalMaxPooling1D';
- registerClass(GlobalMaxPooling1D);
- /**
- * Abstract class for different global pooling 2D layers.
- */
- class GlobalPooling2D extends Layer {
- constructor(args) {
- super(args);
- this.dataFormat =
- args.dataFormat == null ? 'channelsLast' : args.dataFormat;
- checkDataFormat(this.dataFormat);
- this.inputSpec = [new InputSpec({ ndim: 4 })];
- }
- computeOutputShape(inputShape) {
- inputShape = inputShape;
- if (this.dataFormat === 'channelsLast') {
- return [inputShape[0], inputShape[3]];
- }
- else {
- return [inputShape[0], inputShape[1]];
- }
- }
- call(inputs, kwargs) {
- throw new NotImplementedError();
- }
- getConfig() {
- const config = { dataFormat: this.dataFormat };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- }
- class GlobalAveragePooling2D extends GlobalPooling2D {
- call(inputs, kwargs) {
- return tidy(() => {
- const input = getExactlyOneTensor(inputs);
- if (this.dataFormat === 'channelsLast') {
- return mean(input, [1, 2]);
- }
- else {
- return mean(input, [2, 3]);
- }
- });
- }
- }
- /** @nocollapse */
- GlobalAveragePooling2D.className = 'GlobalAveragePooling2D';
- registerClass(GlobalAveragePooling2D);
- class GlobalMaxPooling2D extends GlobalPooling2D {
- call(inputs, kwargs) {
- return tidy(() => {
- const input = getExactlyOneTensor(inputs);
- if (this.dataFormat === 'channelsLast') {
- return max(input, [1, 2]);
- }
- else {
- return max(input, [2, 3]);
- }
- });
- }
- }
- /** @nocollapse */
- GlobalMaxPooling2D.className = 'GlobalMaxPooling2D';
- registerClass(GlobalMaxPooling2D);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Abstract wrapper base class.
- *
- * Wrappers take another layer and augment it in various ways.
- * Do not use this class as a layer, it is only an abstract base class.
- * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
- */
- class Wrapper extends Layer {
- constructor(args) {
- // Porting Note: In PyKeras, `self.layer` is set prior to the calling
- // `super()`. But we can't do that here due to TypeScript's restriction.
- // See: https://github.com/Microsoft/TypeScript/issues/8277
- // As a result, we have to add checks in `get trainable()` and
- // `set trainable()` below in order to prevent using `this.layer` when
- // its value is `undefined`. The super constructor does use the getter
- // and the setter of `this.layer`.
- super(args);
- this.layer = args.layer;
- }
- build(inputShape) {
- this.built = true;
- }
- // TODO(cais): Implement activityRegularizer getter.
- get trainable() {
- // Porting Note: the check of `this.layer` here is necessary due to the
- // way the `constructor` of this class is written (see Porting Note
- // above).
- if (this.layer != null) {
- return this.layer.trainable;
- }
- else {
- return false;
- }
- }
- set trainable(value) {
- // Porting Note: the check of `this.layer` here is necessary due to the
- // way the `constructor` of this class is written (see Porting Note
- // above).
- if (this.layer != null) {
- this.layer.trainable = value;
- }
- }
- get trainableWeights() {
- return this.layer.trainableWeights;
- }
- // TODO(cais): Implement setter for trainableWeights.
- get nonTrainableWeights() {
- return this.layer.nonTrainableWeights;
- }
- // TODO(cais): Implement setter for nonTrainableWeights.
- get updates() {
- // tslint:disable-next-line:no-any
- return this.layer._updates;
- }
- // TODO(cais): Implement getUpdatesFor().
- get losses() {
- return this.layer.losses;
- }
- // TODO(cais): Implement getLossesFor().
- getWeights() {
- return this.layer.getWeights();
- }
- setWeights(weights) {
- this.layer.setWeights(weights);
- }
- getConfig() {
- const config = {
- 'layer': {
- 'className': this.layer.getClassName(),
- 'config': this.layer.getConfig(),
- }
- };
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- setFastWeightInitDuringBuild(value) {
- super.setFastWeightInitDuringBuild(value);
- if (this.layer != null) {
- this.layer.setFastWeightInitDuringBuild(value);
- }
- }
- /** @nocollapse */
- static fromConfig(cls, config, customObjects = {}) {
- const layerConfig = config['layer'];
- const layer = deserialize(layerConfig, customObjects);
- delete config['layer'];
- const newConfig = { layer };
- Object.assign(newConfig, config);
- return new cls(newConfig);
- }
- }
- class TimeDistributed extends Wrapper {
- constructor(args) {
- super(args);
- this.supportsMasking = true;
- }
- build(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- if (inputShape.length < 3) {
- throw new ValueError(`TimeDistributed layer expects an input shape >= 3D, but received ` +
- `input shape ${JSON.stringify(inputShape)}`);
- }
- this.inputSpec = [{ shape: inputShape }];
- const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
- if (!this.layer.built) {
- this.layer.build(childInputShape);
- this.layer.built = true;
- }
- super.build(inputShape);
- }
- computeOutputShape(inputShape) {
- inputShape = getExactlyOneShape(inputShape);
- const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
- const childOutputShape = this.layer.computeOutputShape(childInputShape);
- const timesteps = inputShape[1];
- return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
- }
- call(inputs, kwargs) {
- return tidy(() => {
- // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
- inputs = getExactlyOneTensor(inputs);
- // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
- // values. Hence the inputs can't have an undetermined first (batch)
- // dimension, which is why we always use the K.rnn approach here.
- const step = (inputs, states) => {
- // TODO(cais): Add useLearningPhase.
- // NOTE(cais): `layer.call` may return a length-1 array of Tensor in
- // some cases (e.g., `layer` is a `Sequential` instance), which is
- // why `getExactlyOneTensor` is used below.
- const output = getExactlyOneTensor(this.layer.call(inputs, kwargs));
- return [output, []];
- };
- const rnnOutputs = rnn(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */);
- const y = rnnOutputs[1];
- // TODO(cais): Add activity regularization.
- // TODO(cais): Add useLearningPhase.
- return y;
- });
- }
- }
- /** @nocollapse */
- TimeDistributed.className = 'TimeDistributed';
- registerClass(TimeDistributed);
- function checkBidirectionalMergeMode(value) {
- checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
- }
- const DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
- class Bidirectional extends Wrapper {
- constructor(args) {
- super(args);
- // Note: When creating `this.forwardLayer`, the original Layer object
- // (`config.layer`) ought to be cloned. This is why we call
- // `getConfig()` followed by `deserialize()`. Without this cloning,
- // the layer names saved during serialization will incorrectly contain
- // the 'forward_' prefix. In Python Keras, this is done using
- // `copy.copy` (shallow copy), which does not have a simple equivalent
- // in JavaScript. JavaScript's `Object.assign()` does not copy
- // methods.
- const layerConfig = args.layer.getConfig();
- const forwDict = {};
- forwDict['className'] = args.layer.getClassName();
- forwDict['config'] = layerConfig;
- this.forwardLayer = deserialize(forwDict);
- layerConfig['goBackwards'] =
- layerConfig['goBackwards'] === true ? false : true;
- const backDict = {};
- backDict['className'] = args.layer.getClassName();
- backDict['config'] = layerConfig;
- this.backwardLayer = deserialize(backDict);
- this.forwardLayer.name = 'forward_' + this.forwardLayer.name;
- this.backwardLayer.name = 'backward_' + this.backwardLayer.name;
- this.mergeMode = args.mergeMode === undefined ?
- DEFAULT_BIDIRECTIONAL_MERGE_MODE :
- args.mergeMode;
- checkBidirectionalMergeMode(this.mergeMode);
- if (args.weights) {
- throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
- }
- this._stateful = args.layer.stateful;
- this.returnSequences = args.layer.returnSequences;
- this.returnState = args.layer.returnState;
- this.supportsMasking = true;
- this._trainable = true;
- this.inputSpec = args.layer.inputSpec;
- this.numConstants = null;
- }
- get trainable() {
- return this._trainable;
- }
- set trainable(value) {
- // Porting Note: the check of `this.layer` here is necessary due to the
- // way the `constructor` of this class is written (see Porting Note
- // above).
- this._trainable = value;
- if (this.forwardLayer != null) {
- this.forwardLayer.trainable = value;
- }
- if (this.backwardLayer != null) {
- this.backwardLayer.trainable = value;
- }
- }
- getWeights() {
- return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
- }
- setWeights(weights) {
- const numWeights = weights.length;
- const numeightsOver2 = Math.floor(numWeights / 2);
- this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
- this.backwardLayer.setWeights(weights.slice(numeightsOver2));
- }
- computeOutputShape(inputShape) {
- let layerShapes = this.forwardLayer.computeOutputShape(inputShape);
- if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
- layerShapes = [layerShapes];
- }
- layerShapes = layerShapes;
- let outputShape;
- let outputShapes;
- let stateShape;
- if (this.returnState) {
- stateShape = layerShapes.slice(1);
- outputShape = layerShapes[0];
- }
- else {
- outputShape = layerShapes[0];
- }
- outputShape = outputShape;
- if (this.mergeMode === 'concat') {
- outputShape[outputShape.length - 1] *= 2;
- outputShapes = [outputShape];
- }
- else if (this.mergeMode == null) {
- outputShapes = [outputShape, outputShape.slice()];
- }
- else {
- outputShapes = [outputShape];
- }
- if (this.returnState) {
- if (this.mergeMode == null) {
- return outputShapes.concat(stateShape).concat(stateShape.slice());
- }
- return [outputShape].concat(stateShape).concat(stateShape.slice());
- }
- return singletonOrArray(outputShapes);
- }
- apply(inputs, kwargs) {
- let initialState = kwargs == null ? null : kwargs['initialState'];
- let constants = kwargs == null ? null : kwargs['constants'];
- if (kwargs == null) {
- kwargs = {};
- }
- const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
- inputs = standardized.inputs;
- initialState = standardized.initialState;
- constants = standardized.constants;
- if (Array.isArray(inputs)) {
- initialState = inputs.slice(1);
- inputs = inputs[0];
- }
- if ((initialState == null || initialState.length === 0) &&
- constants == null) {
- return super.apply(inputs, kwargs);
- }
- const additionalInputs = [];
- const additionalSpecs = [];
- if (initialState != null) {
- const numStates = initialState.length;
- if (numStates % 2 > 0) {
- throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' +
- 'the state should be an Array containing the states of ' +
- 'the underlying RNNs.');
- }
- kwargs['initialState'] = initialState;
- additionalInputs.push(...initialState);
- const stateSpecs = initialState
- .map(state => new InputSpec({ shape: state.shape }));
- this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
- this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
- additionalSpecs.push(...stateSpecs);
- }
- if (constants != null) {
- throw new NotImplementedError('Support for constants in Bidirectional layers is not ' +
- 'implemented yet.');
- }
- const isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
- for (const tensor of additionalInputs) {
- if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
- throw new ValueError('The initial state of a Bidirectional layer cannot be ' +
- 'specified as a mix of symbolic and non-symbolic tensors');
- }
- }
- if (isSymbolicTensor) {
- // Compute the full input and specs, including the states.
- const fullInput = [inputs].concat(additionalInputs);
- const fullInputSpec = this.inputSpec.concat(additionalSpecs);
- // Perform the call temporarily and replace inputSpec.
- // Note: with initial states symbolic calls and non-symbolic calls to
- // this method differ in how the initial states are passed. For
- // symbolic calls, the initial states are passed in the first arg, as
- // an Array of SymbolicTensors; for non-symbolic calls, they are
- // passed in the second arg as a part of the kwargs. Hence the need to
- // temporarily modify inputSpec here.
- // TODO(cais): Make refactoring so that this hacky code below is no
- // longer needed.
- const originalInputSpec = this.inputSpec;
- this.inputSpec = fullInputSpec;
- const output = super.apply(fullInput, kwargs);
- this.inputSpec = originalInputSpec;
- return output;
- }
- else {
- return super.apply(inputs, kwargs);
- }
- }
- call(inputs, kwargs) {
- return tidy(() => {
- const initialState = kwargs['initialState'];
- let y;
- let yRev;
- if (initialState == null) {
- y = this.forwardLayer.call(inputs, kwargs);
- yRev = this.backwardLayer.call(inputs, kwargs);
- }
- else {
- const forwardState = initialState.slice(0, initialState.length / 2);
- const backwardState = initialState.slice(initialState.length / 2);
- y = this.forwardLayer.call(inputs, Object.assign(kwargs, { initialState: forwardState }));
- yRev = this.backwardLayer.call(inputs, Object.assign(kwargs, { initialState: backwardState }));
- }
- let states;
- if (this.returnState) {
- if (Array.isArray(y)) {
- states = y.slice(1).concat(yRev.slice(1));
- }
- else {
- }
- y = y[0];
- yRev = yRev[0];
- }
- if (this.returnSequences) {
- yRev = reverse(yRev, 1);
- }
- let output;
- if (this.mergeMode === 'concat') {
- output = concatenate([y, yRev]);
- }
- else if (this.mergeMode === 'sum') {
- output = add$1(y, yRev);
- }
- else if (this.mergeMode === 'ave') {
- output = mul(.5, add$1(y, yRev));
- }
- else if (this.mergeMode === 'mul') {
- output = mul(y, yRev);
- }
- else if (this.mergeMode == null) {
- output = [y, yRev];
- }
- // TODO(cais): Properly set learning phase.
- if (this.returnState) {
- if (this.mergeMode == null) {
- return output.concat(states);
- }
- return [output].concat(states);
- }
- return output;
- });
- }
- resetStates(states) {
- this.forwardLayer.resetStates();
- this.backwardLayer.resetStates();
- }
- build(inputShape) {
- nameScope(this.forwardLayer.name, () => {
- this.forwardLayer.build(inputShape);
- });
- nameScope(this.backwardLayer.name, () => {
- this.backwardLayer.build(inputShape);
- });
- this.built = true;
- }
- computeMask(inputs, mask) {
- if (Array.isArray(mask)) {
- mask = mask[0];
- }
- let outputMask;
- if (this.returnSequences) {
- if (this.mergeMode == null) {
- outputMask = [mask, mask];
- }
- else {
- outputMask = mask;
- }
- }
- else {
- if (this.mergeMode == null) {
- outputMask = [null, null];
- }
- else {
- outputMask = null;
- }
- }
- if (this.returnState) {
- const states = this.forwardLayer.states;
- const stateMask = states.map(state => null);
- if (Array.isArray(outputMask)) {
- return outputMask.concat(stateMask).concat(stateMask);
- }
- else {
- return [outputMask].concat(stateMask).concat(stateMask);
- }
- }
- else {
- return outputMask;
- }
- }
- get trainableWeights() {
- return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
- }
- get nonTrainableWeights() {
- return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
- }
- // TODO(cais): Implement constraints().
- setFastWeightInitDuringBuild(value) {
- super.setFastWeightInitDuringBuild(value);
- if (this.forwardLayer != null) {
- this.forwardLayer.setFastWeightInitDuringBuild(value);
- }
- if (this.backwardLayer != null) {
- this.backwardLayer.setFastWeightInitDuringBuild(value);
- }
- }
- getConfig() {
- const config = {
- 'mergeMode': this.mergeMode,
- };
- // TODO(cais): Add logic for `numConstants` once the property is added.
- const baseConfig = super.getConfig();
- Object.assign(config, baseConfig);
- return config;
- }
- /** @nocollapse */
- static fromConfig(cls, config) {
- const rnnLayer = deserialize(config['layer']);
- delete config['layer'];
- // TODO(cais): Add logic for `numConstants` once the property is added.
- if (config['numConstants'] != null) {
- throw new NotImplementedError(`Deserialization of a Bidirectional layer with numConstants ` +
- `present is not supported yet.`);
- }
- // tslint:disable-next-line:no-any
- const newConfig = config;
- newConfig['layer'] = rnnLayer;
- return new cls(newConfig);
- }
- }
- /** @nocollapse */
- Bidirectional.className = 'Bidirectional';
- registerClass(Bidirectional);
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- // TODO(cais): Add doc string to all the public static functions in this
- // class; include exectuable JavaScript code snippets where applicable
- // (b/74074458).
- // Input Layer.
- /**
- * An input layer is an entry point into a `tf.LayersModel`.
- *
- * `InputLayer` is generated automatically for `tf.Sequential`` models by
- * specifying the `inputshape` or `batchInputShape` for the first layer. It
- * should not be specified explicitly. However, it can be useful sometimes,
- * e.g., when constructing a sequential model from a subset of another
- * sequential model's layers. Like the code snippet below shows.
- *
- * ```js
- * // Define a model which simply adds two inputs.
- * const model1 = tf.sequential();
- * model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'}));
- * model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
- * model1.summary();
- * model1.predict(tf.zeros([1, 4])).print();
- *
- * // Construct another model, reusing the second layer of `model1` while
- * // not using the first layer of `model1`. Note that you cannot add the second
- * // layer of `model` directly as the first layer of the new sequential model,
- * // because doing so will lead to an error related to the fact that the layer
- * // is not an input layer. Instead, you need to create an `inputLayer` and add
- * // it to the new sequential model before adding the reused layer.
- * const model2 = tf.sequential();
- * // Use an inputShape that matches the input shape of `model1`'s second
- * // layer.
- * model2.add(tf.layers.inputLayer({inputShape: [3]}));
- * model2.add(model1.layers[1]);
- * model2.summary();
- * model2.predict(tf.zeros([1, 3])).print();
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'}
- */
- function inputLayer(args) {
- return new InputLayer(args);
- }
- // Advanced Activation Layers.
- /**
- * Exponetial Linear Unit (ELU).
- *
- * It follows:
- * `f(x) = alpha * (exp(x) - 1.) for x < 0`,
- * `f(x) = x for x >= 0`.
- *
- * Input shape:
- * Arbitrary. Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * References:
- * - [Fast and Accurate Deep Network Learning by Exponential Linear Units
- * (ELUs)](https://arxiv.org/abs/1511.07289v1)
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function elu$2(args) {
- return new ELU(args);
- }
- /**
- * Rectified Linear Unit activation function.
- *
- * Input shape:
- * Arbitrary. Use the config field `inputShape` (Array of integers, does
- * not include the sample axis) when using this layer as the first layer
- * in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function reLU(args) {
- return new ReLU(args);
- }
- /**
- * Leaky version of a rectified linear unit.
- *
- * It allows a small gradient when the unit is not active:
- * `f(x) = alpha * x for x < 0.`
- * `f(x) = x for x >= 0.`
- *
- * Input shape:
- * Arbitrary. Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function leakyReLU(args) {
- return new LeakyReLU(args);
- }
- /**
- * Parameterized version of a leaky rectified linear unit.
- *
- * It follows
- * `f(x) = alpha * x for x < 0.`
- * `f(x) = x for x >= 0.`
- * wherein `alpha` is a trainable weight.
- *
- * Input shape:
- * Arbitrary. Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function prelu$1(args) {
- return new PReLU(args);
- }
- /**
- * Softmax activation layer.
- *
- * Input shape:
- * Arbitrary. Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function softmax$1(args) {
- return new Softmax$2(args);
- }
- /**
- * Thresholded Rectified Linear Unit.
- *
- * It follows:
- * `f(x) = x for x > theta`,
- * `f(x) = 0 otherwise`.
- *
- * Input shape:
- * Arbitrary. Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- * Output shape:
- * Same shape as the input.
- *
- * References:
- * - [Zero-Bias Autoencoders and the Benefits of Co-Adapting
- * Features](http://arxiv.org/abs/1402.3337)
- *
- * @doc {
- * heading: 'Layers',
- * subheading: 'Advanced Activation',
- * namespace: 'layers'
- * }
- */
- function thresholdedReLU(args) {
- return new ThresholdedReLU(args);
- }
- // Convolutional Layers.
- /**
- * 1D convolution layer (e.g., temporal convolution).
- *
- * This layer creates a convolution kernel that is convolved
- * with the layer input over a single spatial (or temporal) dimension
- * to produce a tensor of outputs.
- *
- * If `use_bias` is True, a bias vector is created and added to the outputs.
- *
- * If `activation` is not `null`, it is applied to the outputs as well.
- *
- * When using this layer as the first layer in a model, provide an
- * `inputShape` argument `Array` or `null`.
- *
- * For example, `inputShape` would be:
- * - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors
- * - `[null, 128]` for variable-length sequences of 128-dimensional vectors.
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function conv1d$2(args) {
- return new Conv1D(args);
- }
- /**
- * 2D convolution layer (e.g. spatial convolution over images).
- *
- * This layer creates a convolution kernel that is convolved
- * with the layer input to produce a tensor of outputs.
- *
- * If `useBias` is True, a bias vector is created and added to the outputs.
- *
- * If `activation` is not `null`, it is applied to the outputs as well.
- *
- * When using this layer as the first layer in a model,
- * provide the keyword argument `inputShape`
- * (Array of integers, does not include the sample axis),
- * e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures
- * in `dataFormat='channelsLast'`.
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function conv2d$3(args) {
- return new Conv2D$1(args);
- }
- /**
- * Transposed convolutional layer (sometimes called Deconvolution).
- *
- * The need for transposed convolutions generally arises
- * from the desire to use a transformation going in the opposite direction of
- * a normal convolution, i.e., from something that has the shape of the output
- * of some convolution to something that has the shape of its input while
- * maintaining a connectivity pattern that is compatible with said
- * convolution.
- *
- * When using this layer as the first layer in a model, provide the
- * configuration `inputShape` (`Array` of integers, does not include the
- * sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in
- * `dataFormat: 'channelsLast'`.
- *
- * Input shape:
- * 4D tensor with shape:
- * `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`.
- * or 4D tensor with shape
- * `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast`.
- *
- * Output shape:
- * 4D tensor with shape:
- * `[batch, filters, newRows, newCols]` if `dataFormat` is
- * `'channelsFirst'`. or 4D tensor with shape:
- * `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`.
- *
- * References:
- * - [A guide to convolution arithmetic for deep
- * learning](https://arxiv.org/abs/1603.07285v1)
- * - [Deconvolutional
- * Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf)
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function conv2dTranspose$1(args) {
- return new Conv2DTranspose(args);
- }
- /**
- * 3D convolution layer (e.g. spatial convolution over volumes).
- *
- * This layer creates a convolution kernel that is convolved
- * with the layer input to produce a tensor of outputs.
- *
- * If `useBias` is True, a bias vector is created and added to the outputs.
- *
- * If `activation` is not `null`, it is applied to the outputs as well.
- *
- * When using this layer as the first layer in a model,
- * provide the keyword argument `inputShape`
- * (Array of integers, does not include the sample axis),
- * e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes
- * in `dataFormat='channelsLast'`.
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function conv3d$2(args) {
- return new Conv3D$1(args);
- }
- /**
- * Depthwise separable 2D convolution.
- *
- * Separable convolution consists of first performing
- * a depthwise spatial convolution
- * (which acts on each input channel separately)
- * followed by a pointwise convolution which mixes together the resulting
- * output channels. The `depthMultiplier` argument controls how many
- * output channels are generated per input channel in the depthwise step.
- *
- * Intuitively, separable convolutions can be understood as
- * a way to factorize a convolution kernel into two smaller kernels,
- * or as an extreme version of an Inception block.
- *
- * Input shape:
- * 4D tensor with shape:
- * `[batch, channels, rows, cols]` if data_format='channelsFirst'
- * or 4D tensor with shape:
- * `[batch, rows, cols, channels]` if data_format='channelsLast'.
- *
- * Output shape:
- * 4D tensor with shape:
- * `[batch, filters, newRows, newCols]` if data_format='channelsFirst'
- * or 4D tensor with shape:
- * `[batch, newRows, newCols, filters]` if data_format='channelsLast'.
- * `rows` and `cols` values might have changed due to padding.
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function separableConv2d$1(args) {
- return new SeparableConv2D(args);
- }
- /**
- * Cropping layer for 2D input (e.g., image).
- *
- * This layer can crop an input
- * at the top, bottom, left and right side of an image tensor.
- *
- * Input shape:
- * 4D tensor with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, rows, cols, channels]`
- * - If `data_format` is `"channels_first"`:
- * `[batch, channels, rows, cols]`.
- *
- * Output shape:
- * 4D with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, croppedRows, croppedCols, channels]`
- * - If `dataFormat` is `"channelsFirst"`:
- * `[batch, channels, croppedRows, croppedCols]`.
- *
- * Examples
- * ```js
- *
- * const model = tf.sequential();
- * model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]],
- * inputShape: [128, 128, 3]}));
- * //now output shape is [batch, 124, 124, 3]
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function cropping2D(args) {
- return new Cropping2D(args);
- }
- /**
- * Upsampling layer for 2D inputs.
- *
- * Repeats the rows and columns of the data
- * by size[0] and size[1] respectively.
- *
- *
- * Input shape:
- * 4D tensor with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, rows, cols, channels]`
- * - If `dataFormat` is `"channelsFirst"`:
- * `[batch, channels, rows, cols]`
- *
- * Output shape:
- * 4D tensor with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, upsampledRows, upsampledCols, channels]`
- * - If `dataFormat` is `"channelsFirst"`:
- * `[batch, channels, upsampledRows, upsampledCols]`
- *
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function upSampling2d(args) {
- return new UpSampling2D(args);
- }
- // Convolutional(depthwise) Layers.
- /**
- * Depthwise separable 2D convolution.
- *
- * Depthwise Separable convolutions consists in performing just the first step
- * in a depthwise spatial convolution (which acts on each input channel
- * separately). The `depthMultplier` argument controls how many output channels
- * are generated per input channel in the depthwise step.
- *
- * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
- */
- function depthwiseConv2d$3(args) {
- return new DepthwiseConv2D(args);
- }
- // Basic Layers.
- /**
- * Applies an activation function to an output.
- *
- * This layer applies element-wise activation function. Other layers, notably
- * `dense` can also apply activation functions. Use this isolated activation
- * function to extract the values before and after the
- * activation. For instance:
- *
- * ```js
- * const input = tf.input({shape: [5]});
- * const denseLayer = tf.layers.dense({units: 1});
- * const activationLayer = tf.layers.activation({activation: 'relu6'});
- *
- * // Obtain the output symbolic tensors by applying the layers in order.
- * const denseOutput = denseLayer.apply(input);
- * const activationOutput = activationLayer.apply(denseOutput);
- *
- * // Create the model based on the inputs.
- * const model = tf.model({
- * inputs: input,
- * outputs: [denseOutput, activationOutput]
- * });
- *
- * // Collect both outputs and print separately.
- * const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5]));
- * denseOut.print();
- * activationOut.print();
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function activation(args) {
- return new Activation$1(args);
- }
- /**
- * Creates a dense (fully connected) layer.
- *
- * This layer implements the operation:
- * `output = activation(dot(input, kernel) + bias)`
- *
- * `activation` is the element-wise activation function
- * passed as the `activation` argument.
- *
- * `kernel` is a weights matrix created by the layer.
- *
- * `bias` is a bias vector created by the layer (only applicable if `useBias`
- * is `true`).
- *
- * **Input shape:**
- *
- * nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`.
- *
- * The most common situation would be
- * a 2D input with shape `(batchSize, inputDim)`.
- *
- * **Output shape:**
- *
- * nD tensor with shape: `(batchSize, ..., units)`.
- *
- * For instance, for a 2D input with shape `(batchSize, inputDim)`,
- * the output would have shape `(batchSize, units)`.
- *
- * Note: if the input to the layer has a rank greater than 2, then it is
- * flattened prior to the initial dot product with the kernel.
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function dense(args) {
- return new Dense(args);
- }
- /**
- * Applies
- * [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to
- * the input.
- *
- * Dropout consists in randomly setting a fraction `rate` of input units to 0 at
- * each update during training time, which helps prevent overfitting.
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function dropout$2(args) {
- return new Dropout(args);
- }
- /**
- * Spatial 1D version of Dropout.
- *
- * This Layer type performs the same function as the Dropout layer, but it drops
- * entire 1D feature maps instead of individual elements. For example, if an
- * input example consists of 3 timesteps and the feature map for each timestep
- * has a size of 4, a `spatialDropout1d` layer may zero out the feature maps
- * of the 1st timesteps and 2nd timesteps completely while sparing all feature
- * elements of the 3rd timestep.
- *
- * If adjacent frames (timesteps) are strongly correlated (as is normally the
- * case in early convolution layers), regular dropout will not regularize the
- * activation and will otherwise just result in merely an effective learning
- * rate decrease. In this case, `spatialDropout1d` will help promote
- * independence among feature maps and should be used instead.
- *
- * **Arguments:**
- * rate: A floating-point number >=0 and <=1. Fraction of the input elements
- * to drop.
- *
- * **Input shape:**
- * 3D tensor with shape `(samples, timesteps, channels)`.
- *
- * **Output shape:**
- * Same as the input shape.
- *
- * References:
- * - [Efficient Object Localization Using Convolutional
- * Networks](https://arxiv.org/abs/1411.4280)
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function spatialDropout1d(args) {
- return new SpatialDropout1D(args);
- }
- /**
- * Flattens the input. Does not affect the batch size.
- *
- * A `Flatten` layer flattens each batch in its inputs to 1D (making the output
- * 2D).
- *
- * For example:
- *
- * ```js
- * const input = tf.input({shape: [4, 3]});
- * const flattenLayer = tf.layers.flatten();
- * // Inspect the inferred output shape of the flatten layer, which
- * // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the
- * // flattening. (The 1st dimension is the undermined batch size.)
- * console.log(JSON.stringify(flattenLayer.apply(input).shape));
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function flatten$2(args) {
- return new Flatten(args);
- }
- /**
- * Repeats the input n times in a new dimension.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.repeatVector({n: 4, inputShape: [2]}));
- * const x = tf.tensor2d([[10, 20]]);
- * // Use the model to do inference on a data point the model hasn't see
- * model.predict(x).print();
- * // output shape is now [batch, 2, 4]
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function repeatVector(args) {
- return new RepeatVector(args);
- }
- /**
- * Reshapes an input to a certain shape.
- *
- * ```js
- * const input = tf.input({shape: [4, 3]});
- * const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]});
- * // Inspect the inferred output shape of the Reshape layer, which
- * // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.)
- * console.log(JSON.stringify(reshapeLayer.apply(input).shape));
- * ```
- *
- * Input shape:
- * Arbitrary, although all dimensions in the input shape must be fixed.
- * Use the configuration `inputShape` when using this layer as the
- * first layer in a model.
- *
- *
- * Output shape:
- * [batchSize, targetShape[0], targetShape[1], ...,
- * targetShape[targetShape.length - 1]].
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function reshape$1(args) {
- return new Reshape$1(args);
- }
- /**
- * Permutes the dimensions of the input according to a given pattern.
- *
- * Useful for, e.g., connecting RNNs and convnets together.
- *
- * Example:
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.permute({
- * dims: [2, 1],
- * inputShape: [10, 64]
- * }));
- * console.log(model.outputShape);
- * // Now model's output shape is [null, 64, 10], where null is the
- * // unpermuted sample (batch) dimension.
- * ```
- *
- * Input shape:
- * Arbitrary. Use the configuration field `inputShape` when using this
- * layer as the first layer in a model.
- *
- * Output shape:
- * Same rank as the input shape, but with the dimensions re-ordered (i.e.,
- * permuted) according to the `dims` configuration of this layer.
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function permute(args) {
- return new Permute(args);
- }
- /**
- * Maps positive integers (indices) into dense vectors of fixed size.
- * eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
- *
- * **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`.
- *
- * **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength,
- * outputDim]`.
- *
- * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
- */
- function embedding(args) {
- return new Embedding(args);
- }
- // Merge Layers.
- /**
- * Layer that performs element-wise addition on an `Array` of inputs.
- *
- * It takes as input a list of tensors, all of the same shape, and returns a
- * single tensor (also of the same shape). The inputs are specified as an
- * `Array` when the `apply` method of the `Add` layer instance is called. For
- * example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const addLayer = tf.layers.add();
- * const sum = addLayer.apply([input1, input2]);
- * console.log(JSON.stringify(sum.shape));
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function add$3(args) {
- return new Add$1(args);
- }
- /**
- * Layer that performs element-wise averaging on an `Array` of inputs.
- *
- * It takes as input a list of tensors, all of the same shape, and returns a
- * single tensor (also of the same shape). For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const averageLayer = tf.layers.average();
- * const average = averageLayer.apply([input1, input2]);
- * console.log(JSON.stringify(average.shape));
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function average$1(args) {
- return new Average(args);
- }
- /**
- * Layer that concatenates an `Array` of inputs.
- *
- * It takes a list of tensors, all of the same shape except for the
- * concatenation axis, and returns a single tensor, the concatenation
- * of all inputs. For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 3]});
- * const concatLayer = tf.layers.concatenate();
- * const output = concatLayer.apply([input1, input2]);
- * console.log(JSON.stringify(output.shape));
- * // You get [null, 2, 5], with the first dimension as the undetermined batch
- * // dimension. The last dimension (5) is the result of concatenating the
- * // last dimensions of the inputs (2 and 3).
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function concatenate$2(args) {
- return new Concatenate(args);
- }
- /**
- * Layer that computes the element-wise maximum an `Array` of inputs.
- *
- * It takes as input a list of tensors, all of the same shape and returns a
- * single tensor (also of the same shape). For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const maxLayer = tf.layers.maximum();
- * const max = maxLayer.apply([input1, input2]);
- * console.log(JSON.stringify(max.shape));
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function maximum$2(args) {
- return new Maximum$1(args);
- }
- /**
- * Layer that computes the element-wise minimum of an `Array` of inputs.
- *
- * It takes as input a list of tensors, all of the same shape and returns a
- * single tensor (also of the same shape). For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const minLayer = tf.layers.minimum();
- * const min = minLayer.apply([input1, input2]);
- * console.log(JSON.stringify(min.shape));
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function minimum$2(args) {
- return new Minimum$1(args);
- }
- /**
- * Layer that multiplies (element-wise) an `Array` of inputs.
- *
- * It takes as input an Array of tensors, all of the same
- * shape, and returns a single tensor (also of the same shape).
- * For example:
- *
- * ```js
- * const input1 = tf.input({shape: [2, 2]});
- * const input2 = tf.input({shape: [2, 2]});
- * const input3 = tf.input({shape: [2, 2]});
- * const multiplyLayer = tf.layers.multiply();
- * const product = multiplyLayer.apply([input1, input2, input3]);
- * console.log(product.shape);
- * // You get [null, 2, 2], with the first dimension as the undetermined batch
- * // dimension.
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function multiply$1(args) {
- return new Multiply$1(args);
- }
- /**
- * Layer that computes a dot product between samples in two tensors.
- *
- * E.g., if applied to a list of two tensors `a` and `b` both of shape
- * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
- * where each entry at index `[i, 0]` will be the dot product between
- * `a[i, :]` and `b[i, :]`.
- *
- * Example:
- *
- * ```js
- * const dotLayer = tf.layers.dot({axes: -1});
- * const x1 = tf.tensor2d([[10, 20], [30, 40]]);
- * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
- *
- * // Invoke the layer's apply() method in eager (imperative) mode.
- * const y = dotLayer.apply([x1, x2]);
- * y.print();
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
- */
- function dot$2(args) {
- return new Dot(args);
- }
- // Normalization Layers.
- /**
- * Batch normalization layer (Ioffe and Szegedy, 2014).
- *
- * Normalize the activations of the previous layer at each batch,
- * i.e. applies a transformation that maintains the mean activation
- * close to 0 and the activation standard deviation close to 1.
- *
- * Input shape:
- * Arbitrary. Use the keyword argument `inputShape` (Array of integers, does
- * not include the sample axis) when calling the constructor of this class,
- * if this layer is used as a first layer in a model.
- *
- * Output shape:
- * Same shape as input.
- *
- * References:
- * - [Batch Normalization: Accelerating Deep Network Training by Reducing
- * Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
- *
- * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
- */
- function batchNormalization$1(args) {
- return new BatchNormalization(args);
- }
- /**
- * Layer-normalization layer (Ba et al., 2016).
- *
- * Normalizes the activations of the previous layer for each given example in a
- * batch independently, instead of across a batch like in `batchNormalization`.
- * In other words, this layer applies a transformation that maintanis the mean
- * activation within each example close to0 and activation variance close to 1.
- *
- * Input shape:
- * Arbitrary. Use the argument `inputShape` when using this layer as the first
- * layer in a model.
- *
- * Output shape:
- * Same as input.
- *
- * References:
- * - [Layer Normalization](https://arxiv.org/abs/1607.06450)
- *
- * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
- */
- function layerNormalization(args) {
- return new LayerNormalization(args);
- }
- // Padding Layers.
- /**
- * Zero-padding layer for 2D input (e.g., image).
- *
- * This layer can add rows and columns of zeros
- * at the top, bottom, left and right side of an image tensor.
- *
- * Input shape:
- * 4D tensor with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, rows, cols, channels]`
- * - If `data_format` is `"channels_first"`:
- * `[batch, channels, rows, cols]`.
- *
- * Output shape:
- * 4D with shape:
- * - If `dataFormat` is `"channelsLast"`:
- * `[batch, paddedRows, paddedCols, channels]`
- * - If `dataFormat` is `"channelsFirst"`:
- * `[batch, channels, paddedRows, paddedCols]`.
- *
- * @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'}
- */
- function zeroPadding2d(args) {
- return new ZeroPadding2D(args);
- }
- // Pooling Layers.
- /**
- * Average pooling operation for spatial data.
- *
- * Input shape: `[batchSize, inLength, channels]`
- *
- * Output shape: `[batchSize, pooledLength, channels]`
- *
- * `tf.avgPool1d` is an alias.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function averagePooling1d(args) {
- return new AveragePooling1D(args);
- }
- function avgPool1d(args) {
- return averagePooling1d(args);
- }
- // For backwards compatibility.
- // See https://github.com/tensorflow/tfjs/issues/152
- function avgPooling1d(args) {
- return averagePooling1d(args);
- }
- /**
- * Average pooling operation for spatial data.
- *
- * Input shape:
- * - If `dataFormat === CHANNEL_LAST`:
- * 4D tensor with shape:
- * `[batchSize, rows, cols, channels]`
- * - If `dataFormat === CHANNEL_FIRST`:
- * 4D tensor with shape:
- * `[batchSize, channels, rows, cols]`
- *
- * Output shape
- * - If `dataFormat === CHANNEL_LAST`:
- * 4D tensor with shape:
- * `[batchSize, pooleRows, pooledCols, channels]`
- * - If `dataFormat === CHANNEL_FIRST`:
- * 4D tensor with shape:
- * `[batchSize, channels, pooleRows, pooledCols]`
- *
- * `tf.avgPool2d` is an alias.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function averagePooling2d(args) {
- return new AveragePooling2D(args);
- }
- function avgPool2d(args) {
- return averagePooling2d(args);
- }
- // For backwards compatibility.
- // See https://github.com/tensorflow/tfjs/issues/152
- function avgPooling2d(args) {
- return averagePooling2d(args);
- }
- /**
- * Average pooling operation for 3D data.
- *
- * Input shape
- * - If `dataFormat === channelsLast`:
- * 5D tensor with shape:
- * `[batchSize, depths, rows, cols, channels]`
- * - If `dataFormat === channelsFirst`:
- * 4D tensor with shape:
- * `[batchSize, channels, depths, rows, cols]`
- *
- * Output shape
- * - If `dataFormat=channelsLast`:
- * 5D tensor with shape:
- * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
- * - If `dataFormat=channelsFirst`:
- * 5D tensor with shape:
- * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function averagePooling3d(args) {
- return new AveragePooling3D(args);
- }
- function avgPool3d$1(args) {
- return averagePooling3d(args);
- }
- // For backwards compatibility.
- // See https://github.com/tensorflow/tfjs/issues/152
- function avgPooling3d(args) {
- return averagePooling3d(args);
- }
- /**
- * Global average pooling operation for temporal data.
- *
- * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
- *
- * Output Shape:2D tensor with shape: `[batchSize, features]`.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function globalAveragePooling1d(args) {
- return new GlobalAveragePooling1D(args);
- }
- /**
- * Global average pooling operation for spatial data.
- *
- * Input shape:
- * - If `dataFormat` is `CHANNEL_LAST`:
- * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
- * - If `dataFormat` is `CHANNEL_FIRST`:
- * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
- *
- * Output shape:
- * 2D tensor with shape: `[batchSize, channels]`.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function globalAveragePooling2d(args) {
- return new GlobalAveragePooling2D(args);
- }
- /**
- * Global max pooling operation for temporal data.
- *
- * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
- *
- * Output Shape:2D tensor with shape: `[batchSize, features]`.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function globalMaxPooling1d(args) {
- return new GlobalMaxPooling1D(args);
- }
- /**
- * Global max pooling operation for spatial data.
- *
- * Input shape:
- * - If `dataFormat` is `CHANNEL_LAST`:
- * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
- * - If `dataFormat` is `CHANNEL_FIRST`:
- * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
- *
- * Output shape:
- * 2D tensor with shape: `[batchSize, channels]`.
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function globalMaxPooling2d(args) {
- return new GlobalMaxPooling2D(args);
- }
- /**
- * Max pooling operation for temporal data.
- *
- * Input shape: `[batchSize, inLength, channels]`
- *
- * Output shape: `[batchSize, pooledLength, channels]`
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function maxPooling1d(args) {
- return new MaxPooling1D(args);
- }
- /**
- * Max pooling operation for spatial data.
- *
- * Input shape
- * - If `dataFormat === CHANNEL_LAST`:
- * 4D tensor with shape:
- * `[batchSize, rows, cols, channels]`
- * - If `dataFormat === CHANNEL_FIRST`:
- * 4D tensor with shape:
- * `[batchSize, channels, rows, cols]`
- *
- * Output shape
- * - If `dataFormat=CHANNEL_LAST`:
- * 4D tensor with shape:
- * `[batchSize, pooleRows, pooledCols, channels]`
- * - If `dataFormat=CHANNEL_FIRST`:
- * 4D tensor with shape:
- * `[batchSize, channels, pooleRows, pooledCols]`
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function maxPooling2d(args) {
- return new MaxPooling2D(args);
- }
- /**
- * Max pooling operation for 3D data.
- *
- * Input shape
- * - If `dataFormat === channelsLast`:
- * 5D tensor with shape:
- * `[batchSize, depths, rows, cols, channels]`
- * - If `dataFormat === channelsFirst`:
- * 5D tensor with shape:
- * `[batchSize, channels, depths, rows, cols]`
- *
- * Output shape
- * - If `dataFormat=channelsLast`:
- * 5D tensor with shape:
- * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
- * - If `dataFormat=channelsFirst`:
- * 5D tensor with shape:
- * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
- *
- * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
- */
- function maxPooling3d(args) {
- return new MaxPooling3D(args);
- }
- // Recurrent Layers.
- /**
- * Gated Recurrent Unit - Cho et al. 2014.
- *
- * This is an `RNN` layer consisting of one `GRUCell`. However, unlike
- * the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates
- * on a sequence of inputs. The shape of the input (not including the first,
- * batch dimension) needs to be at least 2-D, with the first dimension being
- * time steps. For example:
- *
- * ```js
- * const rnn = tf.layers.gru({units: 8, returnSequences: true});
- *
- * // Create an input with 10 time steps.
- * const input = tf.input({shape: [10, 20]});
- * const output = rnn.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the `GRUCell`'s number of units.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function gru(args) {
- return new GRU(args);
- }
- /**
- * Cell class for `GRU`.
- *
- * `GRUCell` is distinct from the `RNN` subclass `GRU` in that its
- * `apply` method takes the input data of only a single time step and returns
- * the cell's output at the time step, while `GRU` takes the input data
- * over a number of time steps. For example:
- *
- * ```js
- * const cell = tf.layers.gruCell({units: 2});
- * const input = tf.input({shape: [10]});
- * const output = cell.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10]: This is the cell's output at a single time step. The 1st
- * // dimension is the unknown batch size.
- * ```
- *
- * Instance(s) of `GRUCell` can be used to construct `RNN` layers. The
- * most typical use of this workflow is to combine a number of cells into a
- * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
- * RNN. For example:
- *
- * ```js
- * const cells = [
- * tf.layers.gruCell({units: 4}),
- * tf.layers.gruCell({units: 8}),
- * ];
- * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
- *
- * // Create an input with 10 time steps and a length-20 vector at each step.
- * const input = tf.input({shape: [10, 20]});
- * const output = rnn.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the last `gruCell`'s number of units.
- * ```
- *
- * To create an `RNN` consisting of only *one* `GRUCell`, use the
- * `tf.layers.gru`.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function gruCell(args) {
- return new GRUCell(args);
- }
- /**
- * Long-Short Term Memory layer - Hochreiter 1997.
- *
- * This is an `RNN` layer consisting of one `LSTMCell`. However, unlike
- * the underlying `LSTMCell`, the `apply` method of `LSTM` operates
- * on a sequence of inputs. The shape of the input (not including the first,
- * batch dimension) needs to be at least 2-D, with the first dimension being
- * time steps. For example:
- *
- * ```js
- * const lstm = tf.layers.lstm({units: 8, returnSequences: true});
- *
- * // Create an input with 10 time steps.
- * const input = tf.input({shape: [10, 20]});
- * const output = lstm.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the `LSTMCell`'s number of units.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function lstm(args) {
- return new LSTM(args);
- }
- /**
- * Cell class for `LSTM`.
- *
- * `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its
- * `apply` method takes the input data of only a single time step and returns
- * the cell's output at the time step, while `LSTM` takes the input data
- * over a number of time steps. For example:
- *
- * ```js
- * const cell = tf.layers.lstmCell({units: 2});
- * const input = tf.input({shape: [10]});
- * const output = cell.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10]: This is the cell's output at a single time step. The 1st
- * // dimension is the unknown batch size.
- * ```
- *
- * Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The
- * most typical use of this workflow is to combine a number of cells into a
- * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
- * RNN. For example:
- *
- * ```js
- * const cells = [
- * tf.layers.lstmCell({units: 4}),
- * tf.layers.lstmCell({units: 8}),
- * ];
- * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
- *
- * // Create an input with 10 time steps and a length-20 vector at each step.
- * const input = tf.input({shape: [10, 20]});
- * const output = rnn.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the last `lstmCell`'s number of units.
- * ```
- *
- * To create an `RNN` consisting of only *one* `LSTMCell`, use the
- * `tf.layers.lstm`.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function lstmCell(args) {
- return new LSTMCell(args);
- }
- /**
- * Fully-connected RNN where the output is to be fed back to input.
- *
- * This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike
- * the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates
- * on a sequence of inputs. The shape of the input (not including the first,
- * batch dimension) needs to be at least 2-D, with the first dimension being
- * time steps. For example:
- *
- * ```js
- * const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true});
- *
- * // Create an input with 10 time steps.
- * const input = tf.input({shape: [10, 20]});
- * const output = rnn.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the `SimpleRNNCell`'s number of units.
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function simpleRNN(args) {
- return new SimpleRNN(args);
- }
- /**
- * Cell class for `SimpleRNN`.
- *
- * `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its
- * `apply` method takes the input data of only a single time step and returns
- * the cell's output at the time step, while `SimpleRNN` takes the input data
- * over a number of time steps. For example:
- *
- * ```js
- * const cell = tf.layers.simpleRNNCell({units: 2});
- * const input = tf.input({shape: [10]});
- * const output = cell.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10]: This is the cell's output at a single time step. The 1st
- * // dimension is the unknown batch size.
- * ```
- *
- * Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The
- * most typical use of this workflow is to combine a number of cells into a
- * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
- * RNN. For example:
- *
- * ```js
- * const cells = [
- * tf.layers.simpleRNNCell({units: 4}),
- * tf.layers.simpleRNNCell({units: 8}),
- * ];
- * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
- *
- * // Create an input with 10 time steps and a length-20 vector at each step.
- * const input = tf.input({shape: [10, 20]});
- * const output = rnn.apply(input);
- *
- * console.log(JSON.stringify(output.shape));
- * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
- * // same as the sequence length of `input`, due to `returnSequences`: `true`;
- * // 3rd dimension is the last `SimpleRNNCell`'s number of units.
- * ```
- *
- * To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the
- * `tf.layers.simpleRNN`.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function simpleRNNCell(args) {
- return new SimpleRNNCell(args);
- }
- /**
- * Convolutional LSTM layer - Xingjian Shi 2015.
- *
- * This is an `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However,
- * unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D`
- * operates on a sequence of inputs. The shape of the input (not including the
- * first, batch dimension) needs to be 4-D, with the first dimension being time
- * steps. For example:
- *
- * ```js
- * const filters = 3;
- * const kernelSize = 3;
- *
- * const batchSize = 4;
- * const sequenceLength = 2;
- * const size = 5;
- * const channels = 3;
- *
- * const inputShape = [batchSize, sequenceLength, size, size, channels];
- * const input = tf.ones(inputShape);
- *
- * const layer = tf.layers.convLstm2d({filters, kernelSize});
- *
- * const output = layer.apply(input);
- * ```
- */
- /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
- function convLstm2d(args) {
- return new ConvLSTM2D(args);
- }
- /**
- * Cell class for `ConvLSTM2D`.
- *
- * `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in
- * that its `call` method takes the input data of only a single time step and
- * returns the cell's output at the time step, while `ConvLSTM2D` takes the
- * input data over a number of time steps. For example:
- *
- * ```js
- * const filters = 3;
- * const kernelSize = 3;
- *
- * const sequenceLength = 1;
- * const size = 5;
- * const channels = 3;
- *
- * const inputShape = [sequenceLength, size, size, channels];
- * const input = tf.ones(inputShape);
- *
- * const cell = tf.layers.convLstm2dCell({filters, kernelSize});
- *
- * cell.build(input.shape);
- *
- * const outputSize = size - kernelSize + 1;
- * const outShape = [sequenceLength, outputSize, outputSize, filters];
- *
- * const initialH = tf.zeros(outShape);
- * const initialC = tf.zeros(outShape);
- *
- * const [o, h, c] = cell.call([input, initialH, initialC], {});
- * ```
- */
- /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
- function convLstm2dCell(args) {
- return new ConvLSTM2DCell(args);
- }
- /**
- * Base class for recurrent layers.
- *
- * Input shape:
- * 3D tensor with shape `[batchSize, timeSteps, inputDim]`.
- *
- * Output shape:
- * - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first
- * tensor is the output. The remaining tensors are the states at the
- * last time step, each with shape `[batchSize, units]`.
- * - if `returnSequences`, the output will have shape
- * `[batchSize, timeSteps, units]`.
- * - else, the output will have shape `[batchSize, units]`.
- *
- * Masking:
- * This layer supports masking for input data with a variable number
- * of timesteps. To introduce masks to your data,
- * use an embedding layer with the `mask_zero` parameter
- * set to `True`.
- *
- * Notes on using statefulness in RNNs:
- * You can set RNN layers to be 'stateful', which means that the states
- * computed for the samples in one batch will be reused as initial states
- * for the samples in the next batch. This assumes a one-to-one mapping
- * between samples in different successive batches.
- *
- * To enable statefulness:
- * - specify `stateful: true` in the layer constructor.
- * - specify a fixed batch size for your model, by passing
- * if sequential model:
- * `batchInputShape=[...]` to the first layer in your model.
- * else for functional model with 1 or more Input layers:
- * `batchShape=[...]` to all the first layers in your model.
- * This is the expected shape of your inputs *including the batch size*.
- * It should be a tuple of integers, e.g. `(32, 10, 100)`.
- * - specify `shuffle=False` when calling fit().
- *
- * To reset the states of your model, call `.resetStates()` on either
- * a specific layer, or on your entire model.
- *
- * Note on specifying the initial state of RNNs
- * You can specify the initial state of RNN layers symbolically by
- * calling them with the option `initialState`. The value of
- * `initialState` should be a tensor or list of tensors representing
- * the initial state of the RNN layer.
- *
- * You can specify the initial state of RNN layers numerically by
- * calling `resetStates` with the keyword argument `states`. The value of
- * `states` should be a numpy array or list of numpy arrays representing
- * the initial state of the RNN layer.
- *
- * Note on passing external constants to RNNs
- * You can pass "external" constants to the cell using the `constants`
- * keyword argument of `RNN.call` method. This requires that the `cell.call`
- * method accepts the same keyword argument `constants`. Such constants
- * can be used to conditon the cell transformation on additional static inputs
- * (not changing over time), a.k.a an attention mechanism.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function rnn$1(args) {
- return new RNN(args);
- }
- /**
- * Wrapper allowing a stack of RNN cells to behave as a single cell.
- *
- * Used to implement efficient stacked RNNs.
- *
- * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
- */
- function stackedRNNCells(args) {
- return new StackedRNNCells(args);
- }
- // Wrapper Layers.
- /** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */
- function bidirectional(args) {
- return new Bidirectional(args);
- }
- /**
- * This wrapper applies a layer to every temporal slice of an input.
- *
- * The input should be at least 3D, and the dimension of the index `1` will be
- * considered to be the temporal dimension.
- *
- * Consider a batch of 32 samples, where each sample is a sequence of 10 vectors
- * of 16 dimensions. The batch input shape of the layer is then `[32, 10,
- * 16]`, and the `inputShape`, not including the sample dimension, is
- * `[10, 16]`.
- *
- * You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10
- * timesteps, independently:
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.timeDistributed({
- * layer: tf.layers.dense({units: 8}),
- * inputShape: [10, 16],
- * }));
- *
- * // Now model.outputShape = [null, 10, 8].
- * // The output will then have shape `[32, 10, 8]`.
- *
- * // In subsequent layers, there is no need for `inputShape`:
- * model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})}));
- * console.log(JSON.stringify(model.outputs[0].shape));
- * // Now model.outputShape = [null, 10, 32].
- * ```
- *
- * The output will then have shape `[32, 10, 32]`.
- *
- * `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for
- * instance a `Conv2D` layer.
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.timeDistributed({
- * layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}),
- * inputShape: [10, 299, 299, 3],
- * }));
- * console.log(JSON.stringify(model.outputs[0].shape));
- * ```
- *
- * @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'}
- */
- function timeDistributed(args) {
- return new TimeDistributed(args);
- }
- // Aliases for pooling.
- const globalMaxPool1d = globalMaxPooling1d;
- const globalMaxPool2d = globalMaxPooling2d;
- const maxPool1d = maxPooling1d;
- const maxPool2d = maxPooling2d;
- /**
- * Apply additive zero-centered Gaussian noise.
- *
- * As it is a regularization layer, it is only active at training time.
- *
- * This is useful to mitigate overfitting
- * (you could see it as a form of random data augmentation).
- * Gaussian Noise (GS) is a natural choice as corruption process
- * for real valued inputs.
- *
- * # Arguments
- * stddev: float, standard deviation of the noise distribution.
- *
- * # Input shape
- * Arbitrary. Use the keyword argument `input_shape`
- * (tuple of integers, does not include the samples axis)
- * when using this layer as the first layer in a model.
- *
- * # Output shape
- * Same shape as input.
- *
- * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
- */
- function gaussianNoise(args) {
- return new GaussianNoise(args);
- }
- /**
- * Apply multiplicative 1-centered Gaussian noise.
- *
- * As it is a regularization layer, it is only active at training time.
- *
- * Arguments:
- * - `rate`: float, drop probability (as with `Dropout`).
- * The multiplicative noise will have
- * standard deviation `sqrt(rate / (1 - rate))`.
- *
- * Input shape:
- * Arbitrary. Use the keyword argument `inputShape`
- * (tuple of integers, does not include the samples axis)
- * when using this layer as the first layer in a model.
- *
- * Output shape:
- * Same shape as input.
- *
- * References:
- * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
- * http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
- *
- * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
- */
- function gaussianDropout(args) {
- return new GaussianDropout(args);
- }
- /**
- * Applies Alpha Dropout to the input.
- *
- * As it is a regularization layer, it is only active at training time.
- *
- * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
- * to their original values, in order to ensure the self-normalizing property
- * even after this dropout.
- * Alpha Dropout fits well to Scaled Exponential Linear Units
- * by randomly setting activations to the negative saturation value.
- *
- * Arguments:
- * - `rate`: float, drop probability (as with `Dropout`).
- * The multiplicative noise will have
- * standard deviation `sqrt(rate / (1 - rate))`.
- * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
- * shape for randomly generated keep/drop flags.
- *
- * Input shape:
- * Arbitrary. Use the keyword argument `inputShape`
- * (tuple of integers, does not include the samples axis)
- * when using this layer as the first layer in a model.
- *
- * Output shape:
- * Same shape as input.
- *
- * References:
- * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- *
- * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
- */
- function alphaDropout(args) {
- return new AlphaDropout(args);
- }
- /**
- * Masks a sequence by using a mask value to skip timesteps.
- *
- * If all features for a given sample timestep are equal to `mask_value`,
- * then the sample timestep will be masked (skipped) in all downstream layers
- * (as long as they support masking).
- *
- * If any downstream layer does not support masking yet receives such
- * an input mask, an exception will be raised.
- *
- * Arguments:
- * - `maskValue`: Either None or mask value to skip.
- *
- * Input shape:
- * Arbitrary. Use the keyword argument `inputShape`
- * (tuple of integers, does not include the samples axis)
- * when using this layer as the first layer in a model.
- *
- * Output shape:
- * Same shape as input.
- *
- * @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'}
- */
- function masking(args) {
- return new Masking(args);
- }
-
- var exports_layers = /*#__PURE__*/Object.freeze({
- __proto__: null,
- inputLayer: inputLayer,
- elu: elu$2,
- reLU: reLU,
- leakyReLU: leakyReLU,
- prelu: prelu$1,
- softmax: softmax$1,
- thresholdedReLU: thresholdedReLU,
- conv1d: conv1d$2,
- conv2d: conv2d$3,
- conv2dTranspose: conv2dTranspose$1,
- conv3d: conv3d$2,
- separableConv2d: separableConv2d$1,
- cropping2D: cropping2D,
- upSampling2d: upSampling2d,
- depthwiseConv2d: depthwiseConv2d$3,
- activation: activation,
- dense: dense,
- dropout: dropout$2,
- spatialDropout1d: spatialDropout1d,
- flatten: flatten$2,
- repeatVector: repeatVector,
- reshape: reshape$1,
- permute: permute,
- embedding: embedding,
- add: add$3,
- average: average$1,
- concatenate: concatenate$2,
- maximum: maximum$2,
- minimum: minimum$2,
- multiply: multiply$1,
- dot: dot$2,
- batchNormalization: batchNormalization$1,
- layerNormalization: layerNormalization,
- zeroPadding2d: zeroPadding2d,
- averagePooling1d: averagePooling1d,
- avgPool1d: avgPool1d,
- avgPooling1d: avgPooling1d,
- averagePooling2d: averagePooling2d,
- avgPool2d: avgPool2d,
- avgPooling2d: avgPooling2d,
- averagePooling3d: averagePooling3d,
- avgPool3d: avgPool3d$1,
- avgPooling3d: avgPooling3d,
- globalAveragePooling1d: globalAveragePooling1d,
- globalAveragePooling2d: globalAveragePooling2d,
- globalMaxPooling1d: globalMaxPooling1d,
- globalMaxPooling2d: globalMaxPooling2d,
- maxPooling1d: maxPooling1d,
- maxPooling2d: maxPooling2d,
- maxPooling3d: maxPooling3d,
- gru: gru,
- gruCell: gruCell,
- lstm: lstm,
- lstmCell: lstmCell,
- simpleRNN: simpleRNN,
- simpleRNNCell: simpleRNNCell,
- convLstm2d: convLstm2d,
- convLstm2dCell: convLstm2dCell,
- rnn: rnn$1,
- stackedRNNCells: stackedRNNCells,
- bidirectional: bidirectional,
- timeDistributed: timeDistributed,
- globalMaxPool1d: globalMaxPool1d,
- globalMaxPool2d: globalMaxPool2d,
- maxPool1d: maxPool1d,
- maxPool2d: maxPool2d,
- Layer: Layer,
- RNN: RNN,
- RNNCell: RNNCell,
- input: input,
- gaussianNoise: gaussianNoise,
- gaussianDropout: gaussianDropout,
- alphaDropout: alphaDropout,
- masking: masking
- });
-
- /**
- * Binary accuracy metric function.
- *
- * `yTrue` and `yPred` can have 0-1 values. Example:
- * ```js
- * const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]);
- * const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]);
- * const accuracy = tf.metrics.binaryAccuracy(x, y);
- * accuracy.print();
- * ```
- *
- * `yTrue` and `yPred` can also have floating-number values between 0 and 1, in
- * which case the values will be thresholded at 0.5 to yield 0-1 values (i.e.,
- * a value >= 0.5 and <= 1.0 is interpreted as 1.
- * )
- * Example:
- * ```js
- * const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]);
- * const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]);
- * const accuracy = tf.metrics.binaryAccuracy(x, y);
- * accuracy.print();
- * ```
- *
- * @param yTrue Binary Tensor of truth.
- * @param yPred Binary Tensor of prediction.
- * @return Accuracy Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function binaryAccuracy$1(yTrue, yPred) {
- return binaryAccuracy(yTrue, yPred);
- }
- /**
- * Binary crossentropy metric function.
- *
- * Example:
- * ```js
- * const x = tf.tensor2d([[0], [1], [1], [1]]);
- * const y = tf.tensor2d([[0], [0], [0.5], [1]]);
- * const crossentropy = tf.metrics.binaryCrossentropy(x, y);
- * crossentropy.print();
- * ```
- *
- * @param yTrue Binary Tensor of truth.
- * @param yPred Binary Tensor of prediction, probabilities for the `1` case.
- * @return Accuracy Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function binaryCrossentropy$2(yTrue, yPred) {
- return binaryCrossentropy$1(yTrue, yPred);
- }
- /**
- * Sparse categorical accuracy metric function.
- *
- * Example:
- * ```js
- *
- * const yTrue = tf.tensor1d([1, 1, 2, 2, 0]);
- * const yPred = tf.tensor2d(
- * [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]);
- * const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred);
- * crossentropy.print();
- * ```
- *
- * @param yTrue True labels: indices.
- * @param yPred Predicted probabilities or logits.
- * @returns Accuracy tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function sparseCategoricalAccuracy$1(yTrue, yPred) {
- return sparseCategoricalAccuracy(yTrue, yPred);
- }
- /**
- * Categorical accuracy metric function.
- *
- * Example:
- * ```js
- * const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]);
- * const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]);
- * const accuracy = tf.metrics.categoricalAccuracy(x, y);
- * accuracy.print();
- * ```
- *
- * @param yTrue Binary Tensor of truth: one-hot encoding of categories.
- * @param yPred Binary Tensor of prediction: probabilities or logits for the
- * same categories as in `yTrue`.
- * @return Accuracy Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function categoricalAccuracy$1(yTrue, yPred) {
- return categoricalAccuracy(yTrue, yPred);
- }
- /**
- * Categorical crossentropy between an output tensor and a target tensor.
- *
- * @param target A tensor of the same shape as `output`.
- * @param output A tensor resulting from a softmax (unless `fromLogits` is
- * `true`, in which case `output` is expected to be the logits).
- * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
- * a tensor of logits.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function categoricalCrossentropy$2(yTrue, yPred) {
- return categoricalCrossentropy$1(yTrue, yPred);
- }
- /**
- * Computes the precision of the predictions with respect to the labels.
- *
- * Example:
- * ```js
- * const x = tf.tensor2d(
- * [
- * [0, 0, 0, 1],
- * [0, 1, 0, 0],
- * [0, 0, 0, 1],
- * [1, 0, 0, 0],
- * [0, 0, 1, 0]
- * ]
- * );
- *
- * const y = tf.tensor2d(
- * [
- * [0, 0, 1, 0],
- * [0, 1, 0, 0],
- * [0, 0, 0, 1],
- * [0, 1, 0, 0],
- * [0, 1, 0, 0]
- * ]
- * );
- *
- * const precision = tf.metrics.precision(x, y);
- * precision.print();
- * ```
- *
- * @param yTrue The ground truth values. Expected to be contain only 0-1 values.
- * @param yPred The predicted values. Expected to be contain only 0-1 values.
- * @return Precision Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function precision$1(yTrue, yPred) {
- return precision(yTrue, yPred);
- }
- /**
- * Computes the recall of the predictions with respect to the labels.
- *
- * Example:
- * ```js
- * const x = tf.tensor2d(
- * [
- * [0, 0, 0, 1],
- * [0, 1, 0, 0],
- * [0, 0, 0, 1],
- * [1, 0, 0, 0],
- * [0, 0, 1, 0]
- * ]
- * );
- *
- * const y = tf.tensor2d(
- * [
- * [0, 0, 1, 0],
- * [0, 1, 0, 0],
- * [0, 0, 0, 1],
- * [0, 1, 0, 0],
- * [0, 1, 0, 0]
- * ]
- * );
- *
- * const recall = tf.metrics.recall(x, y);
- * recall.print();
- * ```
- *
- * @param yTrue The ground truth values. Expected to be contain only 0-1 values.
- * @param yPred The predicted values. Expected to be contain only 0-1 values.
- * @return Recall Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function recall$1(yTrue, yPred) {
- return recall(yTrue, yPred);
- }
- /**
- * Loss or metric function: Cosine proximity.
- *
- * Mathematically, cosine proximity is defined as:
- * `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`,
- * wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*`
- * represents element-wise multiplication.
- *
- * ```js
- * const yTrue = tf.tensor2d([[1, 0], [1, 0]]);
- * const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]);
- * const proximity = tf.metrics.cosineProximity(yTrue, yPred);
- * proximity.print();
- * ```
- *
- * @param yTrue Truth Tensor.
- * @param yPred Prediction Tensor.
- * @return Cosine proximity Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function cosineProximity$1(yTrue, yPred) {
- return cosineProximity(yTrue, yPred);
- }
- /**
- * Loss or metric function: Mean absolute error.
- *
- * Mathematically, mean absolute error is defined as:
- * `mean(abs(yPred - yTrue))`,
- * wherein the `mean` is applied over feature dimensions.
- *
- * ```js
- * const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]);
- * const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]);
- * const mse = tf.metrics.meanAbsoluteError(yTrue, yPred);
- * mse.print();
- * ```
- *
- * @param yTrue Truth Tensor.
- * @param yPred Prediction Tensor.
- * @return Mean absolute error Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function meanAbsoluteError$1(yTrue, yPred) {
- return meanAbsoluteError(yTrue, yPred);
- }
- /**
- * Loss or metric function: Mean absolute percentage error.
- *
- * ```js
- * const yTrue = tf.tensor2d([[0, 1], [10, 20]]);
- * const yPred = tf.tensor2d([[0, 1], [11, 24]]);
- * const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred);
- * mse.print();
- * ```
- *
- * Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`.
- *
- * @param yTrue Truth Tensor.
- * @param yPred Prediction Tensor.
- * @return Mean absolute percentage error Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function meanAbsolutePercentageError$1(yTrue, yPred) {
- return meanAbsolutePercentageError(yTrue, yPred);
- }
- function MAPE$2(yTrue, yPred) {
- return meanAbsolutePercentageError(yTrue, yPred);
- }
- function mape$2(yTrue, yPred) {
- return meanAbsolutePercentageError(yTrue, yPred);
- }
- /**
- * Loss or metric function: Mean squared error.
- *
- * ```js
- * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
- * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
- * const mse = tf.metrics.meanSquaredError(yTrue, yPred);
- * mse.print();
- * ```
- *
- * Aliases: `tf.metrics.MSE`, `tf.metrics.mse`.
- *
- * @param yTrue Truth Tensor.
- * @param yPred Prediction Tensor.
- * @return Mean squared error Tensor.
- *
- * @doc {heading: 'Metrics', namespace: 'metrics'}
- */
- function meanSquaredError$2(yTrue, yPred) {
- return meanSquaredError$1(yTrue, yPred);
- }
- function MSE$2(yTrue, yPred) {
- return meanSquaredError$1(yTrue, yPred);
- }
- function mse$2(yTrue, yPred) {
- return meanSquaredError$1(yTrue, yPred);
- }
-
- var exports_metrics = /*#__PURE__*/Object.freeze({
- __proto__: null,
- binaryAccuracy: binaryAccuracy$1,
- binaryCrossentropy: binaryCrossentropy$2,
- sparseCategoricalAccuracy: sparseCategoricalAccuracy$1,
- categoricalAccuracy: categoricalAccuracy$1,
- categoricalCrossentropy: categoricalCrossentropy$2,
- precision: precision$1,
- recall: recall$1,
- cosineProximity: cosineProximity$1,
- meanAbsoluteError: meanAbsoluteError$1,
- meanAbsolutePercentageError: meanAbsolutePercentageError$1,
- MAPE: MAPE$2,
- mape: mape$2,
- meanSquaredError: meanSquaredError$2,
- MSE: MSE$2,
- mse: mse$2
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
-
- var exports_models = /*#__PURE__*/Object.freeze({
- __proto__: null,
- modelFromJSON: modelFromJSON
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- /**
- * Regularizer for L1 and L2 regularization.
- *
- * Adds a term to the loss to penalize large weights:
- * loss += sum(l1 * abs(x)) + sum(l2 * x^2)
- *
- * @doc {heading: 'Regularizers', namespace: 'regularizers'}
- */
- function l1l2(config) {
- return new L1L2(config);
- }
- /**
- * Regularizer for L1 regularization.
- *
- * Adds a term to the loss to penalize large weights:
- * loss += sum(l1 * abs(x))
- * @param args l1 config.
- *
- * @doc {heading: 'Regularizers', namespace: 'regularizers'}
- */
- function l1$1(config) {
- return l1(config);
- }
- /**
- * Regularizer for L2 regularization.
- *
- * Adds a term to the loss to penalize large weights:
- * loss += sum(l2 * x^2)
- * @param args l2 config.
- *
- * @doc {heading: 'Regularizers', namespace: 'regularizers'}
- */
- function l2$1(config) {
- return l2(config);
- }
-
- var exports_regularizers = /*#__PURE__*/Object.freeze({
- __proto__: null,
- l1l2: l1l2,
- l1: l1$1,
- l2: l2$1
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
- class Callback extends BaseCallback {
- constructor() {
- super(...arguments);
- /** Instance of `keras.models.Model`. Reference of the model being trained. */
- this.model = null;
- }
- setModel(model) {
- if (!(model instanceof LayersModel)) {
- throw new Error('model must be a LayersModel, not some other Container');
- }
- this.model = model;
- }
- }
- function less$1(currVal, prevVal) {
- return currVal < prevVal;
- }
- function greater$1(currVal, prevVal) {
- return currVal > prevVal;
- }
- /**
- * A Callback that stops training when a monitored quantity has stopped
- * improving.
- */
- class EarlyStopping extends Callback {
- constructor(args) {
- super();
- if (args == null) {
- args = {};
- }
- if (args.restoreBestWeights) {
- throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
- }
- this.monitor = args.monitor || 'val_loss';
- this.minDelta = Math.abs(args.minDelta || 0);
- this.patience = args.patience || 0;
- this.verbose = args.verbose || 0;
- this.mode = args.mode || 'auto';
- this.baseline = args.baseline;
- if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
- console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
- `Falling back to mode 'auto'.`);
- this.mode = 'auto';
- }
- if (this.mode === 'min') {
- this.monitorFunc = less$1;
- }
- else if (this.mode === 'max') {
- this.monitorFunc = greater$1;
- }
- else {
- // For mode === 'auto'.
- if (this.monitor.indexOf('acc') !== -1) {
- this.monitorFunc = greater$1;
- }
- else {
- this.monitorFunc = less$1;
- }
- }
- if (this.monitorFunc === less$1) {
- this.minDelta *= -1;
- }
- }
- async onTrainBegin(logs) {
- this.wait = 0;
- this.stoppedEpoch = 0;
- if (this.baseline != null) {
- this.best = this.baseline;
- }
- else {
- this.best = this.monitorFunc === less$1 ? Infinity : -Infinity;
- }
- }
- async onEpochEnd(epoch, logs) {
- await resolveScalarsInLogs(logs);
- const current = this.getMonitorValue(logs);
- if (current == null) {
- return;
- }
- if (this.monitorFunc(current - this.minDelta, this.best)) {
- this.best = current;
- this.wait = 0;
- // TODO(cais): Logic for restoreBestWeights.
- }
- else {
- this.wait++;
- if (this.wait >= this.patience) {
- this.stoppedEpoch = epoch;
- this.model.stopTraining = true;
- }
- // TODO(cais): Logic for restoreBestWeights.
- }
- }
- async onTrainEnd(logs) {
- if (this.stoppedEpoch > 0 && this.verbose) {
- console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
- }
- }
- getMonitorValue(logs) {
- if (logs == null) {
- logs = {};
- }
- const monitorValue = logs[this.monitor];
- if (monitorValue == null) {
- console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
- `Available metrics are: ${Object.keys(logs)}`);
- }
- return monitorValue;
- }
- }
- /**
- * Factory function for a Callback that stops training when a monitored
- * quantity has stopped improving.
- *
- * Early stopping is a type of regularization, and protects model against
- * overfitting.
- *
- * The following example based on fake data illustrates how this callback
- * can be used during `tf.LayersModel.fit()`:
- *
- * ```js
- * const model = tf.sequential();
- * model.add(tf.layers.dense({
- * units: 3,
- * activation: 'softmax',
- * kernelInitializer: 'ones',
- * inputShape: [2]
- * }));
- * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
- * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
- * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
- * model.compile(
- * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
- *
- * // Without the EarlyStopping callback, the val_acc value would be:
- * // 0.5, 0.5, 0.5, 0.5, ...
- * // With val_acc being monitored, training should stop after the 2nd epoch.
- * const history = await model.fit(xs, ys, {
- * epochs: 10,
- * validationData: [xsVal, ysVal],
- * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
- * });
- *
- * // Expect to see a length-2 array.
- * console.log(history.history.val_acc);
- * ```
- *
- * @doc {
- * heading: 'Callbacks',
- * namespace: 'callbacks'
- * }
- */
- function earlyStopping(args) {
- return new EarlyStopping(args);
- }
- const callbacks = { earlyStopping };
-
- /**
- * @license
- * Copyright 2018 Google LLC
- *
- * Use of this source code is governed by an MIT-style
- * license that can be found in the LICENSE file or at
- * https://opensource.org/licenses/MIT.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /** DataType enum. */
- var DataType;
- (function (DataType) {
- DataType[DataType["DT_INVALID"] = 0] = "DT_INVALID";
- DataType[DataType["DT_FLOAT"] = 1] = "DT_FLOAT";
- DataType[DataType["DT_DOUBLE"] = 2] = "DT_DOUBLE";
- DataType[DataType["DT_INT32"] = 3] = "DT_INT32";
- DataType[DataType["DT_UINT8"] = 4] = "DT_UINT8";
- DataType[DataType["DT_INT16"] = 5] = "DT_INT16";
- DataType[DataType["DT_INT8"] = 6] = "DT_INT8";
- DataType[DataType["DT_STRING"] = 7] = "DT_STRING";
- DataType[DataType["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
- DataType[DataType["DT_INT64"] = 9] = "DT_INT64";
- DataType[DataType["DT_BOOL"] = 10] = "DT_BOOL";
- DataType[DataType["DT_QINT8"] = 11] = "DT_QINT8";
- DataType[DataType["DT_QUINT8"] = 12] = "DT_QUINT8";
- DataType[DataType["DT_QINT32"] = 13] = "DT_QINT32";
- DataType[DataType["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
- DataType[DataType["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
- DataType[DataType["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
- DataType[DataType["DT_INT32_REF"] = 103] = "DT_INT32_REF";
- DataType[DataType["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
- DataType[DataType["DT_INT16_REF"] = 105] = "DT_INT16_REF";
- DataType[DataType["DT_INT8_REF"] = 106] = "DT_INT8_REF";
- DataType[DataType["DT_STRING_REF"] = 107] = "DT_STRING_REF";
- DataType[DataType["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
- DataType[DataType["DT_INT64_REF"] = 109] = "DT_INT64_REF";
- DataType[DataType["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
- DataType[DataType["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
- DataType[DataType["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
- DataType[DataType["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
- DataType[DataType["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
- })(DataType || (DataType = {}));
- var SaverDef;
- (function (SaverDef) {
- /** CheckpointFormatVersion enum. */
- let CheckpointFormatVersion;
- (function (CheckpointFormatVersion) {
- CheckpointFormatVersion[CheckpointFormatVersion["LEGACY"] = 0] = "LEGACY";
- CheckpointFormatVersion[CheckpointFormatVersion["V1"] = 1] = "V1";
- CheckpointFormatVersion[CheckpointFormatVersion["V2"] = 2] = "V2";
- })(CheckpointFormatVersion = SaverDef.CheckpointFormatVersion || (SaverDef.CheckpointFormatVersion = {}));
- })(SaverDef || (SaverDef = {}));
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const CUSTOM_OPS = {};
- /**
- * Register an Op for graph model executor. This allow you to register
- * TensorFlow custom op or override existing op.
- *
- * Here is an example of registering a new MatMul Op.
- * ```js
- * const customMatmul = (node) =>
- * tf.matMul(
- * node.inputs[0], node.inputs[1],
- * node.attrs['transpose_a'], node.attrs['transpose_b']);
- *
- * tf.registerOp('MatMul', customMatmul);
- * ```
- * The inputs and attrs of the node object is based on the TensorFlow op
- * registry.
- *
- * @param name The Tensorflow Op name.
- * @param opFunc An op function which is called with the current graph node
- * during execution and needs to return a tensor or a list of tensors. The node
- * has the following attributes:
- * - attr: A map from attribute name to its value
- * - inputs: A list of input tensors
- *
- * @doc {heading: 'Models', subheading: 'Op Registry'}
- */
- function registerOp(name, opFunc) {
- const opMapper = {
- tfOpName: name,
- category: 'custom',
- inputs: [],
- attrs: [],
- customExecutor: opFunc
- };
- CUSTOM_OPS[name] = opMapper;
- }
- /**
- * Retrieve the OpMapper object for the registered op.
- *
- * @param name The Tensorflow Op name.
- *
- * @doc {heading: 'Models', subheading: 'Op Registry'}
- */
- function getRegisteredOp(name) {
- return CUSTOM_OPS[name];
- }
- /**
- * Deregister the Op for graph model executor.
- *
- * @param name The Tensorflow Op name.
- *
- * @doc {heading: 'Models', subheading: 'Op Registry'}
- */
- function deregisterOp(name) {
- delete CUSTOM_OPS[name];
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function getParamValue(paramName, node, tensorMap, context, resourceManager) {
- const inputParam = node.inputParams[paramName];
- if (inputParam && inputParam.inputIndexStart !== undefined) {
- const start = inputParam.inputIndexStart;
- const end = inputParam.inputIndexEnd === 0 ?
- undefined :
- (inputParam.inputIndexEnd === undefined ? start + 1 :
- inputParam.inputIndexEnd);
- if (inputParam.type === 'tensor') {
- return getTensor(node.inputNames[inputParam.inputIndexStart], tensorMap, context, resourceManager);
- }
- if (inputParam.type === 'tensors') {
- const inputs = node.inputNames.slice(start, end);
- return inputs.map(name => getTensor(name, tensorMap, context, resourceManager));
- }
- const tensor = getTensor(node.inputNames.slice(start)[0], tensorMap, context, resourceManager);
- const data = tensor.dataSync();
- return inputParam.type === 'number' ?
- data[0] :
- toNestedArray(tensor.shape, data);
- }
- const attrParam = node.attrParams[paramName];
- return attrParam && attrParam.value;
- }
- /**
- * Retrieve the tensor from tensorsMap based on input name.
- * @param name Node input name
- * @param tensorsMap Tensors map keyed by the node
- * @param context contains tensors and information for running the current node.
- * @param resourceManager Optional. Contains global resources of the model.
- */
- function getTensor(name, tensorsMap, context, resourceManager) {
- const [nodeName, index] = parseNodeName(name);
- if (resourceManager != null) {
- const tensor = resourceManager.getHashTableHandleByName(nodeName);
- if (tensor != null) {
- return tensor;
- }
- }
- const contextId = context.currentContextIds.find(contextId => {
- return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)];
- });
- return contextId !== undefined ?
- tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] :
- undefined;
- }
- /**
- * Retrieve the tensors based on input name for current context.
- * @param name Node input name
- * @param tensorsMap Tensors map keyed by the node
- */
- function getTensorsForCurrentContenxt(name, tensorsMap, context) {
- return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
- }
- /**
- * Returns the node name and index from the Node input name.
- * @param inputName The input name of the node, in format of
- * node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is
- * default to 0.
- */
- function getNodeNameAndIndex(inputName, context) {
- const [nodeName, index] = parseNodeName(inputName);
- return [
- getNodeNameWithContextId(nodeName, context && context.currentContextId),
- index
- ];
- }
- function getNodeNameWithContextId(name, contextId) {
- return !!contextId ? `${name}-${contextId}` : name;
- }
- function parseNodeName(name) {
- const parts = name.split(':');
- if (parts.length === 1) {
- return [name, 0];
- }
- const nodeName = parts[0];
- return [nodeName, Number(parts[parts.length - 1])];
- }
- function split$2(arr, size) {
- const res = [];
- for (let i = 0; i < arr.length; i += size) {
- res.push(arr.slice(i, i + size));
- }
- return res;
- }
- function getPadding(node, tensorMap, context) {
- let pad = getParamValue('pad', node, tensorMap, context);
- if (pad === 'explicit') {
- // This is 1d array, we need to convert it to 2d array
- pad = getParamValue('explicitPaddings', node, tensorMap, context);
- const explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
- for (let i = 0; i < 4; i++) {
- explicitPadding[i][0] = pad[i * 2];
- explicitPadding[i][1] = pad[i * 2 + 1];
- }
- return explicitPadding;
- }
- return pad;
- }
- /**
- * Reuse the tensor if it is marked as keep, otherwise clone the tensor to
- * avoid disposal. This is important for TensorArray and TensorList ops, since
- * internally they use a tensor as the id for TensorArray and TensorList, and
- * to simplify lookup, they also use Tensor.id as the key to the internal map.
- * These id tensors have been marked as kept in the backend, we need avoid clone
- * them in order to create new Tensor.id.
- * @param tensor
- */
- function cloneTensor(tensor) {
- return tensor.kept ? tensor : clone(tensor);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json = [
- {
- 'tfOpName': 'Add',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'AddV2',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'AddN',
- 'category': 'arithmetic',
- 'inputs': [{ 'start': 0, 'end': 0, 'name': 'tensors', 'type': 'tensors' }]
- },
- {
- 'tfOpName': 'BiasAdd',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sub',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'RealDiv',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Div',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'DivNoNan',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'FloorDiv',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Mul',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Maximum',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' }
- ]
- },
- {
- 'tfOpName': 'Minimum',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' }
- ]
- },
- {
- 'tfOpName': 'Pow',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'SquaredDifference',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Mod',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'FloorMod',
- 'category': 'arithmetic',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- }
- ];
-
- var arithmetic = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$1 = [
- {
- 'tfOpName': 'Abs',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Acos',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Asin',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Atan',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Atan2',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'y', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Ceil',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'ClipByValue',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'clip_value_min', 'name': 'clipValueMin', 'type': 'number' },
- { 'tfName': 'clip_value_max', 'name': 'clipValueMax', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'Complex',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'real', 'type': 'tensor' },
- { 'start': 1, 'name': 'imag', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'ComplexAbs',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Cos',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Cosh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Elu',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Exp',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Floor',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Log',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Imag',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'outputType',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'Neg',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Real',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'outputType',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'Prelu',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'alpha', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Relu',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Relu6',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'clipValueMin',
- 'name': 'clipValueMin',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'clipValueMax',
- 'name': 'clipValueMax',
- 'type': 'number',
- 'defaultValue': 6
- }
- ]
- },
- {
- 'tfOpName': 'Selu',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sigmoid',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sin',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sinh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sqrt',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Rsqrt',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Square',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Tan',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Tanh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Sign',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Round',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Expm1',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Log1p',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Reciprocal',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Softplus',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Asinh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Acosh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Atanh',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Erf',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Prod',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axes', 'type': 'number[]' },
- ],
- 'attrs': [
- {
- 'tfName': 'keep_dims',
- 'name': 'keepDims',
- 'type': 'bool',
- 'notSupported': true
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'LeakyRelu',
- 'category': 'basic_math',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'alpha',
- 'name': 'alpha',
- 'type': 'number',
- 'defaultValue': 0.2
- },
- {
- 'tfName': 'T',
- 'name': 'dtype',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- }
- ];
-
- var basicMath = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$1
- });
-
- const json$2 = [
- {
- 'tfOpName': 'LoopCond',
- 'category': 'control',
- 'inputs': [{ 'start': 0, 'name': 'pred', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'Switch',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'data', 'type': 'tensor' },
- { 'start': 1, 'name': 'pred', 'type': 'tensor' }
- ]
- },
- {
- 'tfOpName': 'Merge',
- 'category': 'control',
- 'inputs': [{ 'start': 0, 'end': 0, 'name': 'tensors', 'type': 'tensors' }]
- },
- {
- 'tfOpName': 'Enter',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true },
- { 'tfName': 'frame_name', 'name': 'frameName', 'type': 'string' },
- { 'tfName': 'is_constant', 'name': 'isConstant', 'type': 'bool' }
- ]
- },
- {
- 'tfOpName': 'Exit',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'NextIteration',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'TensorArrayV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'size', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' },
- { 'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape' },
- { 'tfName': 'dynamic_size', 'name': 'dynamicSize', 'type': 'bool' },
- { 'tfName': 'clear_after_read', 'name': 'clearAfterRead', 'type': 'bool' },
- {
- 'tfName': 'identical_element_shapes',
- 'name': 'identicalElementShapes',
- 'type': 'bool'
- },
- { 'tfName': 'tensor_array_name', 'name': 'name', 'type': 'string' }
- ]
- },
- {
- 'tfOpName': 'TensorArrayWriteV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'index', 'type': 'number' },
- { 'start': 2, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 3, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'TensorArrayReadV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'index', 'type': 'number' },
- { 'start': 2, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [{
- 'tfName': 'dtype',
- 'name': 'dtype',
- 'type': 'dtype',
- 'notSupported': true
- }]
- },
- {
- 'tfOpName': 'TensorArrayGatherV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'number[]' },
- { 'start': 2, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' },
- { 'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape' }
- ]
- },
- {
- 'tfOpName': 'TensorArrayScatterV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'number[]' },
- { 'start': 2, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 3, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorArrayConcatV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' }, {
- 'tfName': 'element_shape_except0',
- 'name': 'elementShapeExcept0',
- 'type': 'shape',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'TensorArraySplitV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 2, 'name': 'lengths', 'type': 'number[]' },
- { 'start': 3, 'name': 'flowIn', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorArraySizeV3',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' },
- { 'start': 1, 'name': 'flowIn', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'TensorArrayCloseV3',
- 'category': 'control',
- 'inputs': [{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'StatelessIf',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'cond', 'type': 'tensor' },
- { 'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors' }
- ],
- 'attrs': [
- { 'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func' },
- { 'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func' }
- ]
- },
- {
- 'tfOpName': 'If',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'cond', 'type': 'tensor' },
- { 'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors' }
- ],
- 'attrs': [
- { 'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func' },
- { 'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func' }
- ]
- },
- {
- 'tfOpName': 'StatelessWhile',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'cond', 'name': 'cond', 'type': 'func' },
- { 'tfName': 'body', 'name': 'body', 'type': 'func' }
- ]
- },
- {
- 'tfOpName': 'While',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'cond', 'name': 'cond', 'type': 'func' },
- { 'tfName': 'body', 'name': 'body', 'type': 'func' }
- ]
- },
- {
- 'tfOpName': 'TensorListScatter',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'number[]' },
- { 'start': 2, 'name': 'elementShape', 'type': 'shape' }
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListScatterV2',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'number[]' },
- { 'start': 2, 'name': 'elementShape', 'type': 'shape' },
- { 'start': 3, 'name': 'numElements', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListGather',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'number[]' },
- { 'start': 2, 'name': 'elementShape', 'type': 'shape' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListGetItem',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'index', 'type': 'number' },
- { 'start': 2, 'name': 'elementShape', 'type': 'shape' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListSetItem',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'index', 'type': 'number' },
- { 'start': 2, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListReserve',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'elementShape', 'type': 'shape' },
- { 'start': 1, 'name': 'numElements', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListFromTensor',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 1, 'name': 'elementShape', 'type': 'shape' }
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListStack',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'elementShape', 'type': 'shape' },
- ],
- 'attrs': [
- { 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' },
- { 'tfName': 'num_elements', 'name': 'numElements', 'type': 'dtype' }
- ]
- },
- {
- 'tfOpName': 'TensorListSplit',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- { 'start': 1, 'name': 'elementShape', 'type': 'shape' },
- { 'start': 2, 'name': 'lengths', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListConcat',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape' },
- { 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }
- ]
- },
- {
- 'tfOpName': 'TensorListPopBack',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'elementShape', 'type': 'shape' },
- ],
- 'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TensorListPushBack',
- 'category': 'control',
- 'inputs': [
- { 'start': 0, 'name': 'tensorListId', 'type': 'tensor' },
- { 'start': 1, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }
- ]
- }
- ];
-
- var control = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$2
- });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$3 = [
- {
- 'tfOpName': 'AvgPool',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- },
- { 'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'MaxPool',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- },
- { 'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'MaxPoolWithArgmax',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' },
- { 'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]' }, {
- 'tfName': 'include_batch_in_index',
- 'name': 'includeBatchInIndex',
- 'type': 'bool'
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'AvgPool3D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- },
- { 'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'MaxPool3D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- },
- { 'tfName': 'ksize', 'name': 'kernelSize', 'type': 'number[]' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Conv1D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'stride', 'name': 'stride', 'type': 'number' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NWC'
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'dilation',
- 'name': 'dilation',
- 'type': 'number',
- 'defaultValue': 1
- }
- ]
- },
- {
- 'tfOpName': 'Conv2D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true },
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' },
- { 'tfName': 'useCudnnOnGpu', 'name': 'useCudnnOnGpu', 'type': 'bool' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- {
- 'tfName': 'explicit_paddings',
- 'name': 'explicitPaddings',
- 'type': 'number[]',
- 'defaultValue': []
- },
- { 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': '_FusedConv2D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- { 'start': 2, end: 0, 'name': 'args', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'num_args', 'name': 'numArgs', 'type': 'number' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true },
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' },
- {
- 'tfName': 'explicit_paddings',
- 'name': 'explicitPaddings',
- 'type': 'number[]',
- 'defaultValue': []
- },
- {
- 'tfName': 'use_cudnn_on_gpu',
- 'name': 'useCudnnOnGpu',
- 'type': 'bool',
- 'defaultValue': true
- },
- {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- {
- 'tfName': 'dilations',
- 'name': 'dilations',
- 'type': 'number[]',
- 'defaultValue': [1, 1, 1, 1]
- },
- {
- 'tfName': 'fused_ops',
- 'name': 'fusedOps',
- 'type': 'string[]',
- 'defaultValue': []
- },
- {
- 'tfName': 'epsilon',
- 'name': 'epsilon',
- 'type': 'number',
- 'defaultValue': 0.0001
- },
- ]
- },
- {
- 'tfOpName': 'Conv2DBackpropInput',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 2, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- { 'start': 0, 'name': 'outputShape', 'type': 'number[]' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' },
- {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- },
- {
- 'tfName': 'explicit_paddings',
- 'name': 'explicitPaddings',
- 'type': 'number[]',
- 'defaultValue': []
- },
- ]
- },
- {
- 'tfOpName': 'DepthwiseConv2d',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'input', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- {
- 'tfName': 'explicit_paddings',
- 'name': 'explicitPaddings',
- 'type': 'number[]',
- 'defaultValue': []
- },
- { 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'DepthwiseConv2dNative',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'input', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- {
- 'tfName': 'explicit_paddings',
- 'name': 'explicitPaddings',
- 'type': 'number[]',
- 'defaultValue': []
- },
- { 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'FusedDepthwiseConv2dNative',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- { 'start': 2, end: 0, 'name': 'args', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'num_args', 'name': 'numArgs', 'type': 'number' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true },
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- {
- 'tfName': 'dilations',
- 'name': 'dilations',
- 'type': 'number[]',
- 'defaultValue': [1, 1, 1, 1]
- },
- {
- 'tfName': 'fused_ops',
- 'name': 'fusedOps',
- 'type': 'string[]',
- 'defaultValue': []
- }
- ]
- },
- {
- 'tfOpName': 'Conv3D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'defaultValue': 'NHWC'
- },
- { 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' }
- ],
- },
- {
- 'tfOpName': 'Dilation2D',
- 'category': 'convolution',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'filter', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' },
- { 'tfName': 'rates', 'name': 'dilations', 'type': 'number[]' },
- { 'tfName': 'padding', 'name': 'pad', 'type': 'string' }
- ]
- }
- ];
-
- var convolution = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$3
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$4 = [
- {
- 'tfOpName': 'Fill',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'shape', 'type': 'number[]' },
- { 'start': 1, 'name': 'value', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'LinSpace',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'start', 'type': 'number' },
- { 'start': 1, 'name': 'stop', 'type': 'number' },
- { 'start': 2, 'name': 'num', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'OneHot',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'indices', 'type': 'tensor' },
- { 'start': 1, 'name': 'depth', 'type': 'number' },
- { 'start': 2, 'name': 'onValue', 'type': 'number', 'defaultValue': 1 },
- { 'start': 3, 'name': 'offValue', 'type': 'number', 'defaultValue': 0 },
- ],
- 'attrs': [
- {
- 'tfName': 'axis',
- 'name': 'axis',
- 'type': 'number',
- 'notSupported': true
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Ones',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'shape', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'OnesLike',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [{ 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'RandomUniform',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'shape', 'type': 'number[]' },
- ],
- 'attrs': [
- {
- 'tfName': 'minval',
- 'name': 'minval',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'maxval',
- 'name': 'maxval',
- 'type': 'number',
- 'defaultValue': 1
- },
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' },
- { 'tfName': 'seed', 'name': 'seed', 'type': 'number', 'defaultValue': 0 }, {
- 'tfName': 'seed2',
- 'name': 'seed2',
- 'type': 'number',
- 'defaultValue': 0,
- 'notSupported': true
- },
- { 'tfName': 'T', 'name': 'T', 'type': 'number', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Range',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'start', 'type': 'number' },
- { 'start': 1, 'name': 'stop', 'type': 'number' },
- { 'start': 2, 'name': 'step', 'type': 'number', 'defaultValue': 0 },
- ],
- 'attrs': [{ 'tfName': 'Tidx', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'TruncatedNormal',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'shape', 'type': 'number[]' },
- ],
- 'attrs': [
- {
- 'tfName': 'means',
- 'name': 'mean',
- 'type': 'number',
- 'defaultValue': 0.0
- },
- {
- 'tfName': 'stddev',
- 'name': 'stdDev',
- 'type': 'number',
- 'defaultValue': 1.0
- },
- { 'tfName': 'seed', 'name': 'seed', 'type': 'number' }, {
- 'tfName': 'seed2',
- 'name': 'seed2',
- 'type': 'number',
- 'defaultValue': 0,
- 'notSupported': true
- },
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' },
- { 'tfName': 'T', 'name': 'T', 'type': 'number', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Zeros',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'shape', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'ZerosLike',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' }]
- },
- {
- 'tfOpName': 'Multinomial',
- 'category': 'creation',
- 'inputs': [
- { 'start': 0, 'name': 'logits', 'type': 'tensor' },
- { 'start': 1, 'name': 'numSamples', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'seed', 'name': 'seed', 'type': 'number' },
- { 'tfName': 'seed2', 'name': 'seed2', 'type': 'number' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype' },
- { 'tfName': 'output_dtype', 'name': 'output_dtype', 'type': 'dtype' }
- ]
- }
- ];
-
- var creation = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$4
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$5 = [
- {
- 'tfOpName': 'NonMaxSuppressionV2',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'boxes', 'type': 'tensor' },
- { 'start': 1, 'name': 'scores', 'type': 'tensor' },
- { 'start': 2, 'name': 'maxOutputSize', 'type': 'number' },
- { 'start': 3, 'name': 'iouThreshold', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'NonMaxSuppressionV3',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'boxes', 'type': 'tensor' },
- { 'start': 1, 'name': 'scores', 'type': 'tensor' },
- { 'start': 2, 'name': 'maxOutputSize', 'type': 'number' },
- { 'start': 3, 'name': 'iouThreshold', 'type': 'number' },
- { 'start': 4, 'name': 'scoreThreshold', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'NonMaxSuppressionV4',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'boxes', 'type': 'tensor' },
- { 'start': 1, 'name': 'scores', 'type': 'tensor' },
- { 'start': 2, 'name': 'maxOutputSize', 'type': 'number' },
- { 'start': 3, 'name': 'iouThreshold', 'type': 'number' },
- { 'start': 4, 'name': 'scoreThreshold', 'type': 'number' }
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'T_threshold',
- 'name': 'threshold',
- 'type': 'dtype',
- 'notSupported': true
- },
- {
- 'tfName': 'pad_to_max_output_size',
- 'name': 'padToMaxOutputSize',
- 'type': 'bool'
- }
- ]
- },
- {
- 'tfOpName': 'NonMaxSuppressionV5',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'boxes', 'type': 'tensor' },
- { 'start': 1, 'name': 'scores', 'type': 'tensor' },
- { 'start': 2, 'name': 'maxOutputSize', 'type': 'number' },
- { 'start': 3, 'name': 'iouThreshold', 'type': 'number' },
- { 'start': 4, 'name': 'scoreThreshold', 'type': 'number' },
- { 'start': 5, 'name': 'softNmsSigma', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'Where',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'condition', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'ListDiff',
- 'category': 'dynamic',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'y', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'T',
- 'name': 'dtype',
- 'type': 'dtype',
- 'notSupported': true
- }]
- }
- ];
-
- var dynamic = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$5
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$6 = [
- {
- 'tfOpName': 'TopKV2',
- 'category': 'evaluation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'k', 'type': 'number' },
- ],
- 'attrs': [{ 'tfName': 'sorted', 'name': 'sorted', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Unique',
- 'category': 'evaluation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- },
- {
- 'tfOpName': 'UniqueV2',
- 'category': 'evaluation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number' },
- ],
- },
- ];
-
- var evaluation = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$6
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$7 = [
- {
- 'tfOpName': 'PlaceholderWithDefault',
- 'category': 'graph',
- 'inputs': [
- { 'start': 0, 'name': 'default', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'shape', 'name': 'shape', 'type': 'shape' },
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' }
- ]
- },
- {
- 'tfOpName': 'Placeholder',
- 'category': 'graph',
- 'attrs': [
- { 'tfName': 'shape', 'name': 'shape', 'type': 'shape' },
- { 'tfName': 'dtype', 'name': 'dtype', 'type': 'dtype' }
- ]
- },
- { 'tfOpName': 'Const', 'category': 'graph' }, {
- 'tfOpName': 'Identity',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'IdentityN',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'end': 0, 'name': 'x', 'type': 'tensors' }]
- },
- {
- 'tfOpName': 'Snapshot',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'Rank',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'Size',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'Shape',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'ShapeN',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'end': 0, 'name': 'x', 'type': 'tensors' }]
- },
- {
- 'tfOpName': 'Print',
- 'category': 'graph',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'data', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'message', 'name': 'message', 'type': 'string' }, {
- 'tfName': 'first_n',
- 'name': 'firstN',
- 'type': 'number',
- 'notSupported': true
- },
- {
- 'tfName': 'summarize',
- 'name': 'summarize',
- 'type': 'number',
- 'defaultValue': 3
- }
- ]
- },
- { 'tfOpName': 'NoOp', 'category': 'graph', 'inputs': [] }, {
- 'tfOpName': 'StopGradient',
- 'category': 'graph',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'FakeQuantWithMinMaxVars',
- 'category': 'graph',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'min', 'name': 'min', 'type': 'number' },
- { 'tfName': 'max', 'name': 'max', 'type': 'number' }
- ]
- }
- ];
-
- var graph = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$7
- });
-
- const json$8 = [
- {
- 'tfOpName': 'HashTable',
- 'category': 'hash_table',
- 'inputs': [],
- 'attrs': [
- { 'tfName': 'shared_name', 'name': 'sharedName', 'type': 'string' },
- {
- 'tfName': 'use_node_name_sharing',
- 'name': 'useNodeNameSharing',
- 'type': 'bool'
- },
- { 'tfName': 'key_dtype', 'name': 'keyDType', 'type': 'dtype' },
- { 'tfName': 'value_dtype', 'name': 'valueDType', 'type': 'dtype' },
- ]
- },
- {
- 'tfOpName': 'HashTableV2',
- 'category': 'hash_table',
- 'inputs': [],
- 'attrs': [
- { 'tfName': 'shared_name', 'name': 'sharedName', 'type': 'string' },
- {
- 'tfName': 'use_node_name_sharing',
- 'name': 'useNodeNameSharing',
- 'type': 'bool'
- },
- { 'tfName': 'key_dtype', 'name': 'keyDType', 'type': 'dtype' },
- { 'tfName': 'value_dtype', 'name': 'valueDType', 'type': 'dtype' },
- ]
- },
- {
- 'tfOpName': 'LookupTableImport',
- 'category': 'hash_table',
- 'inputs': [
- { 'start': 0, 'name': 'tableHandle', 'type': 'tensor' },
- { 'start': 1, 'name': 'keys', 'type': 'tensor' },
- { 'start': 2, 'name': 'values', 'type': 'tensor' }
- ],
- 'attrs': [
- { 'tfName': 'Tin', 'name': 'tIn', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'tOut',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'LookupTableImportV2',
- 'category': 'hash_table',
- 'inputs': [
- { 'start': 0, 'name': 'tableHandle', 'type': 'tensor' },
- { 'start': 1, 'name': 'keys', 'type': 'tensor' },
- { 'start': 2, 'name': 'values', 'type': 'tensor' }
- ],
- 'attrs': [
- { 'tfName': 'Tin', 'name': 'tIn', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'tOut',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'LookupTableFind',
- 'category': 'hash_table',
- 'inputs': [
- { 'start': 0, 'name': 'tableHandle', 'type': 'tensor' },
- { 'start': 1, 'name': 'keys', 'type': 'tensor' },
- { 'start': 2, 'name': 'defaultValue', 'type': 'tensor' }
- ],
- 'attrs': [
- { 'tfName': 'Tin', 'name': 'tIn', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'tOut',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'LookupTableFindV2',
- 'category': 'hash_table',
- 'inputs': [
- { 'start': 0, 'name': 'tableHandle', 'type': 'tensor' },
- { 'start': 1, 'name': 'keys', 'type': 'tensor' },
- { 'start': 2, 'name': 'defaultValue', 'type': 'tensor' }
- ],
- 'attrs': [
- { 'tfName': 'Tin', 'name': 'tIn', 'type': 'dtype', 'notSupported': true }, {
- 'tfName': 'Tout',
- 'name': 'tOut',
- 'type': 'dtype',
- 'notSupported': true
- }
- ]
- }
- ];
-
- var hashTable = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$8
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$9 = [
- {
- 'tfOpName': 'ResizeBilinear',
- 'category': 'image',
- 'inputs': [
- { 'start': 0, 'name': 'images', 'type': 'tensor' },
- { 'start': 1, 'name': 'size', 'type': 'number[]' },
- ],
- 'attrs': [
- { 'tfName': 'align_corners', 'name': 'alignCorners', 'type': 'bool' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'ResizeNearestNeighbor',
- 'category': 'image',
- 'inputs': [
- { 'start': 0, 'name': 'images', 'type': 'tensor' },
- { 'start': 1, 'name': 'size', 'type': 'number[]' },
- ],
- 'attrs': [
- { 'tfName': 'align_corners', 'name': 'alignCorners', 'type': 'bool' },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'CropAndResize',
- 'category': 'image',
- 'inputs': [
- { 'start': 0, 'name': 'image', 'type': 'tensor' },
- { 'start': 1, 'name': 'boxes', 'type': 'tensor' },
- { 'start': 2, 'name': 'boxInd', 'type': 'tensor' },
- { 'start': 3, 'name': 'cropSize', 'type': 'number[]' },
- ],
- 'attrs': [
- { 'tfName': 'method', 'name': 'method', 'type': 'string' }, {
- 'tfName': 'extrapolation_value',
- 'name': 'extrapolationValue',
- 'type': 'number'
- }
- ]
- }
- ];
-
- var image$1 = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$9
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$a = [
- {
- 'tfOpName': 'Equal',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'NotEqual',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Greater',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'GreaterEqual',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Less',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'LessEqual',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'LogicalAnd',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'LogicalNot',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'LogicalOr',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Select',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'condition', 'type': 'tensor' },
- { 'start': 1, 'name': 'a', 'type': 'tensor' },
- { 'start': 2, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'SelectV2',
- 'category': 'logical',
- 'inputs': [
- { 'start': 0, 'name': 'condition', 'type': 'tensor' },
- { 'start': 1, 'name': 'a', 'type': 'tensor' },
- { 'start': 2, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'T',
- 'name': 'dtype',
- 'type': 'dtype',
- 'notSupported': true
- }]
- }
- ];
-
- var logical = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$a
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$b = [
- {
- 'tfOpName': '_FusedMatMul',
- 'category': 'matrices',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- { 'start': 2, end: 0, 'name': 'args', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'num_args', 'name': 'numArgs', 'type': 'number' }, {
- 'tfName': 'fused_ops',
- 'name': 'fusedOps',
- 'type': 'string[]',
- 'defaultValue': []
- },
- {
- 'tfName': 'epsilon',
- 'name': 'epsilon',
- 'type': 'number',
- 'defaultValue': 0.0001
- },
- {
- 'tfName': 'transpose_a',
- 'name': 'transposeA',
- 'type': 'bool',
- 'defaultValue': false
- },
- {
- 'tfName': 'transpose_b',
- 'name': 'transposeB',
- 'type': 'bool',
- 'defaultValue': false
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'MatMul',
- 'category': 'matrices',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'transpose_a',
- 'name': 'transposeA',
- 'type': 'bool',
- 'defaultValue': false
- },
- {
- 'tfName': 'transpose_b',
- 'name': 'transposeB',
- 'type': 'bool',
- 'defaultValue': false
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'BatchMatMul',
- 'category': 'matrices',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'adj_x',
- 'name': 'transposeA',
- 'type': 'bool',
- 'defaultValue': false
- },
- {
- 'tfName': 'adj_y',
- 'name': 'transposeB',
- 'type': 'bool',
- 'defaultValue': false
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'BatchMatMulV2',
- 'category': 'matrices',
- 'inputs': [
- { 'start': 0, 'name': 'a', 'type': 'tensor' },
- { 'start': 1, 'name': 'b', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'adj_x',
- 'name': 'transposeA',
- 'type': 'bool',
- 'defaultValue': false
- },
- {
- 'tfName': 'adj_y',
- 'name': 'transposeB',
- 'type': 'bool',
- 'defaultValue': false
- },
- { 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'Transpose',
- 'category': 'matrices',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'perm', 'type': 'number[]' },
- ],
- 'attrs': [{
- 'tfName': 'T',
- 'name': 'dtype',
- 'type': 'dtype',
- 'notSupported': true
- }]
- }
- ];
-
- var matrices = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$b
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$c = [
- {
- 'tfOpName': 'FusedBatchNorm',
- 'category': 'normalization',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'scale', 'type': 'tensor' },
- { 'start': 2, 'name': 'offset', 'type': 'tensor' },
- { 'start': 3, 'name': 'mean', 'type': 'tensor' },
- { 'start': 4, 'name': 'variance', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'epsilon',
- 'name': 'epsilon',
- 'type': 'number',
- 'defaultValue': 0.001
- },
- {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'FusedBatchNormV2',
- 'category': 'normalization',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'scale', 'type': 'tensor' },
- { 'start': 2, 'name': 'offset', 'type': 'tensor' },
- { 'start': 3, 'name': 'mean', 'type': 'tensor' },
- { 'start': 4, 'name': 'variance', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'epsilon',
- 'name': 'epsilon',
- 'type': 'number',
- 'defaultValue': 0.001
- },
- {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'FusedBatchNormV3',
- 'category': 'normalization',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'scale', 'type': 'tensor' },
- { 'start': 2, 'name': 'offset', 'type': 'tensor' },
- { 'start': 3, 'name': 'mean', 'type': 'tensor' },
- { 'start': 4, 'name': 'variance', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'epsilon',
- 'name': 'epsilon',
- 'type': 'number',
- 'defaultValue': 0.001
- },
- {
- 'tfName': 'data_format',
- 'name': 'dataFormat',
- 'type': 'string',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'LRN',
- 'category': 'normalization',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'depth_radius',
- 'name': 'radius',
- 'type': 'number',
- 'defaultValue': 5
- },
- { 'tfName': 'bias', 'name': 'bias', 'type': 'number', 'defaultValue': 1.0 },
- {
- 'tfName': 'alpha',
- 'name': 'alpha',
- 'type': 'number',
- 'defaultValue': 1.0
- },
- {
- 'tfName': 'beta',
- 'name': 'beta',
- 'type': 'number',
- 'defaultValue': 0.5
- }
- ]
- },
- {
- 'tfOpName': 'Softmax',
- 'category': 'normalization',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'LogSoftmax',
- 'category': 'normalization',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'SparseToDense',
- 'category': 'normalization',
- 'inputs': [
- { 'start': 0, 'name': 'sparseIndices', 'type': 'tensor' },
- { 'start': 1, 'name': 'outputShape', 'type': 'number[]' },
- { 'start': 2, 'name': 'sparseValues', 'type': 'tensor' },
- { 'start': 3, 'name': 'defaultValue', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'validate_indices',
- 'name': 'validateIndices',
- 'type': 'bool',
- 'defaultValue': true,
- 'notSupported': true
- }]
- }
- ];
-
- var normalization = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$c
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$d = [
- {
- 'tfOpName': 'Max',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Mean',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Min',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Sum',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'All',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Any',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'ArgMax',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'ArgMin',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'Prod',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }]
- },
- {
- 'tfOpName': 'Cumsum',
- 'category': 'reduction',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number' },
- ],
- 'attrs': [
- { 'tfName': 'exclusive', 'name': 'exclusive', 'type': 'bool' },
- { 'tfName': 'reverse', 'name': 'reverse', 'type': 'bool' }
- ]
- }
- ];
-
- var reduction = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$d
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$e = [
- {
- 'tfOpName': 'ConcatV2',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'end': -1, 'name': 'tensors', 'type': 'tensors' },
- { 'start': -1, 'name': 'axis', 'type': 'number' }
- ],
- 'attrs': [{ 'tfName': 'N', 'name': 'n', 'type': 'number', 'defaultValue': 2 }]
- },
- {
- 'tfOpName': 'Concat',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 1, 'end': 0, 'name': 'tensors', 'type': 'tensors' },
- { 'start': 0, 'name': 'axis', 'type': 'number' }
- ],
- 'attrs': [{ 'tfName': 'N', 'name': 'n', 'type': 'number', 'defaultValue': 2 }]
- },
- {
- 'tfOpName': 'GatherV2',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'tensor' },
- { 'start': 2, 'name': 'axis', 'type': 'number', 'defaultValue': 0 }
- ]
- },
- {
- 'tfOpName': 'Gather',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'axis', 'name': 'axis', 'type': 'number', 'defaultValue': 0 }, {
- 'tfName': 'validate_indices',
- 'name': 'validateIndices',
- 'type': 'bool',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'Reverse',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'dims', 'type': 'bool', 'notSupported': true }
- ]
- },
- {
- 'tfOpName': 'ReverseV2',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'Slice',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'begin', 'type': 'number[]' },
- { 'start': 2, 'name': 'size', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'StridedSlice',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'begin', 'type': 'number[]' },
- { 'start': 2, 'name': 'end', 'type': 'number[]' },
- { 'start': 3, 'name': 'strides', 'type': 'number[]' },
- ],
- 'attrs': [
- {
- 'tfName': 'begin_mask',
- 'name': 'beginMask',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'end_mask',
- 'name': 'endMask',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'new_axis_mask',
- 'name': 'newAxisMask',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'ellipsis_mask',
- 'name': 'ellipsisMask',
- 'type': 'number',
- 'defaultValue': 0
- },
- {
- 'tfName': 'shrink_axis_mask',
- 'name': 'shrinkAxisMask',
- 'type': 'number',
- 'defaultValue': 0
- }
- ]
- },
- {
- 'tfOpName': 'Pack',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'end': 0, 'name': 'tensors', 'type': 'tensors' },
- ],
- 'attrs': [
- { 'tfName': 'axis', 'name': 'axis', 'type': 'number', 'defaultValue': 0 }
- ]
- },
- {
- 'tfOpName': 'Unpack',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'tensor', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'axis', 'name': 'axis', 'type': 'number', 'defaultValue': 0 }, {
- 'tfName': 'num',
- 'name': 'num',
- 'type': 'number',
- 'defaultValue': 0,
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'Tile',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'reps', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'Split',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'axis', 'type': 'number', 'defaultValue': 0 },
- { 'start': 1, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'num_split',
- 'name': 'numOrSizeSplits',
- 'type': 'number',
- 'defaultValue': 1
- }]
- },
- {
- 'tfOpName': 'SplitV',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'numOrSizeSplits', 'type': 'number[]' },
- { 'start': 2, 'name': 'axis', 'type': 'number', 'defaultValue': 0 }
- ]
- },
- {
- 'tfOpName': 'ScatterNd',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'indices', 'type': 'tensor' },
- { 'start': 1, 'name': 'values', 'type': 'tensor' },
- { 'start': 2, 'name': 'shape', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'GatherNd',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'indices', 'type': 'tensor' }
- ]
- },
- {
- 'tfOpName': 'SparseToDense',
- 'category': 'slice_join',
- 'inputs': [
- { 'start': 0, 'name': 'sparseIndices', 'type': 'tensor' },
- { 'start': 1, 'name': 'outputShape', 'type': 'number[]' },
- { 'start': 2, 'name': 'sparseValues', 'type': 'tensor' },
- { 'start': 3, 'name': 'defaultValue', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'validate_indices',
- 'name': 'validateIndices',
- 'type': 'bool',
- 'defaultValue': false,
- 'notSupported': true
- }]
- }
- ];
-
- var sliceJoin = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$e
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$f = [
- {
- 'tfOpName': 'FFT',
- 'category': 'spectral',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'IFFT',
- 'category': 'spectral',
- 'inputs': [{ 'start': 0, 'name': 'x', 'type': 'tensor' }]
- },
- {
- 'tfOpName': 'RFFT',
- 'category': 'spectral',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' }, {
- 'start': 1,
- 'name': 'fft_length',
- 'type': 'number',
- 'notSupported': true
- }
- ]
- },
- {
- 'tfOpName': 'IRFFT',
- 'category': 'spectral',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' }, {
- 'start': 1,
- 'name': 'fft_length',
- 'type': 'number',
- 'notSupported': true
- }
- ]
- }
- ];
-
- var spectral$1 = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$f
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const json$g = [
- {
- 'tfOpName': 'Cast',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- {
- 'tfName': 'SrcT',
- 'name': 'sdtype',
- 'type': 'dtype',
- 'notSupported': true
- },
- { 'tfName': 'DstT', 'name': 'dtype', 'type': 'dtype' }
- ]
- },
- {
- 'tfOpName': 'ExpandDims',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'axis', 'type': 'number' }
- ]
- },
- {
- 'tfOpName': 'MirrorPad',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'padding', 'type': 'number[]' },
- ],
- 'attrs': [{ 'tfName': 'mode', 'name': 'mode', 'type': 'string' }]
- },
- {
- 'tfOpName': 'Pad',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'padding', 'type': 'number[]' },
- ],
- 'attrs': [{
- 'tfName': 'constant_value',
- 'name': 'constantValue',
- 'type': 'number',
- 'defaultValue': 0
- }]
- },
- {
- 'tfOpName': 'PadV2',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'padding', 'type': 'number[]' }, {
- 'start': 2,
- 'name': 'constantValue',
- 'type': 'number',
- 'defaultValue': 0
- }
- ]
- },
- {
- 'tfOpName': 'Reshape',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'shape', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'Squeeze',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [{
- 'tfName': 'axis',
- 'tfDeprecatedName': 'squeeze_dims',
- 'name': 'axis',
- 'type': 'number[]'
- }]
- },
- {
- 'tfOpName': 'SpaceToBatchND',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'blockShape', 'type': 'number[]' },
- { 'start': 2, 'name': 'paddings', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'BatchToSpaceND',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'blockShape', 'type': 'number[]' },
- { 'start': 2, 'name': 'crops', 'type': 'number[]' }
- ]
- },
- {
- 'tfOpName': 'DepthToSpace',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- ],
- 'attrs': [
- { 'tfName': 'block_size', 'name': 'blockSize', 'type': 'number' },
- { 'tfName': 'data_format', 'name': 'dataFormat', 'type': 'string' }
- ]
- },
- {
- 'tfOpName': 'BroadcastTo',
- 'category': 'transformation',
- 'inputs': [
- { 'start': 0, 'name': 'x', 'type': 'tensor' },
- { 'start': 1, 'name': 'shape', 'type': 'number[]' },
- ],
- 'attrs': []
- }
- ];
-
- var transformation = /*#__PURE__*/Object.freeze({
- __proto__: null,
- json: json$g
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class OperationMapper {
- // Singleton instance for the mapper
- static get Instance() {
- return this._instance || (this._instance = new this());
- }
- // Loads the op mapping from the JSON file.
- constructor() {
- const ops = [
- arithmetic, basicMath, control, convolution, creation, dynamic,
- evaluation, logical, image$1, graph, matrices, normalization, reduction,
- sliceJoin, spectral$1, transformation, hashTable
- ];
- const mappersJson = [].concat(...ops.map(op => op.json));
- this.opMappers = mappersJson.reduce((map, mapper) => {
- map[mapper.tfOpName] = mapper;
- return map;
- }, {});
- }
- // Converts the model inference graph from Tensorflow GraphDef to local
- // representation for TensorFlow.js API
- transformGraph(graph, signature = {}) {
- const tfNodes = graph.node;
- const placeholders = [];
- const weights = [];
- const initNodes = [];
- const nodes = tfNodes.reduce((map, node) => {
- map[node.name] = this.mapNode(node);
- if (node.op.startsWith('Placeholder')) {
- placeholders.push(map[node.name]);
- }
- else if (node.op === 'Const') {
- weights.push(map[node.name]);
- }
- else if (node.input == null || node.input.length === 0) {
- initNodes.push(map[node.name]);
- }
- return map;
- }, {});
- let inputs = [];
- const outputs = [];
- let inputNodeNameToKey = {};
- let outputNodeNameToKey = {};
- if (signature != null) {
- inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
- outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
- }
- const allNodes = Object.keys(nodes);
- allNodes.forEach(key => {
- const node = nodes[key];
- node.inputNames.forEach(name => {
- const [nodeName,] = getNodeNameAndIndex(name);
- node.inputs.push(nodes[nodeName]);
- nodes[nodeName].children.push(node);
- });
- });
- // if signature has not outputs set, add any node that does not have
- // outputs.
- if (Object.keys(outputNodeNameToKey).length === 0) {
- allNodes.forEach(key => {
- const node = nodes[key];
- if (node.children.length === 0) {
- outputs.push(node);
- }
- });
- }
- else {
- Object.keys(outputNodeNameToKey).forEach(name => {
- const [nodeName,] = getNodeNameAndIndex(name);
- const node = nodes[nodeName];
- if (node != null) {
- node.signatureKey = outputNodeNameToKey[name];
- outputs.push(node);
- }
- });
- }
- if (Object.keys(inputNodeNameToKey).length > 0) {
- Object.keys(inputNodeNameToKey).forEach(name => {
- const [nodeName,] = getNodeNameAndIndex(name);
- const node = nodes[nodeName];
- if (node) {
- node.signatureKey = inputNodeNameToKey[name];
- inputs.push(node);
- }
- });
- }
- else {
- inputs = placeholders;
- }
- let functions = {};
- if (graph.library != null && graph.library.function != null) {
- functions = graph.library.function.reduce((functions, func) => {
- functions[func.signature.name] = this.mapFunction(func);
- return functions;
- }, {});
- }
- const result = { nodes, inputs, outputs, weights, placeholders, signature, functions };
- if (initNodes.length > 0) {
- result.initNodes = initNodes;
- }
- return result;
- }
- mapSignatureEntries(entries) {
- return Object.keys(entries || {})
- .reduce((prev, curr) => {
- prev[entries[curr].name] = curr;
- return prev;
- }, {});
- }
- mapNode(node) {
- // Unsupported ops will cause an error at run-time (not parse time), since
- // they may not be used by the actual execution subgraph.
- const mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
- if (node.attr == null) {
- node.attr = {};
- }
- const newNode = {
- name: node.name,
- op: node.op,
- category: mapper.category,
- inputNames: (node.input ||
- []).map(input => input.startsWith('^') ? input.substr(1) : input),
- inputs: [],
- children: [],
- inputParams: {},
- attrParams: {},
- rawAttrs: node.attr
- };
- if (mapper.inputs != null) {
- newNode.inputParams =
- mapper.inputs.reduce((map, param) => {
- map[param.name] = {
- type: param.type,
- inputIndexStart: param.start,
- inputIndexEnd: param.end
- };
- return map;
- }, {});
- }
- if (mapper.attrs != null) {
- newNode.attrParams =
- mapper.attrs.reduce((map, param) => {
- const type = param.type;
- let value = undefined;
- switch (param.type) {
- case 'string':
- value = getStringParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'string[]':
- value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'number':
- value = getNumberParam(node.attr, param.tfName, (param.defaultValue || 0));
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'number[]':
- value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'bool':
- value = getBoolParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'bool[]':
- value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'shape':
- value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'shape[]':
- value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'dtype':
- value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'dtype[]':
- value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'func':
- value = getFuncParam(node.attr, param.tfName, param.defaultValue);
- if (value === undefined && !!param.tfDeprecatedName) {
- value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
- }
- break;
- case 'tensor':
- case 'tensors':
- break;
- default:
- throw new Error(`Unsupported param type: ${param.type} for op: ${node.op}`);
- }
- map[param.name] = { value, type };
- return map;
- }, {});
- }
- return newNode;
- }
- // map the TFunctionDef to TFJS graph object
- mapFunction(functionDef) {
- const tfNodes = functionDef.nodeDef;
- const placeholders = [];
- const weights = [];
- let nodes = {};
- if (tfNodes != null) {
- nodes = tfNodes.reduce((map, node) => {
- map[node.name] = this.mapNode(node);
- if (node.op === 'Const') {
- weights.push(map[node.name]);
- }
- return map;
- }, {});
- }
- const inputs = [];
- const outputs = [];
- functionDef.signature.inputArg.forEach(arg => {
- const [nodeName,] = getNodeNameAndIndex(arg.name);
- const node = {
- name: nodeName,
- op: 'Placeholder',
- inputs: [],
- inputNames: [],
- category: 'graph',
- inputParams: {},
- attrParams: { dtype: { value: parseDtypeParam(arg.type), type: 'dtype' } },
- children: []
- };
- node.signatureKey = arg.name;
- inputs.push(node);
- nodes[nodeName] = node;
- });
- const allNodes = Object.keys(nodes);
- allNodes.forEach(key => {
- const node = nodes[key];
- node.inputNames.forEach(name => {
- const [nodeName,] = getNodeNameAndIndex(name);
- node.inputs.push(nodes[nodeName]);
- nodes[nodeName].children.push(node);
- });
- });
- const returnNodeMap = functionDef.ret;
- functionDef.signature.outputArg.forEach(output => {
- const [nodeName, index] = getNodeNameAndIndex(returnNodeMap[output.name]);
- const node = nodes[nodeName];
- if (node != null) {
- node.defaultOutput = index;
- outputs.push(node);
- }
- });
- const signature = this.mapArgsToSignature(functionDef);
- return { nodes, inputs, outputs, weights, placeholders, signature };
- }
- mapArgsToSignature(functionDef) {
- return {
- methodName: functionDef.signature.name,
- inputs: functionDef.signature.inputArg.reduce((map, arg) => {
- map[arg.name] = this.mapArgToTensorInfo(arg);
- return map;
- }, {}),
- outputs: functionDef.signature.outputArg.reduce((map, arg) => {
- map[arg.name] = this.mapArgToTensorInfo(arg, functionDef.ret);
- return map;
- }, {}),
- };
- }
- mapArgToTensorInfo(arg, nameMap) {
- let name = arg.name;
- if (nameMap != null) {
- name = nameMap[name];
- }
- return { name, dtype: arg.type };
- }
- }
- function decodeBase64(text) {
- const global = env().global;
- if (typeof global.atob !== 'undefined') {
- return global.atob(text);
- }
- else if (typeof Buffer !== 'undefined') {
- return new Buffer(text, 'base64').toString();
- }
- else {
- throw new Error('Unable to decode base64 in this environment. ' +
- 'Missing built-in atob() or Buffer()');
- }
- }
- function parseStringParam(s, keepCase) {
- const value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
- return keepCase ? value : value.toLowerCase();
- }
- function getStringParam(attrs, name, def, keepCase = false) {
- const param = attrs[name];
- if (param != null) {
- return parseStringParam(param.s, keepCase);
- }
- return def;
- }
- function getBoolParam(attrs, name, def) {
- const param = attrs[name];
- return param ? param.b : def;
- }
- function getNumberParam(attrs, name, def) {
- const param = attrs[name] || {};
- const value = param['i'] != null ? param['i'] : (param['f'] != null ? param['f'] : def);
- return (typeof value === 'number') ? value : parseInt(value, 10);
- }
- function parseDtypeParam(value) {
- if (typeof (value) === 'string') {
- // tslint:disable-next-line:no-any
- value = DataType[value];
- }
- switch (value) {
- case DataType.DT_FLOAT:
- return 'float32';
- case DataType.DT_INT32:
- case DataType.DT_INT64:
- case DataType.DT_INT8:
- case DataType.DT_UINT8:
- return 'int32';
- case DataType.DT_BOOL:
- return 'bool';
- case DataType.DT_DOUBLE:
- return 'float32';
- case DataType.DT_STRING:
- return 'string';
- default:
- // Unknown dtype error will happen at runtime (instead of parse time),
- // since these nodes might not be used by the actual subgraph execution.
- return null;
- }
- }
- function getFuncParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.func) {
- return param.func.name;
- }
- return def;
- }
- function getDtypeParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.type) {
- return parseDtypeParam(param.type);
- }
- return def;
- }
- function getDtypeArrayParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.list && param.list.type) {
- return param.list.type.map(v => parseDtypeParam(v));
- }
- return def;
- }
- function parseTensorShapeParam(shape) {
- if (shape.unknownRank) {
- return undefined;
- }
- if (shape.dim != null) {
- return shape.dim.map(dim => (typeof dim.size === 'number') ? dim.size : parseInt(dim.size, 10));
- }
- return [];
- }
- function getTensorShapeParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.shape) {
- return parseTensorShapeParam(param.shape);
- }
- return def;
- }
- function getNumericArrayParam(attrs, name, def) {
- const param = attrs[name];
- if (param) {
- return ((param.list.f && param.list.f.length ? param.list.f :
- param.list.i) ||
- [])
- .map(v => (typeof v === 'number') ? v : parseInt(v, 10));
- }
- return def;
- }
- function getStringArrayParam(attrs, name, def, keepCase = false) {
- const param = attrs[name];
- if (param && param.list && param.list.s) {
- return param.list.s.map((v) => {
- return parseStringParam(v, keepCase);
- });
- }
- return def;
- }
- function getTensorShapeArrayParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.list && param.list.shape) {
- return param.list.shape.map((v) => {
- return parseTensorShapeParam(v);
- });
- }
- return def;
- }
- function getBoolArrayParam(attrs, name, def) {
- const param = attrs[name];
- if (param && param.list && param.list.b) {
- return param.list.b;
- }
- return def;
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Helper class for lookup inputs and params for nodes in the model graph.
- */
- class NodeValueImpl {
- constructor(node, tensorMap, context) {
- this.node = node;
- this.tensorMap = tensorMap;
- this.context = context;
- this.inputs = [];
- this.attrs = {};
- this.inputs = node.inputNames.map(name => this.getInput(name));
- if (node.rawAttrs != null) {
- this.attrs = Object.keys(node.rawAttrs)
- .reduce((attrs, key) => {
- attrs[key] = this.getAttr(key);
- return attrs;
- }, {});
- }
- }
- /**
- * Return the value of the attribute or input param.
- * @param name String: name of attribute or input param.
- */
- getInput(name) {
- return getTensor(name, this.tensorMap, this.context);
- }
- /**
- * Return the value of the attribute or input param.
- * @param name String: name of attribute or input param.
- */
- getAttr(name, defaultValue) {
- const value = this.node.rawAttrs[name];
- if (value.tensor != null) {
- return getTensor(name, this.tensorMap, this.context);
- }
- if (value.i != null || value.f != null) {
- return getNumberParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.s != null) {
- return getStringParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.b != null) {
- return getBoolParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.shape != null) {
- return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.type != null) {
- return getDtypeParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.list != null) {
- if (value.list.i != null || value.list.f != null) {
- return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.list.s != null) {
- return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.list.shape != null) {
- return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.list.b != null) {
- return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
- }
- if (value.list.type != null) {
- return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
- }
- }
- return defaultValue;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp = (node, tensorMap, context) => {
- switch (node.op) {
- case 'BiasAdd':
- case 'AddV2':
- case 'Add': {
- return [add$1(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'AddN': {
- return [addN(getParamValue('tensors', node, tensorMap, context))];
- }
- case 'FloorMod':
- case 'Mod':
- return [mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- case 'Mul':
- return [mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- case 'RealDiv':
- case 'Div': {
- return [div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'DivNoNan': {
- return [divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'FloorDiv': {
- return [floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Sub': {
- return [sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Minimum': {
- return [minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Maximum': {
- return [maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Pow': {
- return [pow(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'SquaredDifference': {
- return [squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY = 'arithmetic';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$1 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Abs':
- case 'ComplexAbs':
- return [abs(getParamValue('x', node, tensorMap, context))];
- case 'Acos':
- return [acos(getParamValue('x', node, tensorMap, context))];
- case 'Acosh':
- return [acosh(getParamValue('x', node, tensorMap, context))];
- case 'Asin':
- return [asin(getParamValue('x', node, tensorMap, context))];
- case 'Asinh':
- return [asinh(getParamValue('x', node, tensorMap, context))];
- case 'Atan':
- return [atan(getParamValue('x', node, tensorMap, context))];
- case 'Atan2':
- return [atan2(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context))];
- case 'Atanh':
- return [atanh(getParamValue('x', node, tensorMap, context))];
- case 'Ceil':
- return [ceil(getParamValue('x', node, tensorMap, context))];
- case 'Complex':
- return [complex(getParamValue('real', node, tensorMap, context), getParamValue('imag', node, tensorMap, context))];
- case 'Cos':
- return [cos(getParamValue('x', node, tensorMap, context))];
- case 'Cosh':
- return [cosh(getParamValue('x', node, tensorMap, context))];
- case 'Elu':
- return [elu(getParamValue('x', node, tensorMap, context))];
- case 'Erf':
- return [erf(getParamValue('x', node, tensorMap, context))];
- case 'Exp':
- return [exp(getParamValue('x', node, tensorMap, context))];
- case 'Expm1': {
- return [expm1(getParamValue('x', node, tensorMap, context))];
- }
- case 'Floor':
- return [floor(getParamValue('x', node, tensorMap, context))];
- case 'Log':
- return [log(getParamValue('x', node, tensorMap, context))];
- case 'Log1p': {
- return [log1p(getParamValue('x', node, tensorMap, context))];
- }
- case 'Imag':
- return [imag(getParamValue('x', node, tensorMap, context))];
- case 'Neg':
- return [neg(getParamValue('x', node, tensorMap, context))];
- case 'Reciprocal': {
- return [reciprocal(getParamValue('x', node, tensorMap, context))];
- }
- case 'Real':
- return [real(getParamValue('x', node, tensorMap, context))];
- case 'Relu':
- return [relu(getParamValue('x', node, tensorMap, context))];
- case 'Round': {
- return [round(getParamValue('x', node, tensorMap, context))];
- }
- case 'Selu':
- return [selu(getParamValue('x', node, tensorMap, context))];
- case 'Sigmoid':
- return [sigmoid(getParamValue('x', node, tensorMap, context))];
- case 'Sin':
- return [sin(getParamValue('x', node, tensorMap, context))];
- case 'Sign': {
- return [sign(getParamValue('x', node, tensorMap, context))];
- }
- case 'Sinh': {
- return [sinh(getParamValue('x', node, tensorMap, context))];
- }
- case 'Softplus': {
- return [softplus(getParamValue('x', node, tensorMap, context))];
- }
- case 'Sqrt': {
- return [sqrt(getParamValue('x', node, tensorMap, context))];
- }
- case 'Square': {
- return [square(getParamValue('x', node, tensorMap, context))];
- }
- case 'Tanh': {
- return [tanh$1(getParamValue('x', node, tensorMap, context))];
- }
- case 'Tan':
- return [tan(getParamValue('x', node, tensorMap, context))];
- case 'Relu6':
- case 'ClipByValue':
- return [clipByValue(getParamValue('x', node, tensorMap, context), getParamValue('clipValueMin', node, tensorMap, context), getParamValue('clipValueMax', node, tensorMap, context))];
- case 'Rsqrt':
- return [rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
- case 'Prod':
- return [prod(getParamValue('x', node, tensorMap, context), getParamValue('axes', node, tensorMap, context))];
- case 'LeakyRelu':
- return [leakyRelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
- case 'Prelu':
- return [prelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$1 = 'basic_math';
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix = '') {
- assert(shapesEqualAllowUndefinedSize(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
- }
- function shapesEqualAllowUndefinedSize(n1, n2) {
- if (n1.length !== n2.length) {
- return false;
- }
- for (let i = 0; i < n1.length; i++) {
- if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) {
- return false;
- }
- }
- return true;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * The TensorArray object keeps an array of Tensors. It
- * allows reading from the array and writing to the array.
- */
- class TensorArray {
- constructor(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
- this.name = name;
- this.dtype = dtype;
- this.maxSize = maxSize;
- this.elementShape = elementShape;
- this.identicalElementShapes = identicalElementShapes;
- this.dynamicSize = dynamicSize;
- this.clearAfterRead = clearAfterRead;
- this.tensors = [];
- this.closed_ = false;
- this.idTensor = scalar(0);
- keep(this.idTensor);
- }
- get id() {
- return this.idTensor.id;
- }
- get closed() {
- return this.closed_;
- }
- /**
- * Dispose the tensors and idTensor and mark the TensoryArray as closed.
- */
- clearAndClose(keepIds) {
- this.tensors.forEach(tensor => {
- if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
- tensor.tensor.dispose();
- }
- });
- this.tensors = [];
- this.closed_ = true;
- this.idTensor.dispose();
- }
- size() {
- return this.tensors.length;
- }
- /**
- * Read the value at location index in the TensorArray.
- * @param index Number the index to read from.
- */
- read(index) {
- if (this.closed_) {
- throw new Error(`TensorArray ${this.name} has already been closed.`);
- }
- if (index < 0 || index >= this.size()) {
- throw new Error(`Tried to read from index ${index}, but array size is: ${this.size()}`);
- }
- const tensorWithState = this.tensors[index];
- if (tensorWithState.cleared) {
- throw new Error(`TensorArray ${this.name}: Could not read index ${index} twice because it was cleared after a previous read ` +
- `(perhaps try setting clear_after_read = false?).`);
- }
- if (this.clearAfterRead) {
- tensorWithState.cleared = true;
- }
- tensorWithState.read = true;
- return tensorWithState.tensor;
- }
- /**
- * Helper method to read multiple tensors from the specified indices.
- */
- readMany(indices) {
- return indices.map(index => this.read(index));
- }
- /**
- * Write value into the index of the TensorArray.
- * @param index number the index to write to.
- * @param tensor
- */
- write(index, tensor) {
- if (this.closed_) {
- throw new Error(`TensorArray ${this.name} has already been closed.`);
- }
- if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
- throw new Error(`Tried to write to index ${index}, but array is not resizeable and size is: ${this.maxSize}`);
- }
- const t = this.tensors[index] || {};
- if (tensor.dtype !== this.dtype) {
- throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index},
- because the value dtype is ${tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
- }
- // Set the shape for the first time write to unknow shape tensor array
- if (this.size() === 0 &&
- (this.elementShape == null || this.elementShape.length === 0)) {
- this.elementShape = tensor.shape;
- }
- assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`);
- if (t.read) {
- throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`);
- }
- if (t.written) {
- throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`);
- }
- t.tensor = tensor;
- keep(tensor);
- t.written = true;
- this.tensors[index] = t;
- }
- /**
- * Helper method to write multiple tensors to the specified indices.
- */
- writeMany(indices, tensors) {
- if (indices.length !== tensors.length) {
- throw new Error(`TensorArray ${this.name}: could not write multiple tensors,` +
- `because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`);
- }
- indices.forEach((i, index) => this.write(i, tensors[index]));
- }
- /**
- * Return selected values in the TensorArray as a packed Tensor. All of
- * selected values must have been written and their shapes must all match.
- * @param [indices] number[] Optional. Taking values in [0, max_value). If the
- * TensorArray is not dynamic, max_value=size(). If not specified returns
- * all tensors in the original order.
- * @param [dtype]
- */
- gather(indices, dtype) {
- if (!!dtype && dtype !== this.dtype) {
- throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`);
- }
- if (!indices) {
- indices = [];
- for (let i = 0; i < this.size(); i++) {
- indices.push(i);
- }
- }
- else {
- indices = indices.slice(0, this.size());
- }
- if (indices.length === 0) {
- return tensor([], [0].concat(this.elementShape));
- }
- // Read all the PersistentTensors into a vector to keep track of
- // their memory.
- const tensors = this.readMany(indices);
- assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
- return stack(tensors, 0);
- }
- /**
- * Return the values in the TensorArray as a concatenated Tensor.
- */
- concat(dtype) {
- if (!!dtype && dtype !== this.dtype) {
- throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`);
- }
- if (this.size() === 0) {
- return tensor([], [0].concat(this.elementShape));
- }
- const indices = [];
- for (let i = 0; i < this.size(); i++) {
- indices.push(i);
- }
- // Collect all the tensors from the tensors array.
- const tensors = this.readMany(indices);
- assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`);
- return concat(tensors, 0);
- }
- /**
- * Scatter the values of a Tensor in specific indices of a TensorArray.
- * @param indices nummber[] values in [0, max_value). If the
- * TensorArray is not dynamic, max_value=size().
- * @param tensor Tensor input tensor.
- */
- scatter(indices, tensor) {
- if (tensor.dtype !== this.dtype) {
- throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
- }
- if (indices.length !== tensor.shape[0]) {
- throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
- }
- const maxIndex = Math.max(...indices);
- if (!this.dynamicSize && maxIndex >= this.maxSize) {
- throw new Error(`Max index must be < array size (${maxIndex} vs. ${this.maxSize})`);
- }
- this.writeMany(indices, unstack(tensor, 0));
- }
- /**
- * Split the values of a Tensor into the TensorArray.
- * @param length number[] with the lengths to use when splitting value along
- * its first dimension.
- * @param tensor Tensor, the tensor to split.
- */
- split(length, tensor) {
- if (tensor.dtype !== this.dtype) {
- throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
- }
- let totalLength = 0;
- const cumulativeLengths = length.map(len => {
- totalLength += len;
- return totalLength;
- });
- if (totalLength !== tensor.shape[0]) {
- throw new Error(`Expected sum of lengths to be equal to
- tensor.shape[0], but sum of lengths is
- ${totalLength}, and tensor's shape is: ${tensor.shape}`);
- }
- if (!this.dynamicSize && length.length !== this.maxSize) {
- throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), ` +
- 'and the TensorArray is not marked as dynamically resizeable');
- }
- const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
- const tensors = [];
- tidy(() => {
- tensor = reshape(tensor, [1, totalLength, elementPerRow]);
- for (let i = 0; i < length.length; ++i) {
- const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
- const indices = [0, previousLength, 0];
- const sizes = [1, length[i], elementPerRow];
- tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);
- }
- return tensors;
- });
- const indices = [];
- for (let i = 0; i < length.length; i++) {
- indices[i] = i;
- }
- this.writeMany(indices, tensors);
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * TensorList stores a container of `tf.Tensor` objects, which are accessible
- * via tensors field.
- *
- * In order to get a copy of the underlying list, use the copy method:
- * ```
- * TensorList b = a.copy();
- * b.tensors().pushBack(t); // This does not modify a.tensors().
- * ```
- *
- * Note that this is not a deep copy: the memory locations of the underlying
- * tensors will still point to the same locations of the corresponding tensors
- * in the original.
- */
- class TensorList {
- /**
- *
- * @param tensors list of tensors
- * @param elementShape shape of each tensor
- * @param elementDtype data type of each tensor
- * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
- * meaning that the size of `tensors` is unbounded.
- */
- constructor(tensors, elementShape, elementDtype, maxNumElements = -1) {
- this.tensors = tensors;
- this.elementShape = elementShape;
- this.elementDtype = elementDtype;
- if (tensors != null) {
- tensors.forEach(tensor => {
- if (elementDtype !== tensor.dtype) {
- throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`);
- }
- assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
- keep(tensor);
- });
- }
- this.idTensor = scalar(0);
- this.maxNumElements = maxNumElements;
- keep(this.idTensor);
- }
- get id() {
- return this.idTensor.id;
- }
- /**
- * Get a new TensorList containing a copy of the underlying tensor container.
- */
- copy() {
- return new TensorList([...this.tensors], this.elementShape, this.elementDtype);
- }
- /**
- * Dispose the tensors and idTensor and clear the tensor list.
- */
- clearAndClose(keepIds) {
- this.tensors.forEach(tensor => {
- if (keepIds == null || !keepIds.has(tensor.id)) {
- tensor.dispose();
- }
- });
- this.tensors.length = 0;
- this.idTensor.dispose();
- }
- /**
- * The size of the tensors in the tensor list.
- */
- size() {
- return this.tensors.length;
- }
- /**
- * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
- * tf.Tensor.
- * @param elementShape shape of each tensor
- * @param elementDtype data type of each tensor
- * @param numElements the number of elements to stack
- */
- stack(elementShape, elementDtype, numElements = -1) {
- if (elementDtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
- }
- if (numElements !== -1 && this.tensors.length !== numElements) {
- throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`);
- }
- assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
- return tidy(() => {
- const reshapedTensors = this.tensors.map(tensor => reshape(tensor, elementShape));
- return stack(reshapedTensors, 0);
- });
- }
- /**
- * Pop a tensor from the end of the list.
- * @param elementShape shape of the tensor
- * @param elementDtype data type of the tensor
- */
- popBack(elementShape, elementDtype) {
- if (elementDtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
- }
- if (this.size() === 0) {
- throw new Error('Trying to pop from an empty list.');
- }
- const tensor = this.tensors.pop();
- assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
- return reshape(tensor, elementShape);
- }
- /**
- * Push a tensor to the end of the list.
- * @param tensor Tensor to be pushed.
- */
- pushBack(tensor) {
- if (tensor.dtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
- }
- assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
- if (this.maxNumElements === this.size()) {
- throw new Error(`Trying to push element into a full list.`);
- }
- keep(tensor);
- this.tensors.push(tensor);
- }
- /**
- * Update the size of the list.
- * @param size the new size of the list.
- */
- resize(size) {
- if (size < 0) {
- throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`);
- }
- if (this.maxNumElements !== -1 && size > this.maxNumElements) {
- throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`);
- }
- this.tensors.length = size;
- }
- /**
- * Retrieve the element at the provided index
- * @param elementShape shape of the tensor
- * @param elementDtype dtype of the tensor
- * @param elementIndex index of the tensor
- */
- getItem(elementIndex, elementShape, elementDtype) {
- if (elementDtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
- }
- if (elementIndex < 0 || elementIndex > this.tensors.length) {
- throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`);
- }
- if (this.tensors[elementIndex] == null) {
- throw new Error(`element at index ${elementIndex} is null.`);
- }
- assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
- return this.tensors[elementIndex];
- }
- /**
- * Set the tensor at the index
- * @param elementIndex index of the tensor
- * @param tensor the tensor to be inserted into the list
- */
- setItem(elementIndex, tensor) {
- if (tensor.dtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
- }
- if (elementIndex < 0 ||
- this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
- throw new Error(`Trying to set element ${elementIndex} in a list with max ${this.maxNumElements} elements.`);
- }
- assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
- keep(tensor);
- this.tensors[elementIndex] = tensor;
- }
- /**
- * Return selected values in the TensorList as a stacked Tensor. All of
- * selected values must have been written and their shapes must all match.
- * @param indices indices of tensors to gather
- * @param elementDtype output tensor dtype
- * @param elementShape output tensor element shape
- */
- gather(indices, elementDtype, elementShape) {
- if (elementDtype !== this.elementDtype) {
- throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
- }
- assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
- // When indices is greater than the size of the list, indices beyond the
- // size of the list are ignored.
- indices = indices.slice(0, this.size());
- if (indices.length === 0) {
- return tensor([], [0].concat(this.elementShape));
- }
- return tidy(() => {
- const tensors = indices.map(i => reshape(this.tensors[i], elementShape));
- return stack(tensors, 0);
- });
- }
- /**
- * Return the values in the TensorList as a concatenated Tensor.
- * @param elementDtype output tensor dtype
- * @param elementShape output tensor element shape
- */
- concat(elementDtype, elementShape) {
- if (!!elementDtype && elementDtype !== this.elementDtype) {
- throw new Error(`TensorList dtype is ${this.elementDtype} but concat requested dtype ${elementDtype}`);
- }
- assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
- if (this.size() === 0) {
- return tensor([], [0].concat(this.elementShape));
- }
- return tidy(() => {
- const tensors = this.tensors.map(t => reshape(t, elementShape));
- return concat(tensors, 0);
- });
- }
- }
- /**
- * Creates a TensorList which, when stacked, has the value of tensor.
- * @param tensor from tensor
- * @param elementShape output tensor element shape
- */
- function fromTensor(tensor, elementShape, elementDtype) {
- const dtype = tensor.dtype;
- if (tensor.shape.length < 1) {
- throw new Error(`Tensor must be at least a vector, but saw shape: ${tensor.shape}`);
- }
- if (tensor.dtype !== elementDtype) {
- throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`);
- }
- const outputShape = tensor.shape.slice(1);
- assertShapesMatchAllowUndefinedSize(outputShape, elementShape, 'TensorList shape mismatch: ');
- const tensorList = unstack(tensor);
- return new TensorList(tensorList, elementShape, dtype);
- }
- /**
- * Return a TensorList of the given size with empty elements.
- * @param elementShape the shape of the future elements of the list
- * @param elementDtype the desired type of elements in the list
- * @param numElements the number of elements to reserve
- */
- function reserve(elementShape, elementDtype, numElements) {
- return new TensorList([], elementShape, elementDtype, numElements);
- }
- /**
- * Put tensors at specific indices of a stacked tensor into a TensorList.
- * @param indices list of indices on how to scatter the tensor.
- * @param tensor input tensor.
- * @param elementShape the shape of the future elements of the list
- * @param numElements the number of elements to scatter
- */
- function scatter(tensor, indices, elementShape, numElements) {
- if (indices.length !== tensor.shape[0]) {
- throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
- }
- const maxIndex = Math.max(...indices);
- if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
- throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`);
- }
- const list = new TensorList([], elementShape, tensor.dtype, numElements);
- const tensors = unstack(tensor, 0);
- indices.forEach((value, index) => {
- list.setItem(value, tensors[index]);
- });
- return list;
- }
- /**
- * Split the values of a Tensor into a TensorList.
- * @param length the lengths to use when splitting value along
- * its first dimension.
- * @param tensor the tensor to split.
- * @param elementShape the shape of the future elements of the list
- */
- function split$3(tensor, length, elementShape) {
- let totalLength = 0;
- const cumulativeLengths = length.map(len => {
- totalLength += len;
- return totalLength;
- });
- if (totalLength !== tensor.shape[0]) {
- throw new Error(`Expected sum of lengths to be equal to
- tensor.shape[0], but sum of lengths is
- ${totalLength}, and tensor's shape is: ${tensor.shape}`);
- }
- const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
- const tensors = tidy(() => {
- const tensors = [];
- tensor = reshape(tensor, [1, totalLength, elementPerRow]);
- for (let i = 0; i < length.length; ++i) {
- const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
- const indices = [0, previousLength, 0];
- const sizes = [1, length[i], elementPerRow];
- tensors[i] = reshape(slice(tensor, indices, sizes), elementShape);
- }
- tensor.dispose();
- return tensors;
- });
- const list = new TensorList([], elementShape, tensor.dtype, length.length);
- for (let i = 0; i < tensors.length; i++) {
- list.setItem(i, tensors[i]);
- }
- return list;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$2 = async (node, tensorMap, context) => {
- switch (node.op) {
- case 'If':
- case 'StatelessIf': {
- const thenFunc = getParamValue('thenBranch', node, tensorMap, context);
- const elseFunc = getParamValue('elseBranch', node, tensorMap, context);
- const cond = getParamValue('cond', node, tensorMap, context);
- const args = getParamValue('args', node, tensorMap, context);
- const condValue = await cond.data();
- if (condValue[0]) {
- return context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
- }
- else {
- return context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
- }
- }
- case 'While':
- case 'StatelessWhile': {
- const bodyFunc = getParamValue('body', node, tensorMap, context);
- const condFunc = getParamValue('cond', node, tensorMap, context);
- const args = getParamValue('args', node, tensorMap, context);
- // Calculate the condition of the loop
- const condResult = (await context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
- const argIds = args.map(tensor => tensor.id);
- let condValue = await condResult[0].data();
- // Dispose the intermediate tensors for condition function
- condResult.forEach(tensor => {
- if (!tensor.kept && argIds.indexOf(tensor.id) === -1) {
- tensor.dispose();
- }
- });
- let result = args;
- while (condValue[0]) {
- // Record the previous result for intermediate tensor tracking
- const origResult = result;
- // Execution the body of the loop
- result = await context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
- const resultIds = result.map(tensor => tensor.id);
- // Dispose the intermediate tensor for body function that is not global
- // kept, not input/output of the body function
- origResult.forEach(tensor => {
- if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
- resultIds.indexOf(tensor.id) === -1) {
- tensor.dispose();
- }
- });
- // Recalcuate the condition of the loop using the latest results.
- const condResult = (await context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap));
- condValue = await condResult[0].data();
- // Dispose the intermediate tensors for condition function
- condResult.forEach(tensor => {
- if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
- resultIds.indexOf(tensor.id) === -1) {
- tensor.dispose();
- }
- });
- }
- return result;
- }
- case 'LoopCond': {
- const pred = getParamValue('pred', node, tensorMap, context);
- return [cloneTensor(pred)];
- }
- case 'Switch': {
- const pred = getParamValue('pred', node, tensorMap, context);
- let data = getParamValue('data', node, tensorMap, context);
- if (!data.kept) {
- data = cloneTensor(data);
- }
- // Outputs nodes :0 => false, :1 => true
- return (await pred.data())[0] ? [undefined, data] : [data, undefined];
- }
- case 'Merge': {
- const inputName = node.inputNames.find(name => getTensor(name, tensorMap, context) !== undefined);
- if (inputName) {
- const data = getTensor(inputName, tensorMap, context);
- return [cloneTensor(data)];
- }
- return undefined;
- }
- case 'Enter': {
- const frameId = getParamValue('frameName', node, tensorMap, context);
- const data = getParamValue('tensor', node, tensorMap, context);
- context.enterFrame(frameId);
- return [cloneTensor(data)];
- }
- case 'Exit': {
- const data = getParamValue('tensor', node, tensorMap, context);
- context.exitFrame();
- return [cloneTensor(data)];
- }
- case 'NextIteration': {
- const data = getParamValue('tensor', node, tensorMap, context);
- context.nextIteration();
- return [cloneTensor(data)];
- }
- case 'TensorArrayV3': {
- const size = getParamValue('size', node, tensorMap, context);
- const dtype = getParamValue('dtype', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const dynamicSize = getParamValue('dynamicSize', node, tensorMap, context);
- const clearAfterRead = getParamValue('clearAfterRead', node, tensorMap, context);
- const identicalElementShapes = getParamValue('identicalElementShapes', node, tensorMap, context);
- const name = getParamValue('name', node, tensorMap, context);
- const tensorArray = new TensorArray(name, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
- context.addTensorArray(tensorArray);
- return [tensorArray.idTensor, scalar(1.0)];
- }
- case 'TensorArrayWriteV3': {
- const id = getParamValue('tensorArrayId', node, tensorMap, context);
- const index = getParamValue('index', node, tensorMap, context);
- const writeTensor = getParamValue('tensor', node, tensorMap, context);
- const writeTensorArray = context.getTensorArray(id.id);
- writeTensorArray.write(index, writeTensor);
- return [writeTensorArray.idTensor];
- }
- case 'TensorArrayReadV3': {
- const readId = getParamValue('tensorArrayId', node, tensorMap, context);
- const readIndex = getParamValue('index', node, tensorMap, context);
- const readTensorArray = context.getTensorArray(readId.id);
- return [readTensorArray.read(readIndex)];
- }
- case 'TensorArrayGatherV3': {
- const gatherId = getParamValue('tensorArrayId', node, tensorMap, context);
- const gatherIndices = getParamValue('indices', node, tensorMap, context);
- const gatherDtype = getParamValue('dtype', node, tensorMap, context);
- const gatherTensorArray = context.getTensorArray(gatherId.id);
- return [gatherTensorArray.gather(gatherIndices, gatherDtype)];
- }
- case 'TensorArrayScatterV3': {
- const scatterId = getParamValue('tensorArrayId', node, tensorMap, context);
- const scatterIndices = getParamValue('indices', node, tensorMap, context);
- const scatterTensor = getParamValue('tensor', node, tensorMap, context);
- const scatterTensorArray = context.getTensorArray(scatterId.id);
- scatterTensorArray.scatter(scatterIndices, scatterTensor);
- return [scatterTensorArray.idTensor];
- }
- case 'TensorArrayConcatV3': {
- const concatId = getParamValue('tensorArrayId', node, tensorMap, context);
- const concatTensorArray = context.getTensorArray(concatId.id);
- const concatDtype = getParamValue('dtype', node, tensorMap, context);
- return [concatTensorArray.concat(concatDtype)];
- }
- case 'TensorArraySplitV3': {
- const splitId = getParamValue('tensorArrayId', node, tensorMap, context);
- const splitTensor = getParamValue('tensor', node, tensorMap, context);
- const lengths = getParamValue('lengths', node, tensorMap, context);
- const splitTensorArray = context.getTensorArray(splitId.id);
- splitTensorArray.split(lengths, splitTensor);
- return [splitTensorArray.idTensor];
- }
- case 'TensorArraySizeV3': {
- const sizeId = getParamValue('tensorArrayId', node, tensorMap, context);
- const sizeTensorArray = context.getTensorArray(sizeId.id);
- return [scalar(sizeTensorArray.size(), 'int32')];
- }
- case 'TensorArrayCloseV3': {
- const closeId = getParamValue('tensorArrayId', node, tensorMap, context);
- const closeTensorArray = context.getTensorArray(closeId.id);
- closeTensorArray.clearAndClose();
- return [closeTensorArray.idTensor];
- }
- case 'TensorListSetItem': {
- const idTensor = getParamValue('tensorListId', node, tensorMap, context);
- const index = getParamValue('index', node, tensorMap, context);
- const writeTensor = getParamValue('tensor', node, tensorMap, context);
- const tensorList = context.getTensorList(idTensor.id);
- tensorList.setItem(index, writeTensor);
- return [tensorList.idTensor];
- }
- case 'TensorListGetItem': {
- const idTensor = getParamValue('tensorListId', node, tensorMap, context);
- const readIndex = getParamValue('index', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDType = getParamValue('elementDType', node, tensorMap, context);
- const tensorList = context.getTensorList(idTensor.id);
- return [tensorList.getItem(readIndex, elementShape, elementDType)];
- }
- case 'TensorListScatterV2':
- case 'TensorListScatter': {
- const scatterIndices = getParamValue('indices', node, tensorMap, context);
- const scatterTensor = getParamValue('tensor', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const numElements = getParamValue('numElements', node, tensorMap, context);
- const tensorList = scatter(scatterTensor, scatterIndices, elementShape, numElements);
- context.addTensorList(tensorList);
- return [tensorList.idTensor];
- }
- case 'TensorListReserve': {
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDtype = getParamValue('elementDType', node, tensorMap, context);
- const numElements = getParamValue('numElements', node, tensorMap, context);
- const tensorList = reserve(elementShape, elementDtype, numElements);
- context.addTensorList(tensorList);
- return [tensorList.idTensor];
- }
- case 'TensorListGather': {
- const gatherId = getParamValue('tensorListId', node, tensorMap, context);
- const gatherIndices = getParamValue('indices', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDtype = getParamValue('elementDType', node, tensorMap, context);
- const tensorList = context.getTensorList(gatherId.id);
- return [tensorList.gather(gatherIndices, elementDtype, elementShape)];
- }
- case 'TensorListStack': {
- const idTensor = getParamValue('tensorListId', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDtype = getParamValue('elementDType', node, tensorMap, context);
- const numElements = getParamValue('numElements', node, tensorMap, context);
- const tensorList = context.getTensorList(idTensor.id);
- return [tensorList.stack(elementShape, elementDtype, numElements)];
- }
- case 'TensorListFromTensor': {
- const tensor = getParamValue('tensor', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDtype = getParamValue('elementDType', node, tensorMap, context);
- const tensorList = fromTensor(tensor, elementShape, elementDtype);
- context.addTensorList(tensorList);
- return [tensorList.idTensor];
- }
- case 'TensorListConcat': {
- const concatId = getParamValue('tensorListId', node, tensorMap, context);
- const tensorList = context.getTensorList(concatId.id);
- const concatDtype = getParamValue('dtype', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- return [tensorList.concat(concatDtype, elementShape)];
- }
- case 'TensorListPushBack': {
- const idTensor = getParamValue('tensorListId', node, tensorMap, context);
- const writeTensor = getParamValue('tensor', node, tensorMap, context);
- const tensorList = context.getTensorList(idTensor.id);
- tensorList.pushBack(writeTensor);
- return [tensorList.idTensor];
- }
- case 'TensorListPopBack': {
- const idTensor = getParamValue('tensorListId', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const elementDType = getParamValue('elementDType', node, tensorMap, context);
- const tensorList = context.getTensorList(idTensor.id);
- return [tensorList.popBack(elementShape, elementDType)];
- }
- case 'TensorListSplit': {
- const splitTensor = getParamValue('tensor', node, tensorMap, context);
- const elementShape = getParamValue('elementShape', node, tensorMap, context);
- const lengths = getParamValue('lengths', node, tensorMap, context);
- const tensorList = split$3(splitTensor, lengths, elementShape);
- context.addTensorList(tensorList);
- return [tensorList.idTensor];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$2 = 'control';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function fusedConvAndDepthWiseParams(node, tensorMap, context) {
- const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
- const isBiasAdd = extraOp === 'biasadd';
- const isPrelu = activationFunc === 'prelu';
- const isBatchNorm = extraOp === 'fusedbatchnorm';
- const numArgs = getParamValue('numArgs', node, tensorMap, context);
- if (isBiasAdd) {
- if (isPrelu && numArgs !== 2) {
- throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' +
- 'must have two extra arguments: bias and alpha.');
- }
- if (!isPrelu && numArgs !== 1) {
- throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' +
- 'one extra argument: bias.');
- }
- }
- if (isBatchNorm) {
- throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported.');
- }
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getPadding(node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
- .toUpperCase();
- const dilations = getParamValue('dilations', node, tensorMap, context);
- const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
- return {
- stride,
- pad,
- dataFormat,
- dilations,
- biasArg,
- preluArg,
- activationFunc
- };
- }
- const executeOp$3 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Conv1D': {
- const stride = getParamValue('stride', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
- .toUpperCase();
- const dilation = getParamValue('dilation', node, tensorMap, context);
- return [conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
- }
- case 'Conv2D': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getPadding(node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
- .toUpperCase();
- const dilations = getParamValue('dilations', node, tensorMap, context);
- return [conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
- }
- case '_FusedConv2D': {
- const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc } = fusedConvAndDepthWiseParams(node, tensorMap, context);
- return [conv2d$1({
- x: getParamValue('x', node, tensorMap, context),
- filter: getParamValue('filter', node, tensorMap, context),
- strides: [stride[1], stride[2]],
- pad: pad,
- dataFormat: dataFormat,
- dilations: [dilations[1], dilations[2]],
- bias: biasArg,
- activation: activationFunc,
- preluActivationWeights: preluArg
- })];
- }
- case 'FusedDepthwiseConv2dNative': {
- const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc } = fusedConvAndDepthWiseParams(node, tensorMap, context);
- return [depthwiseConv2d$1({
- x: getParamValue('x', node, tensorMap, context),
- filter: getParamValue('filter', node, tensorMap, context),
- strides: [stride[1], stride[2]],
- pad: pad,
- dataFormat: dataFormat,
- dilations: [dilations[1], dilations[2]],
- bias: biasArg,
- activation: activationFunc,
- preluActivationWeights: preluArg
- })];
- }
- case 'Conv2DBackpropInput':
- case 'Conv2dTranspose': {
- const shape = getParamValue('outputShape', node, tensorMap, context);
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getPadding(node, tensorMap, context);
- return [conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [stride[1], stride[2]], pad)];
- }
- case 'DepthwiseConv2dNative':
- case 'DepthwiseConv2d': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getPadding(node, tensorMap, context);
- const dilations = getParamValue('dilations', node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
- .toUpperCase();
- return [depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
- }
- case 'Conv3D': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
- .toUpperCase();
- const dilations = getParamValue('dilations', node, tensorMap, context);
- return [conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2], stride[3]], pad, dataFormat, [dilations[1], dilations[2], dilations[3]])];
- }
- case 'AvgPool': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
- return [avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
- }
- case 'MaxPool': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
- return [maxPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
- }
- case 'MaxPoolWithArgmax': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
- const includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
- const { result, indexes } = maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad, includeBatchInIndex);
- return [result, indexes];
- }
- case 'AvgPool3D': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
- return [avgPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
- }
- case 'MaxPool3D': {
- const stride = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
- return [maxPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
- }
- case 'Dilation2D': {
- const strides = getParamValue('strides', node, tensorMap, context);
- const pad = getParamValue('pad', node, tensorMap, context);
- const dilations = getParamValue('dilations', node, tensorMap, context);
- // strides: [1, stride_height, stride_width, 1].
- const strideHeight = strides[1];
- const strideWidth = strides[2];
- // dilations: [1, dilation_height, dilation_width, 1].
- const dilationHeight = dilations[1];
- const dilationWidth = dilations[2];
- return [dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$3 = 'convolution';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$4 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Fill': {
- const shape = getParamValue('shape', node, tensorMap, context);
- const dtype = getParamValue('dtype', node, tensorMap, context);
- const value = getParamValue('value', node, tensorMap, context);
- return [fill(shape, value, dtype)];
- }
- case 'LinSpace': {
- const start = getParamValue('start', node, tensorMap, context);
- const stop = getParamValue('stop', node, tensorMap, context);
- const num = getParamValue('num', node, tensorMap, context);
- return [linspace(start, stop, num)];
- }
- case 'Multinomial': {
- const logits = getParamValue('logits', node, tensorMap, context);
- const numSamples = getParamValue('numSamples', node, tensorMap, context);
- const seed = getParamValue('seed', node, tensorMap, context);
- return [multinomial(logits, numSamples, seed)];
- }
- case 'OneHot': {
- const indices = getParamValue('indices', node, tensorMap, context);
- const depth = getParamValue('depth', node, tensorMap, context);
- const onValue = getParamValue('onValue', node, tensorMap, context);
- const offValue = getParamValue('offValue', node, tensorMap, context);
- return [oneHot(indices, depth, onValue, offValue)];
- }
- case 'Ones': {
- return [ones$1(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
- }
- case 'OnesLike': {
- return [onesLike(getParamValue('x', node, tensorMap, context))];
- }
- case 'RandomUniform': {
- return [randomUniform(
- // tslint:disable-next-line:no-any
- getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
- }
- case 'Range': {
- const start = getParamValue('start', node, tensorMap, context);
- const stop = getParamValue('stop', node, tensorMap, context);
- const step = getParamValue('step', node, tensorMap, context);
- return [range(start, stop, step, getParamValue('dtype', node, tensorMap, context))];
- }
- case 'TruncatedNormal': {
- const shape = getParamValue('shape', node, tensorMap, context);
- const mean = getParamValue('mean', node, tensorMap, context);
- const stdDev = getParamValue('stdDev', node, tensorMap, context);
- const seed = getParamValue('seed', node, tensorMap, context);
- return [truncatedNormal(shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), seed)];
- }
- case 'Zeros': {
- return [zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
- }
- case 'ZerosLike': {
- return [zerosLike(getParamValue('x', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$4 = 'creation';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function nmsParams(node, tensorMap, context) {
- const boxes = getParamValue('boxes', node, tensorMap, context);
- const scores = getParamValue('scores', node, tensorMap, context);
- const maxOutputSize = getParamValue('maxOutputSize', node, tensorMap, context);
- const iouThreshold = getParamValue('iouThreshold', node, tensorMap, context);
- const scoreThreshold = getParamValue('scoreThreshold', node, tensorMap, context);
- const softNmsSigma = getParamValue('softNmsSigma', node, tensorMap, context);
- return {
- boxes,
- scores,
- maxOutputSize,
- iouThreshold,
- scoreThreshold,
- softNmsSigma
- };
- }
- const executeOp$5 = async (node, tensorMap, context) => {
- switch (node.op) {
- case 'NonMaxSuppressionV5': {
- const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = nmsParams(node, tensorMap, context);
- const result = await image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
- return [result.selectedIndices, result.selectedScores];
- }
- case 'NonMaxSuppressionV4': {
- const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
- const padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context);
- const result = await image.nonMaxSuppressionPaddedAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
- return [result.selectedIndices, result.validOutputs];
- }
- case 'NonMaxSuppressionV3':
- case 'NonMaxSuppressionV2': {
- const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
- return [await image.nonMaxSuppressionAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)];
- }
- case 'Where': {
- const condition = cast(getParamValue('condition', node, tensorMap, context), 'bool');
- const result = [await whereAsync(condition)];
- condition.dispose();
- return result;
- }
- case 'ListDiff': {
- return setdiff1dAsync(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context));
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$5 = 'dynamic';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$6 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'TopKV2': {
- const x = getParamValue('x', node, tensorMap, context);
- const k = getParamValue('k', node, tensorMap, context);
- const sorted = getParamValue('sorted', node, tensorMap, context);
- const result = topk(x, k, sorted);
- return [result.values, result.indices];
- }
- case 'Unique': {
- const x = getParamValue('x', node, tensorMap, context);
- const result = unique(x);
- return [result.values, result.indices];
- }
- case 'UniqueV2': {
- const x = getParamValue('x', node, tensorMap, context);
- const axis = getParamValue('axis', node, tensorMap, context);
- const result = unique(x, axis);
- return [result.values, result.indices];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$6 = 'evaluation';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$7 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Const': {
- return tensorMap[node.name];
- }
- case 'PlaceholderWithDefault':
- const def = getParamValue('default', node, tensorMap, context);
- return [getTensor(node.name, tensorMap, context) || def];
- case 'Placeholder':
- return [getTensor(node.name, tensorMap, context)];
- case 'Identity':
- case 'StopGradient':
- case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
- const data = getParamValue('x', node, tensorMap, context);
- return [cloneTensor(data)];
- }
- case 'IdentityN':
- return getParamValue('x', node, tensorMap, context)
- .map((t) => cloneTensor(t));
- case 'Snapshot':
- const snapshot = getParamValue('x', node, tensorMap, context);
- return [cloneTensor(snapshot)];
- case 'Shape':
- return [tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
- case 'ShapeN':
- return getParamValue('x', node, tensorMap, context)
- .map((t) => tensor1d(t.shape));
- case 'Size':
- return [scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
- case 'Rank':
- return [scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
- case 'NoOp':
- return [scalar(1)];
- case 'Print':
- const input = getParamValue('x', node, tensorMap, context);
- const data = getParamValue('data', node, tensorMap, context);
- const message = getParamValue('message', node, tensorMap, context);
- const summarize = getParamValue('summarize', node, tensorMap, context);
- console.warn('The graph has a tf.print() operation,' +
- 'usually used for debugging, which slows down performance.');
- console.log(message);
- for (let i = 0; i < data.length; i++) {
- console.log(Array.prototype.slice.call(data[i].dataSync())
- .slice(0, summarize));
- }
- return [input];
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$7 = 'graph';
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Hashtable contains a set of tensors, which can be accessed by key.
- */
- class HashTable {
- /**
- * Constructor of HashTable. Creates a hash table.
- *
- * @param keyDType `dtype` of the table keys.
- * @param valueDType `dtype` of the table values.
- */
- constructor(keyDType, valueDType) {
- this.keyDType = keyDType;
- this.valueDType = valueDType;
- this.handle = scalar(0);
- // tslint:disable-next-line: no-any
- this.tensorMap = new Map();
- keep(this.handle);
- }
- get id() {
- return this.handle.id;
- }
- /**
- * Dispose the tensors and handle and clear the hashtable.
- */
- clearAndClose() {
- this.tensorMap.forEach(value => value.dispose());
- this.tensorMap.clear();
- this.handle.dispose();
- }
- /**
- * The number of items in the hash table.
- */
- size() {
- return this.tensorMap.size;
- }
- /**
- * Replaces the contents of the table with the specified keys and values.
- * @param keys Keys to store in the hashtable.
- * @param values Values to store in the hashtable.
- */
- async import(keys, values) {
- this.checkKeyAndValueTensor(keys, values);
- // We only store the primitive values of the keys, this allows lookup
- // to be O(1).
- const $keys = await keys.data();
- // Clear the hashTable before inserting new values.
- this.tensorMap.forEach(value => value.dispose());
- this.tensorMap.clear();
- return tidy(() => {
- const $values = unstack(values);
- const keysLength = $keys.length;
- const valuesLength = $values.length;
- assert(keysLength === valuesLength, () => `The number of elements doesn't match, keys has ` +
- `${keysLength} elements, the values has ${valuesLength} ` +
- `elements.`);
- for (let i = 0; i < keysLength; i++) {
- const key = $keys[i];
- const value = $values[i];
- keep(value);
- this.tensorMap.set(key, value);
- }
- return this.handle;
- });
- }
- /**
- * Looks up keys in a hash table, outputs the corresponding values.
- *
- * Performs batch lookups, for every element in the key tensor, `find`
- * stacks the corresponding value into the return tensor.
- *
- * If an element is not present in the table, the given `defaultValue` is
- * used.
- *
- * @param keys Keys to look up. Must have the same type as the keys of the
- * table.
- * @param defaultValue The scalar `defaultValue` is the value output for keys
- * not present in the table. It must also be of the same type as the
- * table values.
- */
- async find(keys, defaultValue) {
- this.checkKeyAndValueTensor(keys, defaultValue);
- const $keys = await keys.data();
- return tidy(() => {
- const result = [];
- for (let i = 0; i < $keys.length; i++) {
- const key = $keys[i];
- const value = this.findWithDefault(key, defaultValue);
- result.push(value);
- }
- return stack(result);
- });
- }
- // tslint:disable-next-line: no-any
- findWithDefault(key, defaultValue) {
- const result = this.tensorMap.get(key);
- return result != null ? result : defaultValue;
- }
- checkKeyAndValueTensor(key, value) {
- if (key.dtype !== this.keyDType) {
- throw new Error(`Expect key dtype ${this.keyDType}, but got ` +
- `${key.dtype}`);
- }
- if (value.dtype !== this.valueDType) {
- throw new Error(`Expect value dtype ${this.valueDType}, but got ` +
- `${value.dtype}`);
- }
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$8 = async (node, tensorMap, context, resourceManager) => {
- switch (node.op) {
- case 'HashTable':
- case 'HashTableV2': {
- const keyDType = getParamValue('keyDType', node, tensorMap, context);
- const valueDType = getParamValue('valueDType', node, tensorMap, context);
- const hashTable = new HashTable(keyDType, valueDType);
- resourceManager.addHashTable(node.name, hashTable);
- return [hashTable.handle];
- }
- case 'LookupTableImport':
- case 'LookupTableImportV2': {
- const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
- const keys = getParamValue('keys', node, tensorMap, context);
- const values = getParamValue('values', node, tensorMap, context);
- const hashTable = resourceManager.getHashTableById(handle.id);
- return [await hashTable.import(keys, values)];
- }
- case 'LookupTableFind':
- case 'LookupTableFindV2': {
- const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
- const keys = getParamValue('keys', node, tensorMap, context);
- const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
- const hashTable = resourceManager.getHashTableById(handle.id);
- return [await hashTable.find(keys, defaultValue)];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$8 = 'hash_table';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$9 = (node, tensorMap, context) => {
- switch (node.op) {
- case 'ResizeBilinear': {
- const images = getParamValue('images', node, tensorMap, context);
- const size = getParamValue('size', node, tensorMap, context);
- const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
- return [image.resizeBilinear(images, [size[0], size[1]], alignCorners)];
- }
- case 'ResizeNearestNeighbor': {
- const images = getParamValue('images', node, tensorMap, context);
- const size = getParamValue('size', node, tensorMap, context);
- const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
- return [image.resizeNearestNeighbor(images, [size[0], size[1]], alignCorners)];
- }
- case 'CropAndResize': {
- const image$1 = getParamValue('image', node, tensorMap, context);
- const boxes = getParamValue('boxes', node, tensorMap, context);
- const boxInd = getParamValue('boxInd', node, tensorMap, context);
- const cropSize = getParamValue('cropSize', node, tensorMap, context);
- const method = getParamValue('method', node, tensorMap, context);
- const extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context);
- return [image.cropAndResize(image$1, boxes, boxInd, cropSize, method, extrapolationValue)];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$9 = 'image';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$a = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Equal': {
- return [equal(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'NotEqual': {
- return [notEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Greater': {
- return [greater(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'GreaterEqual': {
- return [greaterEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Less': {
- return [less(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'LessEqual': {
- return [lessEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'LogicalAnd': {
- return [logicalAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'LogicalNot': {
- return [logicalNot(getParamValue('a', node, tensorMap, context))];
- }
- case 'LogicalOr': {
- return [logicalOr(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- case 'Select':
- case 'SelectV2': {
- return [where(getParamValue('condition', node, tensorMap, context), getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$a = 'logical';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$b = (node, tensorMap, context) => {
- switch (node.op) {
- case 'BatchMatMul':
- case 'BatchMatMulV2':
- case 'MatMul':
- return [matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
- case 'Transpose':
- return [transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
- case '_FusedMatMul':
- const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
- const isBiasAdd = extraOp === 'biasadd';
- const isPrelu = activationFunc === 'prelu';
- const numArgs = getParamValue('numArgs', node, tensorMap, context);
- if (isBiasAdd) {
- if (isPrelu && numArgs !== 2) {
- throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
- 'extra arguments: bias and alpha.');
- }
- if (!isPrelu && numArgs !== 1) {
- throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
- }
- }
- const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
- return [matMul$1({
- a: getParamValue('a', node, tensorMap, context),
- b: getParamValue('b', node, tensorMap, context),
- transposeA: getParamValue('transposeA', node, tensorMap, context),
- transposeB: getParamValue('transposeB', node, tensorMap, context),
- bias: biasArg,
- activation: activationFunc,
- preluActivationWeights: preluArg
- })];
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$b = 'matrices';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$c = (node, tensorMap, context) => {
- switch (node.op) {
- case 'FusedBatchNorm':
- case 'FusedBatchNormV2': {
- return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
- }
- case 'FusedBatchNormV3': {
- return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
- }
- case 'LRN': {
- return [localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
- }
- case 'Softmax': {
- return [softmax(getParamValue('x', node, tensorMap, context))];
- }
- case 'LogSoftmax': {
- return [logSoftmax(getParamValue('x', node, tensorMap, context))];
- }
- case 'SparseToDense': {
- return [sparseToDense(getParamValue('sparseIndices', node, tensorMap, context), getParamValue('outputShape', node, tensorMap, context), getParamValue('sparseValues', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$c = 'normalization';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$d = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Max': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [max(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'Mean': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [mean(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'Min': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [min(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'Sum': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [sum$1(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'All': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [all(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'Any': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [any(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'ArgMax': {
- const axis = getParamValue('axis', node, tensorMap, context);
- return [argMax(getParamValue('x', node, tensorMap, context), axis)];
- }
- case 'ArgMin': {
- const axis = getParamValue('axis', node, tensorMap, context);
- return [argMin(getParamValue('x', node, tensorMap, context), axis)];
- }
- case 'Prod': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const keepDims = getParamValue('keepDims', node, tensorMap, context);
- return [prod(getParamValue('x', node, tensorMap, context), axis, keepDims)];
- }
- case 'Cumsum': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const exclusive = getParamValue('exclusive', node, tensorMap, context);
- const reverse = getParamValue('reverse', node, tensorMap, context);
- return [cumsum(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$d = 'reduction';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$e = (node, tensorMap, context) => {
- switch (node.op) {
- case 'ConcatV2':
- case 'Concat': {
- const n = getParamValue('n', node, tensorMap, context);
- const axis = getParamValue('axis', node, tensorMap, context);
- let inputs = getParamValue('tensors', node, tensorMap, context);
- inputs = inputs.slice(0, n);
- return [concat(inputs, axis)];
- }
- case 'GatherV2':
- case 'Gather': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const input = getParamValue('x', node, tensorMap, context);
- const indices = getParamValue('indices', node, tensorMap, context);
- return [gather(input, cast(indices, 'int32'), axis)];
- }
- case 'ReverseV2':
- case 'Reverse': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const input = getParamValue('x', node, tensorMap, context);
- return [reverse(input, axis)];
- }
- case 'Slice': {
- // tslint:disable-next-line:no-any
- const begin = getParamValue('begin', node, tensorMap, context);
- // tslint:disable-next-line:no-any
- const size = getParamValue('size', node, tensorMap, context);
- return [slice(getParamValue('x', node, tensorMap, context), begin, size)];
- }
- case 'StridedSlice': {
- const begin = getParamValue('begin', node, tensorMap, context);
- const end = getParamValue('end', node, tensorMap, context);
- const strides = getParamValue('strides', node, tensorMap, context);
- const beginMask = getParamValue('beginMask', node, tensorMap, context);
- const endMask = getParamValue('endMask', node, tensorMap, context);
- const ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context);
- const newAxisMask = getParamValue('newAxisMask', node, tensorMap, context);
- const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context);
- const tensor = getParamValue('x', node, tensorMap, context);
- return [stridedSlice(tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
- }
- case 'Pack': {
- return tidy(() => {
- const axis = getParamValue('axis', node, tensorMap, context);
- const tensors = getParamValue('tensors', node, tensorMap, context);
- // Reshape the tensors to the first tensor's shape if they don't
- // match.
- const shape = tensors[0].shape;
- const squeezedShape = squeeze(tensors[0]).shape;
- const mapped = tensors.map(tensor => {
- const sameShape = arraysEqual(tensor.shape, shape);
- if (!sameShape &&
- !arraysEqual(squeeze(tensor).shape, squeezedShape)) {
- throw new Error('the input tensors shape does not match');
- }
- return sameShape ? tensor : reshape(tensor, shape);
- });
- return [stack(mapped, axis)];
- });
- }
- case 'Unpack': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const tensor = getParamValue('tensor', node, tensorMap, context);
- return unstack(tensor, axis);
- }
- case 'Tile': {
- const reps = getParamValue('reps', node, tensorMap, context);
- return [tile(getParamValue('x', node, tensorMap, context), reps)];
- }
- case 'Split':
- case 'SplitV': {
- const axis = getParamValue('axis', node, tensorMap, context);
- const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context);
- const tensor = getParamValue('x', node, tensorMap, context);
- return split(tensor, numOrSizeSplits, axis);
- }
- case 'ScatterNd': {
- const indices = getParamValue('indices', node, tensorMap, context);
- const values = getParamValue('values', node, tensorMap, context);
- const shape = getParamValue('shape', node, tensorMap, context);
- return [scatterND(indices, values, shape)];
- }
- case 'GatherNd': {
- const x = getParamValue('x', node, tensorMap, context);
- const indices = getParamValue('indices', node, tensorMap, context);
- return [gatherND(x, indices)];
- }
- case 'SparseToDense': {
- const indices = getParamValue('sparseIndices', node, tensorMap, context);
- const shape = getParamValue('outputShape', node, tensorMap, context);
- const sparseValues = getParamValue('sparseValues', node, tensorMap, context);
- const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
- return [sparseToDense(indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ?
- defaultValue :
- cast(defaultValue, sparseValues.dtype))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$e = 'slice_join';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$f = (node, tensorMap, context) => {
- switch (node.op) {
- case 'FFT': {
- return [fft(getParamValue('x', node, tensorMap, context))];
- }
- case 'IFFT': {
- return [ifft(getParamValue('x', node, tensorMap, context))];
- }
- case 'RFFT': {
- return [rfft(getParamValue('x', node, tensorMap, context))];
- }
- case 'IRFFT': {
- return [irfft(getParamValue('x', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$f = 'spectral';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const executeOp$g = (node, tensorMap, context) => {
- switch (node.op) {
- case 'Cast': {
- return [cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
- }
- case 'ExpandDims': {
- const axis = getParamValue('axis', node, tensorMap, context);
- return [expandDims(getParamValue('x', node, tensorMap, context), axis)];
- }
- case 'Squeeze': {
- const axis = getParamValue('axis', node, tensorMap, context);
- return [squeeze(getParamValue('x', node, tensorMap, context), axis)];
- }
- case 'Reshape': {
- return [reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
- }
- case 'MirrorPad': {
- return [mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
- }
- case 'PadV2':
- case 'Pad': {
- return [pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
- }
- case 'SpaceToBatchND': {
- const blockShape = getParamValue('blockShape', node, tensorMap, context);
- const paddings = getParamValue('paddings', node, tensorMap, context);
- return [spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
- }
- case 'BatchToSpaceND': {
- const blockShape = getParamValue('blockShape', node, tensorMap, context);
- const crops = getParamValue('crops', node, tensorMap, context);
- return [batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)];
- }
- case 'DepthToSpace': {
- const blockSize = getParamValue('blockSize', node, tensorMap, context);
- const dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
- return [depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
- }
- case 'BroadcastTo': {
- return [broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
- }
- default:
- throw TypeError(`Node type ${node.op} is not implemented`);
- }
- };
- const CATEGORY$g = 'transformation';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Executes the op defined by the node object.
- * @param node
- * @param tensorMap contains tensors for executed nodes and weights
- * @param context contains tensors and information for running the current node.
- * @param resourceManager Optional. Contains global resources of the model.
- */
- function executeOp$h(node, tensorMap, context, resourceManager) {
- const value = ((node, tensorMap, context) => {
- switch (node.category) {
- case 'arithmetic':
- return tidy(() => executeOp(node, tensorMap, context));
- case 'basic_math':
- return tidy(() => executeOp$1(node, tensorMap, context));
- case 'control':
- return executeOp$2(node, tensorMap, context);
- case 'convolution':
- return tidy(() => executeOp$3(node, tensorMap, context));
- case 'creation':
- return tidy(() => executeOp$4(node, tensorMap, context));
- case 'dynamic':
- return executeOp$5(node, tensorMap, context);
- case 'evaluation':
- return tidy(() => executeOp$6(node, tensorMap, context));
- case 'image':
- return tidy(() => executeOp$9(node, tensorMap, context));
- case 'graph':
- return tidy(() => executeOp$7(node, tensorMap, context));
- case 'logical':
- return tidy(() => executeOp$a(node, tensorMap, context));
- case 'matrices':
- return tidy(() => executeOp$b(node, tensorMap, context));
- case 'normalization':
- return tidy(() => executeOp$c(node, tensorMap, context));
- case 'reduction':
- return tidy(() => executeOp$d(node, tensorMap, context));
- case 'slice_join':
- return tidy(() => executeOp$e(node, tensorMap, context));
- case 'spectral':
- return tidy(() => executeOp$f(node, tensorMap, context));
- case 'transformation':
- return tidy(() => executeOp$g(node, tensorMap, context));
- case 'hash_table':
- return executeOp$8(node, tensorMap, context, resourceManager);
- case 'custom':
- const opMapper = getRegisteredOp(node.op);
- if (opMapper && opMapper.customExecutor) {
- return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context));
- }
- else {
- throw TypeError(`Custom op ${node.op} is not registered.`);
- }
- default:
- throw TypeError(`Unknown op '${node.op}'. File an issue at ` +
- `https://github.com/tensorflow/tfjs/issues so we can add it` +
- `, or register a custom execution with tf.registerOp()`);
- }
- })(node, tensorMap, context);
- if (value instanceof Promise) {
- return value.then((data) => [].concat(data));
- }
- return [].concat(value);
- }
-
- /**
- * ExecutionContext captures the runtime environment of the node. It keeps
- * track of the current frame and iteration for the control flow ops.
- *
- * For example, typical Dynamic RNN model may contain loops, for which
- * TensorFlow will generate graphs with Enter/Exit nodes to control the
- * current execution frame, and NextIteration Nodes for iteration id increment.
- * For model with branch logic, TensorFLow will generate Switch/Merge ops.
- */
- class ExecutionContext {
- constructor(weightMap = {}, tensorArrayMap = {}, tensorListMap = {}, functionMap = {}) {
- this.weightMap = weightMap;
- this.tensorArrayMap = tensorArrayMap;
- this.tensorListMap = tensorListMap;
- this.functionMap = functionMap;
- this.rootContext = { id: 0, frameName: '', iterationId: 0 };
- this.contexts = [this.rootContext];
- this.lastId = 0;
- this.generateCurrentContextIds();
- }
- newFrame(id, frameName) {
- return { id, frameName, iterationId: 0 };
- }
- /**
- * Set the current context
- * @param contexts: ExecutionContextInfo[] the current path of execution
- * frames
- */
- set currentContext(contexts) {
- if (this.contexts !== contexts) {
- this.contexts = contexts;
- this.generateCurrentContextIds();
- }
- }
- get currentContext() {
- return this.contexts;
- }
- /**
- * Returns the current context in string format.
- */
- get currentContextId() {
- return this._currentContextIds[0];
- }
- /**
- * Returns the current context and all parent contexts in string format.
- * This allow access to the nodes in the current and parent frames.
- */
- get currentContextIds() {
- return this._currentContextIds;
- }
- generateCurrentContextIds() {
- const names = [];
- for (let i = 0; i < this.contexts.length - 1; i++) {
- const contexts = this.contexts.slice(0, this.contexts.length - i);
- names.push(this.contextIdforContexts(contexts));
- }
- names.push('');
- this._currentContextIds = names;
- }
- contextIdforContexts(contexts) {
- return contexts ?
- contexts
- .map(context => (context.id === 0 && context.iterationId === 0) ?
- '' :
- `${context.frameName}-${context.iterationId}`)
- .join('/') :
- '';
- }
- /**
- * Enter a new frame, a new context is pushed on the current context list.
- * @param frameId new frame id
- */
- enterFrame(frameId) {
- if (this.contexts) {
- this.lastId++;
- this.contexts = this.contexts.slice();
- this.contexts.push(this.newFrame(this.lastId, frameId));
- this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
- }
- }
- /**
- * Exit the current frame, the last context is removed from the current
- * context list.
- */
- exitFrame() {
- if (this.contexts && this.contexts.length > 1) {
- this.contexts = this.contexts.slice();
- this.contexts.splice(-1);
- this.currentContextIds.shift();
- }
- else {
- throw new Error('Cannot exit frame, the context is empty');
- }
- }
- /**
- * Enter the next iteration of a loop, the iteration id of last context is
- * increased.
- */
- nextIteration() {
- if (this.contexts && this.contexts.length > 0) {
- this.contexts = this.contexts.slice();
- this.lastId++;
- const context = Object.assign({}, this.contexts[this.contexts.length - 1]);
- context.iterationId += 1;
- context.id = this.lastId;
- this.contexts.splice(-1, 1, context);
- this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
- }
- else {
- throw new Error('Cannot increase frame iteration, the context is empty');
- }
- }
- getWeight(name) {
- return this.weightMap[name];
- }
- addTensorArray(tensorArray) {
- this.tensorArrayMap[tensorArray.id] = tensorArray;
- }
- getTensorArray(id) {
- return this.tensorArrayMap[id];
- }
- addTensorList(tensorList) {
- this.tensorListMap[tensorList.id] = tensorList;
- }
- getTensorList(id) {
- return this.tensorListMap[id];
- }
- dispose(keepIds) {
- for (const key in this.tensorArrayMap) {
- this.tensorArrayMap[key].clearAndClose(keepIds);
- }
- for (const key in this.tensorListMap) {
- this.tensorListMap[key].clearAndClose(keepIds);
- }
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Given graph inputs and desired outputs, find the minimal set of nodes
- * to execute in order to compute the outputs. In addition return other useful
- * info such:
- * - Missing inputs needed to compute the output.
- * - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
- * - Alternative inputs in order to avoid async (dynamic op) execution.
- */
- function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
- const usedNodes = new Set();
- const missingInputs = [];
- let dynamicNode = null;
- let syncInputs = null;
- // Start with the outputs, going backwards and find all the nodes that are
- // needed to compute those outputs.
- const seen = new Set();
- const inputNodeNames = Object.keys(inputs).map(name => parseNodeName(name)[0]);
- let initNodeNames = [];
- if (initNodes != null) {
- initNodeNames = initNodes.map(node => parseNodeName(node.name)[0]);
- }
- const frontier = [...outputs];
- while (frontier.length > 0) {
- const node = frontier.pop();
- if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
- if (dynamicNode == null) {
- dynamicNode = node;
- syncInputs = dynamicNode.children.map(child => child.name)
- .filter(name => usedNodes.has(name));
- }
- }
- usedNodes.add(node.name);
- // Weights are dead end since we already have their values.
- if (weightMap[node.name] != null) {
- continue;
- }
- // This node is a dead end since it's one of the user-provided inputs.
- if (inputNodeNames.indexOf(node.name) !== -1) {
- continue;
- }
- // This node is a dead end since it doesn't have any inputs.
- if (initNodeNames.indexOf(node.name) !== -1) {
- continue;
- }
- if (node.inputs.length === 0) {
- missingInputs.push(node.name);
- continue;
- }
- node.inputs.forEach(input => {
- // Don't add to the frontier if it is already there.
- if (seen.has(input.name)) {
- return;
- }
- seen.add(input.name);
- frontier.push(input);
- });
- }
- return { inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs };
- }
- /**
- * Given the execution info, return a list of nodes in topological order that
- * need to be executed to compute the output.
- */
- function getNodesInTopologicalOrder(graph, weightMap, executionInfo) {
- const { usedNodes, inputs } = executionInfo;
- const frontier = [];
- const inputNodes = Object.keys(inputs)
- .map(name => parseNodeName(name)[0])
- .map(name => graph.nodes[name]);
- const initNodes = graph.initNodes;
- inputNodes.forEach(input => {
- if (usedNodes.has(input.name)) {
- frontier.push(input);
- }
- });
- graph.weights.forEach(weight => {
- if (usedNodes.has(weight.name)) {
- frontier.push(weight);
- }
- });
- if (initNodes != null) {
- initNodes.forEach(node => {
- if (usedNodes.has(node.name)) {
- frontier.push(node);
- }
- });
- }
- const seen = new Set();
- const orderedNodes = [];
- while (frontier.length > 0) {
- const node = frontier.pop();
- seen.add(node.name);
- if (!weightMap[node.name]) {
- orderedNodes.push(node);
- }
- node.children.forEach(child => {
- if (!seen.has(child.name) && usedNodes.has(child.name) &&
- child.inputs.every(input => seen.has(input.name))) {
- frontier.push(child);
- }
- });
- }
- return orderedNodes;
- }
- const CONTROL_FLOW_OPS = [
- 'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf',
- 'StatelessWhile', 'if', 'While'
- ];
- const DYNAMIC_SHAPE_OPS = [
- 'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'
- ];
- const HASH_TABLE_OPS = [
- 'HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2',
- 'LookupTableFind', 'LookupTableFindV2'
- ];
- function isControlFlow(node) {
- return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;
- }
- function isDynamicShape(node) {
- return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;
- }
- function isHashTable(node) {
- return HASH_TABLE_OPS.indexOf(node.op) >= 0;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class GraphExecutor {
- /**
- *
- * @param graph Graph the model or function graph to be executed.
- * @param parent When building function exector you need to set the parent
- * executor. Since the weights and function executor maps are set at parant
- * level, that function executor can access the function maps and weight maps
- * through the parent.
- */
- constructor(graph, parent) {
- this.graph = graph;
- this.parent = parent;
- this.compiledMap = new Map();
- this._weightMap = {};
- this.SEPERATOR = ',';
- this._functions = {};
- this._functionExecutorMap = {};
- this._outputs = graph.outputs;
- this._inputs = graph.inputs;
- this._initNodes = graph.initNodes;
- this._signature = graph.signature;
- this._functions = graph.functions;
- // create sub-graph executors
- if (graph.functions != null) {
- Object.keys(graph.functions).forEach(name => {
- this._functionExecutorMap[name] =
- new GraphExecutor(graph.functions[name], this);
- });
- }
- }
- get weightIds() {
- return this.parent ? this.parent.weightIds : this._weightIds;
- }
- get functionExecutorMap() {
- return this.parent ? this.parent.functionExecutorMap :
- this._functionExecutorMap;
- }
- get weightMap() {
- return this.parent ? this.parent.weightMap : this._weightMap;
- }
- set weightMap(weightMap) {
- const weightIds = Object.keys(weightMap).map(key => weightMap[key].map(tensor => tensor.id));
- this._weightIds = [].concat(...weightIds);
- this._weightMap = weightMap;
- }
- /**
- * Set `ResourceManager` shared by executors of a model.
- * @param resourceManager: `ResourceManager` of the `GraphModel`.
- */
- set resourceManager(resourceManager) {
- this._resourceManager = resourceManager;
- }
- get inputs() {
- return this._inputs.map(node => {
- return {
- name: node.name,
- shape: node.attrParams['shape'] ?
- node.attrParams['shape'].value :
- undefined,
- dtype: node.attrParams['dtype'] ?
- node.attrParams['dtype'].value :
- undefined
- };
- });
- }
- get outputs() {
- return this._outputs.map(node => {
- return {
- name: node.name,
- shape: node.attrParams['shape'] ?
- node.attrParams['shape'].value :
- undefined,
- dtype: node.attrParams['dtype'] ?
- node.attrParams['dtype'].value :
- undefined
- };
- });
- }
- get inputNodes() {
- return this._inputs.map(node => node.signatureKey || node.name);
- }
- get outputNodes() {
- return this._outputs.map((node) => {
- const name = node.signatureKey || node.name;
- return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name;
- });
- }
- get functions() {
- return Object.keys(this._functions).reduce((map, key) => {
- map[key] = this._functions[key].signature;
- return map;
- }, {});
- }
- getCompilationKey(inputs, outputs) {
- const sortedInputs = inputs.map(node => node.name).sort();
- const sortedOutputs = outputs.map(node => node.name).sort();
- return sortedInputs.join(this.SEPERATOR) + '--' +
- sortedOutputs.join(this.SEPERATOR);
- }
- /**
- * Compiles the inference graph and returns the minimal set of nodes that are
- * required for execution, in the correct execution order.
- */
- compile(inputs, outputs) {
- const executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
- const { missingInputs, dynamicNode, syncInputs } = executionInfo;
- if (dynamicNode != null) {
- throw new Error(`This execution contains the node '${dynamicNode.name}', which has ` +
- `the dynamic op '${dynamicNode.op}'. Please use ` +
- `model.executeAsync() instead. Alternatively, to avoid the ` +
- `dynamic ops, specify the inputs [${syncInputs}]`);
- }
- if (missingInputs.length > 0) {
- const outNames = outputs.map(n => n.name);
- const inNames = Object.keys(inputs);
- throw new Error(`Cannot compute the outputs [${outNames}] from the provided inputs ` +
- `[${inNames}]. Missing the following inputs: [${missingInputs}]`);
- }
- return getNodesInTopologicalOrder(this.graph, this.weightMap, executionInfo);
- }
- /**
- * Executes the inference for given input tensors.
- * @param inputs Tensor map for the model inputs, keyed by the input node
- * names.
- * @param outputs Optional. output node name from the Tensorflow model, if
- * no outputs are specified, the default outputs of the model would be used.
- * You can inspect intermediate nodes of the model by adding them to the
- * outputs array.
- */
- execute(inputs, outputs) {
- inputs = this.mapInputs(inputs);
- const names = Object.keys(inputs).sort();
- this.checkInputs(inputs);
- this.checkInputShapeAndType(inputs);
- outputs = this.mapOutputs(outputs);
- this.checkOutputs(outputs);
- const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
- const outputNodeNames = outputs.map(name => parseNodeName(name)[0]);
- let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
- // If no outputs are specified, then use the default outputs of the model.
- if (outputNodes.length === 0) {
- outputNodes = this._outputs;
- }
- const compilationKey = this.getCompilationKey(inputNodes, outputNodes);
- // Do nothing if the compiled graph cache contains the input.
- let orderedNodes = this.compiledMap.get(compilationKey);
- if (orderedNodes == null) {
- orderedNodes = this.compile(inputs, outputNodes);
- this.compiledMap.set(compilationKey, orderedNodes);
- }
- const tensorArrayMap = {};
- const tensorListMap = {};
- return tidy(() => {
- const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
- const tensorsMap = { ...this.weightMap };
- Object.keys(inputs).forEach(name => {
- const [nodeName, index] = parseNodeName(name);
- const tensors = [];
- tensors[index] = inputs[name];
- tensorsMap[nodeName] = tensors;
- });
- const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
- const intermediateTensorConsumerCount = {};
- for (let i = 0; i < orderedNodes.length; i++) {
- const node = orderedNodes[i];
- if (!tensorsMap[node.name]) {
- const tensors = executeOp$h(node, tensorsMap, context, this._resourceManager);
- if (tensors instanceof Promise) {
- throw new Error(`The execution of the op '${node.op}' returned a promise. ` +
- `Please use model.executeAsync() instead.`);
- }
- tensorsMap[node.name] = tensors;
- this.checkTensorForDisposal(node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount);
- }
- }
- // dispose the context for the root executor
- if (this.parent == null) {
- context.dispose(tensorsToKeep);
- }
- return outputs.map(name => getTensor(name, tensorsMap, context));
- });
- }
- getFrozenTensorIds(tensorMap) {
- const ids = [].concat.apply([], Object.keys(tensorMap)
- .map(key => tensorMap[key])
- .map(tensors => tensors.map(tensor => tensor.id)));
- return new Set(ids);
- }
- checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount) {
- // Skip output nodes and any control flow nodes, since its dependency is
- // tricky to track correctly.
- if (node.category === 'control' || outputNames.indexOf(nodeName) !== -1) {
- return;
- }
- tensorMap[nodeName].forEach(tensor => {
- if (tensor != null) {
- intermediateTensorConsumerCount[tensor.id] =
- (intermediateTensorConsumerCount[tensor.id] || 0) +
- node.children.length;
- }
- });
- node.inputs.forEach(input => {
- // Skip any control flow nodes, since its dependency is tricky to track
- // correctly.
- if (input.category !== 'control') {
- const tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context);
- if (tensors != null) {
- tensors.forEach(tensor => {
- if (tensor && !tensorsToKeep.has(tensor.id)) {
- const count = intermediateTensorConsumerCount[tensor.id];
- if (count === 1) {
- tensor.dispose();
- delete intermediateTensorConsumerCount[tensor.id];
- }
- else if (count != null) {
- // only intermediate nodes has count set, inputs and weights are
- // not.
- intermediateTensorConsumerCount[tensor.id]--;
- }
- }
- });
- }
- }
- });
- }
- /**
- * Executes the inference for given input tensors in Async fashion.
- * @param inputs Tensor map for the model inputs, keyed by the input node
- * names.
- * @param outputs output node name from the Tensorflow model, if no outputs
- * are specified, the default outputs of the model would be used. You can
- * inspect intermediate nodes of the model by adding them to the outputs
- * array.
- */
- async executeAsync(inputs, outputs) {
- return this._executeAsync(inputs, outputs);
- }
- /**
- * Executes the inference for given input tensors in Async fashion.
- * @param inputs Tensor map for the model inputs, keyed by the input node
- * names.
- * @param outputs Optional. output node name from the Tensorflow model,
- * if no outputs are specified, the default outputs of the model would be
- * used. You can inspect intermediate nodes of the model by adding them to the
- * outputs array.
- * @param isFunctionExecution Optional. Flag for executing a function.
- * @param tensorArrayMap Optional, global TensorArray map by id. Used for
- * function execution.
- * @param tensorArrayMap Optinal global TensorList map by id. Used for
- * function execution.
- */
- async _executeAsync(inputs, outputs, isFunctionExecution = false, tensorArrayMap = {}, tensorListMap = {}) {
- if (!isFunctionExecution) {
- inputs = this.mapInputs(inputs);
- this.checkInputs(inputs);
- this.checkInputShapeAndType(inputs);
- outputs = this.mapOutputs(outputs);
- this.checkOutputs(outputs);
- }
- const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
- // Graph with control flow op requires runtime evaluation of the execution
- // order, while without control flow the execution order is pre-determined
- // in the compile method.
- const tensorMap = await this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
- const results = outputs.map(name => getTensor(name, tensorMap, context));
- // dispose all the intermediate tensors
- const outputIds = results.map(t => t.id);
- const inputIds = Object.keys(inputs).map(name => inputs[name].id);
- const keepIds = new Set([...outputIds, ...inputIds, ...this.weightIds]);
- Object.keys(tensorMap).forEach(key => {
- const tensorArray = tensorMap[key];
- tensorArray.forEach(tensor => {
- if (tensor && !tensor.isDisposed && !keepIds.has(tensor.id)) {
- tensor.dispose();
- }
- });
- });
- // dispose the context for the root executor
- if (this.parent == null) {
- context.dispose(keepIds);
- }
- return results;
- }
- async executeFunctionAsync(inputs, tensorArrayMap, tensorListMap) {
- const mappedInputs = inputs.reduce((map, tensor, index) => {
- map[this.inputs[index].name] = tensor;
- return map;
- }, {});
- return this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap);
- }
- /**
- * When there are control flow nodes in the graph, the graph execution use
- * ExecutionContext to keep track of the frames and loop iterators.
- * @param inputs placeholder tensors for the graph.
- * @param context the execution context object for current execution.
- * @param outputNames Optional. output node name from the Tensorflow model,
- * if no outputs are specified, the default outputs of the model would be
- * used. You can inspect intermediate nodes of the model by adding them to the
- * outputs array.
- * @param isFunctionExecution Flag for executing a function.
- */
- async executeWithControlFlow(inputs, context, outputNames, isFunctionExecution) {
- const names = Object.keys(inputs);
- const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
- const outputNodeNames = outputNames.map(name => parseNodeName(name)[0]);
- let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
- // If no outputs are specified, then use the default outputs of the model.
- if (outputNodes.length === 0) {
- outputNodes = this._outputs;
- }
- const { usedNodes, missingInputs, dynamicNode, syncInputs } = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes);
- // First nodes to execute include inputNodes, weights, and initNodes.
- const stack = [
- ...inputNodes, ...this.graph.weights, ...(this._initNodes || [])
- ].map(node => {
- return { node, contexts: context.currentContext };
- });
- const tensorsMap = { ...this.weightMap };
- Object.keys(inputs).forEach(name => {
- const [nodeName, index] = parseNodeName(name);
- const tensors = [];
- tensors[index] = inputs[name];
- tensorsMap[nodeName] = tensors;
- });
- const intermediateTensorConsumerCount = {};
- const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
- const added = {};
- while (stack.length > 0) {
- const promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount, usedNodes);
- await Promise.all(promises);
- }
- if (dynamicNode == null && !isFunctionExecution) {
- console.warn(`This model execution did not contain any nodes with control flow ` +
- `or dynamic output shapes. You can use model.execute() instead.`);
- }
- const missingOutputs = outputNodes
- .filter(node => !isControlFlow(node) &&
- !getTensor(node.name, tensorsMap, context))
- .map(node => node.name);
- if (missingOutputs.length > 0) {
- let alternativeMsg = '';
- if (dynamicNode != null) {
- alternativeMsg =
- `Alternatively, to avoid the dynamic ops, use model.execute() ` +
- `and specify the inputs [${syncInputs}]`;
- }
- throw new Error(`Cannot compute the outputs [${missingOutputs}] from the provided ` +
- `inputs [${names}]. Consider providing the following inputs: ` +
- `[${missingInputs}]. ${alternativeMsg}`);
- }
- return tensorsMap;
- }
- processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount, usedNodes) {
- const promises = [];
- while (stack.length > 0) {
- const item = stack.pop();
- context.currentContext = item.contexts;
- let nodeName = '';
- // The tensor of the Enter op with isConstant set should be set
- // in the parent scope, so it will be available as constant for the
- // whole loop.
- if (item.node.op === 'Enter' &&
- getParamValue('isConstant', item.node, tensorMap, context)) {
- [nodeName] = getNodeNameAndIndex(item.node.name, context);
- }
- // only process nodes that are not in the tensorMap yet, this include
- // inputNodes and internal initNodes.
- if (tensorMap[item.node.name] == null) {
- const tensors = executeOp$h(item.node, tensorMap, context, this._resourceManager);
- if (!nodeName) {
- [nodeName] = getNodeNameAndIndex(item.node.name, context);
- }
- const currentContext = context.currentContext;
- if (tensors instanceof Promise) {
- promises.push(tensors.then(t => {
- tensorMap[nodeName] = t;
- context.currentContext = currentContext;
- this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
- this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
- return t;
- }));
- }
- else {
- tensorMap[nodeName] = tensors;
- this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
- this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
- }
- }
- else {
- this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
- }
- }
- return promises;
- }
- processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
- node.children.forEach((childNode) => {
- const [nodeName,] = getNodeNameAndIndex(childNode.name, context);
- if (added[nodeName] || !usedNodes.has(childNode.name)) {
- return;
- }
- // Merge op can be pushed if any of its inputs has value.
- if (childNode.op === 'Merge') {
- if (childNode.inputNames.some(name => {
- return !!getTensor(name, tensorMap, context);
- })) {
- added[nodeName] = true;
- stack.push({ contexts: context.currentContext, node: childNode });
- }
- }
- else // Otherwise all inputs must to have value.
- if (childNode.inputNames.every(name => {
- return !!getTensor(name, tensorMap, context);
- })) {
- added[nodeName] = true;
- stack.push({ contexts: context.currentContext, node: childNode });
- }
- });
- }
- /**
- * Releases the memory used by the weight tensors.
- */
- dispose() {
- Object.keys(this.weightMap)
- .forEach(key => this.weightMap[key].forEach(tensor => tensor.dispose()));
- }
- checkInputShapeAndType(inputs) {
- Object.keys(inputs).forEach(name => {
- const input = inputs[name];
- const [nodeName,] = parseNodeName(name);
- const node = this.graph.nodes[nodeName];
- if (node.attrParams['shape'] && node.attrParams['shape'].value) {
- const shape = node.attrParams['shape'].value;
- const match = shape.length === input.shape.length &&
- input.shape.every((dim, index) => shape[index] === -1 || shape[index] === dim);
- assert(match, () => `The shape of dict['${node.name}'] provided in ` +
- `model.execute(dict) must be [${shape}], but was ` +
- `[${input.shape}]`);
- }
- if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
- assert(input.dtype === node.attrParams['dtype'].value, () => `The dtype of dict['${node.name}'] provided in ` +
- `model.execute(dict) must be ` +
- `${node.attrParams['dtype'].value}, but was ${input.dtype}`);
- }
- });
- }
- mapInputs(inputs) {
- const result = {};
- for (const inputName in inputs) {
- if (this._signature != null && this._signature.inputs != null &&
- this._signature.inputs[inputName] != null) {
- const tensor = this._signature.inputs[inputName];
- result[tensor.name] = inputs[inputName];
- }
- else {
- result[inputName] = inputs[inputName];
- }
- }
- return result;
- }
- checkInputs(inputs) {
- const notInGraph = Object.keys(inputs).filter(name => {
- const [nodeName] = parseNodeName(name);
- return this.graph.nodes[nodeName] == null;
- });
- if (notInGraph.length > 0) {
- throw new Error(`The dict provided in model.execute(dict) has ` +
- `keys: [${notInGraph}] that are not part of graph`);
- }
- }
- mapOutputs(outputs) {
- return outputs.map(name => {
- if (this._signature != null && this._signature.outputs != null &&
- this._signature.outputs[name] != null) {
- const tensor = this._signature.outputs[name];
- return tensor.name;
- }
- return name;
- }, {});
- }
- checkOutputs(outputs) {
- outputs.forEach(name => {
- const [normalizedName] = parseNodeName(name);
- if (!this.graph.nodes[normalizedName]) {
- throw new Error(`The output '${name}' is not found in the graph`);
- }
- });
- }
- }
-
- /**
- * Contains global resources of a model.
- */
- class ResourceManager {
- constructor(hashTableNameToHandle = {}, hashTableMap = {}) {
- this.hashTableNameToHandle = hashTableNameToHandle;
- this.hashTableMap = hashTableMap;
- }
- /**
- * Register a `HashTable` in the resource manager.
- *
- * The `HashTable` can be retrieved by `resourceManager.getHashTableById`,
- * where id is the table handle tensor's id.
- *
- * @param name Op node name that creates the `HashTable`.
- * @param hashTable The `HashTable` to be added to resource manager.
- */
- addHashTable(name, hashTable) {
- this.hashTableNameToHandle[name] = hashTable.handle;
- this.hashTableMap[hashTable.id] = hashTable;
- }
- /**
- * Get the table handle by node name.
- * @param name Op node name that creates the `HashTable`. This name is also
- * used in the inputs list of lookup and import `HashTable` ops.
- */
- getHashTableHandleByName(name) {
- return this.hashTableNameToHandle[name];
- }
- /**
- * Get the actual `HashTable` by its handle tensor's id.
- * @param id The id of the handle tensor.
- */
- getHashTableById(id) {
- return this.hashTableMap[id];
- }
- /**
- * Dispose `ResourceManager`, including its hashTables and tensors in them.
- */
- dispose() {
- for (const key in this.hashTableMap) {
- this.hashTableMap[key].clearAndClose();
- delete this.hashTableMap[key];
- }
- for (const name in this.hashTableNameToHandle) {
- this.hashTableNameToHandle[name].dispose();
- delete this.hashTableNameToHandle[name];
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
- const DEFAULT_MODEL_NAME = 'model.json';
- /**
- * A `tf.GraphModel` is a directed, acyclic graph built from a
- * SavedModel GraphDef and allows inference execution.
- *
- * A `tf.GraphModel` can only be created by loading from a model converted from
- * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
- * the command line converter tool and loaded via `tf.loadGraphModel`.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- class GraphModel {
- /**
- * @param modelUrl url for the model, or an `io.IOHandler`.
- * @param weightManifestUrl url for the weight file generated by
- * scripts/convert.py script.
- * @param requestOption options for Request, which allows to send credentials
- * and custom headers.
- * @param onProgress Optional, progress callback function, fired periodically
- * before the load is completed.
- */
- constructor(modelUrl, loadOptions = {}) {
- this.modelUrl = modelUrl;
- this.loadOptions = loadOptions;
- this.version = 'n/a';
- if (loadOptions == null) {
- this.loadOptions = {};
- }
- this.resourceManager = new ResourceManager();
- }
- // Returns the version information for the tensorflow model GraphDef.
- get modelVersion() {
- return this.version;
- }
- get inputNodes() {
- return this.executor.inputNodes;
- }
- get outputNodes() {
- return this.executor.outputNodes;
- }
- get inputs() {
- return this.executor.inputs;
- }
- get outputs() {
- return this.executor.outputs;
- }
- get weights() {
- return this.executor.weightMap;
- }
- findIOHandler() {
- const path = this.modelUrl;
- if (path.load != null) {
- // Path is an IO Handler.
- this.handler = path;
- }
- else if (this.loadOptions.requestInit != null) {
- this.handler = browserHTTPRequest(path, this.loadOptions);
- }
- else {
- const handlers = getLoadHandlers(path, this.loadOptions);
- if (handlers.length === 0) {
- // For backward compatibility: if no load handler can be found,
- // assume it is a relative http path.
- handlers.push(browserHTTPRequest(path, this.loadOptions));
- }
- else if (handlers.length > 1) {
- throw new Error(`Found more than one (${handlers.length}) load handlers for ` +
- `URL '${[path]}'`);
- }
- this.handler = handlers[0];
- }
- }
- /**
- * Loads the model and weight files, construct the in memory weight map and
- * compile the inference graph.
- */
- async load() {
- this.findIOHandler();
- if (this.handler.load == null) {
- throw new Error('Cannot proceed with model loading because the IOHandler provided ' +
- 'does not have the `load` method implemented.');
- }
- const artifacts = await this.handler.load();
- return this.loadSync(artifacts);
- }
- /**
- * Synchronously construct the in memory weight map and
- * compile the inference graph. Also initialize hashtable if any.
- *
- * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
- */
- loadSync(artifacts) {
- this.artifacts = artifacts;
- const graph = this.artifacts.modelTopology;
- let signature = {};
- if (this.artifacts.userDefinedMetadata != null) {
- signature = // tslint:disable-next-line:no-any
- this.artifacts.userDefinedMetadata.signature;
- }
- this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
- const weightMap = decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
- this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, signature));
- this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
- // Attach a model-level resourceManager to each executor to share resources,
- // such as `HashTable`.
- this.executor.resourceManager = this.resourceManager;
- if (artifacts.modelInitializer != null) {
- const initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
- this.initializer = new GraphExecutor(initializer);
- this.initializer.weightMap = this.executor.weightMap;
- // Attach a model-level resourceManager to the initializer, the
- // hashTables created from when executing the initializer will be stored
- // in the resourceManager.
- this.initializer.resourceManager = this.resourceManager;
- this.initializer.executeAsync({}, []);
- }
- return true;
- }
- /**
- * Save the configuration and/or weights of the GraphModel.
- *
- * An `IOHandler` is an object that has a `save` method of the proper
- * signature defined. The `save` method manages the storing or
- * transmission of serialized data ("artifacts") that represent the
- * model's topology and weights onto or via a specific medium, such as
- * file downloads, local storage, IndexedDB in the web browser and HTTP
- * requests to a server. TensorFlow.js provides `IOHandler`
- * implementations for a number of frequently used saving mediums, such as
- * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
- * for more details.
- *
- * This method also allows you to refer to certain types of `IOHandler`s
- * as URL-like string shortcuts, such as 'localstorage://' and
- * 'indexeddb://'.
- *
- * Example 1: Save `model`'s topology and weights to browser [local
- * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
- * then load it back.
- *
- * ```js
- * const modelUrl =
- * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
- * const model = await tf.loadGraphModel(modelUrl);
- * const zeros = tf.zeros([1, 224, 224, 3]);
- * model.predict(zeros).print();
- *
- * const saveResults = await model.save('localstorage://my-model-1');
- *
- * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
- * console.log('Prediction from loaded model:');
- * model.predict(zeros).print();
- * ```
- *
- * @param handlerOrURL An instance of `IOHandler` or a URL-like,
- * scheme-based string shortcut for `IOHandler`.
- * @param config Options for saving the model.
- * @returns A `Promise` of `SaveResult`, which summarizes the result of
- * the saving, such as byte sizes of the saved artifacts for the model's
- * topology and weight values.
- *
- * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
- */
- async save(handlerOrURL, config) {
- if (typeof handlerOrURL === 'string') {
- const handlers = getSaveHandlers(handlerOrURL);
- if (handlers.length === 0) {
- throw new Error(`Cannot find any save handlers for URL '${handlerOrURL}'`);
- }
- else if (handlers.length > 1) {
- throw new Error(`Found more than one (${handlers.length}) save handlers for ` +
- `URL '${handlerOrURL}'`);
- }
- handlerOrURL = handlers[0];
- }
- if (handlerOrURL.save == null) {
- throw new Error('GraphModel.save() cannot proceed because the IOHandler ' +
- 'provided does not have the `save` attribute defined.');
- }
- return handlerOrURL.save(this.artifacts);
- }
- /**
- * Execute the inference for the input tensors.
- *
- * @param input The input tensors, when there is single input for the model,
- * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
- * inputs params should be in either `tf.Tensor`[] if the input order is
- * fixed, or otherwise NamedTensorMap format.
- *
- * For model with multiple inputs, we recommend you use NamedTensorMap as the
- * input type, if you use `tf.Tensor`[], the order of the array needs to
- * follow the
- * order of inputNodes array. @see {@link GraphModel.inputNodes}
- *
- * You can also feed any intermediate nodes using the NamedTensorMap as the
- * input type. For example, given the graph
- * InputNode => Intermediate => OutputNode,
- * you can execute the subgraph Intermediate => OutputNode by calling
- * model.execute('IntermediateNode' : tf.tensor(...));
- *
- * This is useful for models that uses tf.dynamic_rnn, where the intermediate
- * state needs to be fed manually.
- *
- * For batch inference execution, the tensors for each input need to be
- * concatenated together. For example with mobilenet, the required input shape
- * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
- * If we are provide a batched data of 100 images, the input tensor should be
- * in the shape of [100, 244, 244, 3].
- *
- * @param config Prediction configuration for specifying the batch size and
- * output node names. Currently the batch size option is ignored for graph
- * model.
- *
- * @returns Inference result tensors. The output would be single `tf.Tensor`
- * if model has single output node, otherwise Tensor[] or NamedTensorMap[]
- * will be returned for model with multiple outputs.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- predict(inputs, config) {
- return this.execute(inputs, this.outputNodes);
- }
- normalizeInputs(inputs) {
- if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
- // The input is already a NamedTensorMap.
- return inputs;
- }
- inputs = Array.isArray(inputs) ? inputs : [inputs];
- if (inputs.length !== this.inputNodes.length) {
- throw new Error('Input tensor count mismatch,' +
- `the graph model has ${this.inputNodes.length} placeholders, ` +
- `while there are ${inputs.length} input tensors.`);
- }
- return this.inputNodes.reduce((map, inputName, i) => {
- map[inputName] = inputs[i];
- return map;
- }, {});
- }
- normalizeOutputs(outputs) {
- outputs = outputs || this.outputNodes;
- return !Array.isArray(outputs) ? [outputs] : outputs;
- }
- /**
- * Executes inference for the model for given input tensors.
- * @param inputs tensor, tensor array or tensor map of the inputs for the
- * model, keyed by the input node names.
- * @param outputs output node name from the Tensorflow model, if no
- * outputs are specified, the default outputs of the model would be used.
- * You can inspect intermediate nodes of the model by adding them to the
- * outputs array.
- *
- * @returns A single tensor if provided with a single output or no outputs
- * are provided and there is only one default output, otherwise return a
- * tensor array. The order of the tensor array is the same as the outputs
- * if provided, otherwise the order of outputNodes attribute of the model.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- execute(inputs, outputs) {
- inputs = this.normalizeInputs(inputs);
- outputs = this.normalizeOutputs(outputs);
- const result = this.executor.execute(inputs, outputs);
- return result.length > 1 ? result : result[0];
- }
- /**
- * Executes inference for the model for given input tensors in async
- * fashion, use this method when your model contains control flow ops.
- * @param inputs tensor, tensor array or tensor map of the inputs for the
- * model, keyed by the input node names.
- * @param outputs output node name from the Tensorflow model, if no outputs
- * are specified, the default outputs of the model would be used. You can
- * inspect intermediate nodes of the model by adding them to the outputs
- * array.
- *
- * @returns A Promise of single tensor if provided with a single output or
- * no outputs are provided and there is only one default output, otherwise
- * return a tensor map.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- async executeAsync(inputs, outputs) {
- inputs = this.normalizeInputs(inputs);
- outputs = this.normalizeOutputs(outputs);
- const result = await this.executor.executeAsync(inputs, outputs);
- return result.length > 1 ? result : result[0];
- }
- convertTensorMapToTensorsMap(map) {
- return Object.keys(map).reduce((newMap, key) => {
- newMap[key] = [map[key]];
- return newMap;
- }, {});
- }
- /**
- * Releases the memory used by the weight tensors and resourceManager.
- *
- * @doc {heading: 'Models', subheading: 'Classes'}
- */
- dispose() {
- this.executor.dispose();
- if (this.initializer) {
- this.initializer.dispose();
- }
- this.resourceManager.dispose();
- }
- }
- /**
- * Load a graph model given a URL to the model definition.
- *
- * Example of loading MobileNetV2 from a URL and making a prediction with a
- * zeros input:
- *
- * ```js
- * const modelUrl =
- * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
- * const model = await tf.loadGraphModel(modelUrl);
- * const zeros = tf.zeros([1, 224, 224, 3]);
- * model.predict(zeros).print();
- * ```
- *
- * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
- * a zeros input:
- *
- * ```js
- * const modelUrl =
- * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
- * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
- * const zeros = tf.zeros([1, 224, 224, 3]);
- * model.predict(zeros).print();
- * ```
- * @param modelUrl The url or an `io.IOHandler` that loads the model.
- * @param options Options for the HTTP request, which allows to send credentials
- * and custom headers.
- *
- * @doc {heading: 'Models', subheading: 'Loading'}
- */
- async function loadGraphModel(modelUrl, options = {}) {
- if (modelUrl == null) {
- throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
- 'or an IOHandler that loads the model');
- }
- if (options == null) {
- options = {};
- }
- if (options.fromTFHub) {
- if (modelUrl.load == null) {
- if (!modelUrl.endsWith('/')) {
- modelUrl = modelUrl + '/';
- }
- modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
- }
- }
- const model = new GraphModel(modelUrl, options);
- await model.load();
- return model;
- }
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$2 = '0.0.0';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Apply a mapping function to a nested structure in a recursive manner.
- *
- * The result of the mapping is an object with the same nested structure (i.e.,
- * of arrays and dicts) as the input, except that some subtrees are replaced,
- * according to the results of the mapping function.
- *
- * Mappings are memoized. Thus, if the nested structure contains the same
- * object in multiple positions, the output will contain the same mapped object
- * in those positions. Cycles are not supported, however.
- *
- * @param input: The object to which to apply the mapping function.
- * @param mapFn: A function that expects a single node of the object tree, and
- * returns a `DeepMapResult`. The `DeepMapResult` either provides a
- * replacement value for that node (i.e., replacing the subtree), or indicates
- * that the node should be processed recursively.
- */
- function deepMap(input, mapFn) {
- return deepMapInternal(input, mapFn);
- }
- /**
- * @param seen: A Map of known object mappings (i.e., memoized results of
- * `mapFn()`)
- * @param containedIn: An set containing objects on the reference path currently
- * being processed (used to detect cycles).
- */
- function deepMapInternal(input, mapFn, seen = new Map(), containedIn = new Set()) {
- if (input == null) {
- return null;
- }
- if (containedIn.has(input)) {
- throw new Error('Circular references are not supported.');
- }
- if (seen.has(input)) {
- return seen.get(input);
- }
- const result = mapFn(input);
- if (result.recurse && result.value !== null) {
- throw new Error('A deep map function may not return both a value and recurse=true.');
- }
- if (!result.recurse) {
- seen.set(input, result.value);
- return result.value;
- }
- else if (isIterable$1(input)) {
- // tslint:disable-next-line:no-any
- const mappedIterable = Array.isArray(input) ? [] : {};
- containedIn.add(input);
- for (const k in input) {
- const child = input[k];
- const childResult = deepMapInternal(child, mapFn, seen, containedIn);
- mappedIterable[k] = childResult;
- }
- containedIn.delete(input);
- return mappedIterable;
- }
- else {
- throw new Error(`Can't recurse into non-iterable type: ${input}`);
- }
- }
- // TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
- // with zip()
- /**
- * Zip nested structures together in a recursive manner.
- *
- * This has the effect of transposing or pivoting data, e.g. converting it from
- * a row-major representation to a column-major representation.
- *
- * For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
- * `{a: [1, 3], b: [2, 4]}`.
- *
- * The inputs should all have the same nested structure (i.e., of arrays and
- * dicts). The result is a single object with the same nested structure, where
- * the leaves are arrays collecting the values of the inputs at that location
- * (or, optionally, the result of a custom function applied to those arrays).
- *
- * @param inputs: An array of the objects to zip together.
- * @param zipFn: (optional) A function that expects an array of elements at a
- * single node of the object tree, and returns a `DeepMapResult`. The
- * `DeepMapResult` either provides a result value for that node (i.e.,
- * representing the subtree), or indicates that the node should be processed
- * recursively. The default zipFn recurses as far as possible and places
- * arrays at the leaves.
- */
- function deepZip(inputs, zipFn = zipToList) {
- return deepZipInternal(inputs, zipFn);
- }
- /**
- * @param containedIn: An set containing objects on the reference path currently
- * being processed (used to detect cycles).
- */
- function deepZipInternal(inputs, zipFn, containedIn = new Set()) {
- // The recursion follows the structure of input 0; it's assumed that all the
- // other inputs have the same structure.
- const input = inputs[0];
- if (containedIn.has(input)) {
- throw new Error('Circular references are not supported.');
- }
- const result = zipFn(inputs);
- if (result.recurse && result.value !== null) {
- throw new Error('A deep zip function may not return both a value and recurse=true.');
- }
- if (!result.recurse) {
- return result.value;
- }
- else if (isIterable$1(input)) {
- // tslint:disable-next-line:no-any
- const mappedIterable = Array.isArray(input) ? [] : {};
- containedIn.add(input);
- for (const k in input) {
- const children = inputs.map(x => x[k]);
- const childResult = deepZipInternal(children, zipFn, containedIn);
- mappedIterable[k] = childResult;
- }
- containedIn.delete(input);
- return mappedIterable;
- }
- else {
- throw new Error(`Can't recurse into non-iterable type: ${input}`);
- }
- }
- // tslint:disable-next-line:no-any
- function zipToList(x) {
- if (x === null) {
- return null;
- }
- // TODO(soergel): validate array type?
- if (isIterable$1(x[0])) {
- return { value: null, recurse: true };
- }
- else {
- return { value: x, recurse: false };
- }
- }
- /**
- * Apply an async mapping function to a nested structure in a recursive manner.
- *
- * This first creates a nested structure of Promises, and then awaits all of
- * those, resulting in a single Promise for a resolved nested structure.
- *
- * The result of the mapping is an object with the same nested structure (i.e.,
- * of arrays and dicts) as the input, except that some subtrees are replaced,
- * according to the results of the mapping function.
- *
- * Mappings are memoized. Thus, if the nested structure contains the same
- * object in multiple positions, the output will contain the same mapped object
- * in those positions. Cycles are not supported, however.
- *
- * @param input: The object to which to apply the mapping function.
- * @param mapFn: A function that expects a single node of the object tree, and
- * returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
- * a `Promise` for a replacement value for that node (i.e., replacing the
- * subtree), or indicates that the node should be processed recursively. Note
- * that the decision whether or not to recurse must be made immediately; only
- * the mapped value may be promised.
- */
- async function deepMapAndAwaitAll(input, mapFn) {
- const seen = new Map();
- // First do a normal deepMap, collecting Promises in 'seen' as a side effect.
- deepMapInternal(input, mapFn, seen);
- // Replace the Promises in 'seen' in place.
- // Note TypeScript provides no async map iteration, and regular map iteration
- // is broken too, so sadly we have to do Array.from() to make it work.
- // (There's no advantage to Promise.all(), and that would be tricky anyway.)
- for (const key of Array.from(seen.keys())) {
- const value = seen.get(key);
- if (value instanceof Promise) {
- const mappedValue = await value;
- seen.set(key, mappedValue);
- }
- }
- // Normal deepMap again, this time filling in the resolved values.
- // It's unfortunate that we have to do two passes.
- // TODO(soergel): test performance and think harder about a fast solution.
- const result = deepMapInternal(input, mapFn, seen);
- return result;
- }
- /**
- * Determine whether the argument is iterable.
- *
- * @returns true if the argument is an array or any non-Tensor object.
- */
- // tslint:disable-next-line:no-any
- function isIterable$1(obj) {
- return obj != null && (!ArrayBuffer.isView(obj)) &&
- (Array.isArray(obj) ||
- (typeof obj === 'object' && !(obj instanceof Tensor)));
- }
- /**
- * Determine whether the argument can be converted to Tensor.
- *
- * Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
- * not.
- *
- * @returns true if the argument can be converted to Tensor.
- */
- // tslint:disable-next-line:no-any
- function canTensorify(obj) {
- return obj == null || isPrimitive(obj) || Array.isArray(obj) ||
- (typeof obj === 'object' && (obj instanceof Tensor)) ||
- isTypedArray(obj);
- }
- /**
- * Returns true if the given `value` is a primitive type. Otherwise returns
- * false. This is equivalant to node util.isPrimitive
- */
- function isPrimitive(value) {
- return (value === null ||
- (typeof value !== 'object' && typeof value !== 'function'));
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- function deepClone(container) {
- return deepMap(container, cloneIfTensor);
- }
- // tslint:disable-next-line: no-any
- function cloneIfTensor(item) {
- if (item instanceof Tensor) {
- return ({ value: item.clone(), recurse: false });
- }
- else if (isIterable$1(item)) {
- return { value: null, recurse: true };
- }
- else {
- return { value: item, recurse: false };
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * A ring buffer, providing O(1) FIFO, LIFO, and related operations.
- */
- class RingBuffer {
- /**
- * Constructs a `RingBuffer`.
- * @param capacity The number of items that the buffer can accomodate.
- */
- constructor(capacity) {
- this.capacity = capacity;
- // Note we store the indices in the range 0 <= index < 2*capacity.
- // This allows us to distinguish the full from the empty case.
- // See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
- this.begin = 0; // inclusive
- this.end = 0; // exclusive
- if (capacity == null) {
- throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
- }
- if (capacity < 1) {
- throw new RangeError('Can\'t create ring buffer of capacity < 1.');
- }
- this.data = new Array(capacity);
- this.doubledCapacity = 2 * capacity;
- }
- /**
- * Map any index into the range 0 <= index < 2*capacity.
- */
- wrap(index) {
- // don't trust % on negative numbers
- while (index < 0) {
- index += this.doubledCapacity;
- }
- return index % this.doubledCapacity;
- }
- get(index) {
- if (index < 0) {
- throw new RangeError('Can\'t get item at a negative index.');
- }
- return this.data[index % this.capacity];
- }
- set(index, value) {
- if (index < 0) {
- throw new RangeError('Can\'t set item at a negative index.');
- }
- this.data[index % this.capacity] = value;
- }
- /**
- * Returns the current number of items in the buffer.
- */
- length() {
- let length = this.end - this.begin;
- if (length < 0) {
- length = this.doubledCapacity + length;
- }
- return length;
- }
- /**
- * Reports whether the buffer is full.
- * @returns true if the number of items in the buffer equals its capacity, and
- * false otherwise.
- */
- isFull() {
- return this.length() === this.capacity;
- }
- /**
- * Reports whether the buffer is empty.
- * @returns true if the number of items in the buffer equals zero, and
- * false otherwise.
- */
- isEmpty() {
- return this.length() === 0;
- }
- /**
- * Adds an item to the end of the buffer.
- */
- push(value) {
- if (this.isFull()) {
- throw new RangeError('Ring buffer is full.');
- }
- this.set(this.end, value);
- this.end = this.wrap(this.end + 1);
- }
- /**
- * Adds many items to the end of the buffer, in order.
- */
- pushAll(values) {
- for (const value of values) {
- this.push(value);
- }
- }
- /**
- * Removes and returns the last item in the buffer.
- */
- pop() {
- if (this.isEmpty()) {
- throw new RangeError('Ring buffer is empty.');
- }
- this.end = this.wrap(this.end - 1);
- const result = this.get(this.end);
- this.set(this.end, undefined);
- return result;
- }
- /**
- * Adds an item to the beginning of the buffer.
- */
- unshift(value) {
- if (this.isFull()) {
- throw new RangeError('Ring buffer is full.');
- }
- this.begin = this.wrap(this.begin - 1);
- this.set(this.begin, value);
- }
- /**
- * Removes and returns the first item in the buffer.
- */
- shift() {
- if (this.isEmpty()) {
- throw new RangeError('Ring buffer is empty.');
- }
- const result = this.get(this.begin);
- this.set(this.begin, undefined);
- this.begin = this.wrap(this.begin + 1);
- return result;
- }
- /**
- * Removes and returns a specific item in the buffer, and moves the last item
- * to the vacated slot. This is useful for implementing a shuffling stream.
- * Note that this operation necessarily scrambles the original order.
- *
- * @param relativeIndex: the index of the item to remove, relative to the
- * first item in the buffer (e.g., hiding the ring nature of the underlying
- * storage).
- */
- shuffleExcise(relativeIndex) {
- if (this.isEmpty()) {
- throw new RangeError('Ring buffer is empty.');
- }
- const index = this.wrap(this.begin + relativeIndex);
- const result = this.get(index);
- this.set(index, this.pop());
- return result;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- class GrowingRingBuffer extends RingBuffer {
- /**
- * Constructs a `GrowingRingBuffer`.
- */
- constructor() {
- super(GrowingRingBuffer.INITIAL_CAPACITY);
- }
- isFull() {
- return false;
- }
- push(value) {
- if (super.isFull()) {
- this.expand();
- }
- super.push(value);
- }
- unshift(value) {
- if (super.isFull()) {
- this.expand();
- }
- super.unshift(value);
- }
- /**
- * Doubles the capacity of the buffer.
- */
- expand() {
- const newCapacity = this.capacity * 2;
- const newData = new Array(newCapacity);
- const len = this.length();
- // Rotate the buffer to start at index 0 again, since we can't just
- // allocate more space at the end.
- for (let i = 0; i < len; i++) {
- newData[i] = this.get(this.wrap(this.begin + i));
- }
- this.data = newData;
- this.capacity = newCapacity;
- this.doubledCapacity = 2 * this.capacity;
- this.begin = 0;
- this.end = len;
- }
- }
- GrowingRingBuffer.INITIAL_CAPACITY = 32;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- // Here we implement a simple asynchronous iterator.
- // This lets us avoid using either third-party stream libraries or
- // recent TypeScript language support requiring polyfills.
- /**
- * Create a `LazyIterator` from an array of items.
- */
- function iteratorFromItems(items) {
- return new ArrayIterator(items);
- }
- /**
- * Create a `LazyIterator` of incrementing integers.
- */
- function iteratorFromIncrementing(start) {
- let i = start;
- return iteratorFromFunction(() => ({ value: i++, done: false }));
- }
- /**
- * Create a `LazyIterator` from a function.
- *
- * ```js
- * let i = -1;
- * const func = () =>
- * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
- * const iter = tf.data.iteratorFromFunction(func);
- * await iter.forEachAsync(e => console.log(e));
- * ```
- *
- * @param func A function that produces data on each call.
- */
- function iteratorFromFunction(func) {
- return new FunctionCallIterator(func);
- }
- /**
- * Create a `LazyIterator` by concatenating underlying streams, which are
- * themselves provided as a stream.
- *
- * This can also be thought of as a "stream flatten" operation.
- *
- * @param baseIterators A stream of streams to be concatenated.
- * @param baseErrorHandler An optional function that can intercept `Error`s
- * raised during a `next()` call on the base stream. This function can decide
- * whether the error should be propagated, whether the error should be
- * ignored, or whether the base stream should be terminated.
- */
- function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
- return new ChainedIterator(baseIterators, baseErrorHandler);
- }
- /**
- * Create a `LazyIterator` by concatenating streams produced by calling a
- * stream-generating function a given number of times.
- *
- * Since a `LazyIterator` is read-once, it cannot be repeated, but this
- * function can be used to achieve a similar effect:
- *
- * LazyIterator.ofConcatenatedFunction(() => new MyIterator(), 6);
- *
- * @param iteratorFunc: A function that produces a new stream on each call.
- * @param count: The number of times to call the function.
- * @param baseErrorHandler An optional function that can intercept `Error`s
- * raised during a `next()` call on the base stream. This function can decide
- * whether the error should be propagated, whether the error should be
- * ignored, or whether the base stream should be terminated.
- */
- function iteratorFromConcatenatedFunction(iteratorFunc, count, baseErrorHandler) {
- return iteratorFromConcatenated(iteratorFromFunction(iteratorFunc).take(count), baseErrorHandler);
- }
- /**
- * Create a `LazyIterator` by zipping together an array, dict, or nested
- * structure of `LazyIterator`s (and perhaps additional constants).
- *
- * The underlying streams must provide elements in a consistent order such
- * that they correspond.
- *
- * Typically, the underlying streams should have the same number of
- * elements. If they do not, the behavior is determined by the
- * `mismatchMode` argument.
- *
- * The nested structure of the `iterators` argument determines the
- * structure of elements in the resulting iterator.
- *
- * @param iterators: An array or object containing LazyIterators at the
- * leaves.
- * @param mismatchMode: Determines what to do when one underlying iterator
- * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
- * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
- * causes the zipped iterator to terminate with the furst underlying
- * streams, so elements remaining on the longer streams are ignored.
- * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
- * in nulls for the exhausted streams, until all streams are exhausted.
- */
- function iteratorFromZipped(iterators, mismatchMode = ZipMismatchMode.FAIL) {
- return new ZipIterator(iterators, mismatchMode);
- }
- /**
- * An asynchronous iterator, providing lazy access to a potentially
- * unbounded stream of elements.
- *
- * Iterator can be obtained from a dataset:
- * `const iter = await dataset.iterator();`
- */
- class LazyIterator {
- /**
- * Collect all remaining elements of a bounded stream into an array.
- * Obviously this will succeed only for small streams that fit in memory.
- * Useful for testing.
- *
- * @returns A Promise for an array of stream elements, which will resolve
- * when the stream is exhausted.
- */
- async toArray() {
- const result = [];
- let x = await this.next();
- while (!x.done) {
- result.push(x.value);
- x = await this.next();
- }
- return result;
- }
- /**
- * Collect all elements of this dataset into an array with prefetching 100
- * elements. This is useful for testing, because the prefetch changes the
- * order in which the Promises are resolved along the processing pipeline.
- * This may help expose bugs where results are dependent on the order of
- * Promise resolution rather than on the logical order of the stream (i.e.,
- * due to hidden mutable state).
- *
- * @returns A Promise for an array of stream elements, which will resolve
- * when the stream is exhausted.
- */
- async toArrayForTest() {
- const stream = this.prefetch(100);
- const result = [];
- let x = await stream.next();
- while (!x.done) {
- result.push(x.value);
- x = await stream.next();
- }
- return result;
- }
- /**
- * Draw items from the stream until it is exhausted.
- *
- * This can be useful when the stream has side effects but no output. In
- * that case, calling this function guarantees that the stream will be
- * fully processed.
- */
- async resolveFully() {
- let x = await this.next();
- while (!x.done) {
- x = await this.next();
- }
- }
- /**
- * Draw items from the stream until it is exhausted, or a predicate fails.
- *
- * This can be useful when the stream has side effects but no output. In
- * that case, calling this function guarantees that the stream will be
- * fully processed.
- */
- async resolveWhile(predicate) {
- let x = await this.next();
- let shouldContinue = predicate(x.value);
- while ((!x.done) && shouldContinue) {
- x = await this.next();
- shouldContinue = predicate(x.value);
- }
- }
- /**
- * Handles errors thrown on this stream using a provided handler function.
- *
- * @param handler A function that handles any `Error` thrown during a `next()`
- * call and returns true if the stream should continue (dropping the failed
- * call) or false if the stream should quietly terminate. If the handler
- * itself throws (or rethrows) an `Error`, that will be propagated.
- *
- * @returns A `LazyIterator` of elements passed through from upstream,
- * possibly filtering or terminating on upstream `next()` calls that
- * throw an `Error`.
- */
- handleErrors(handler) {
- return new ErrorHandlingLazyIterator(this, handler);
- }
- // TODO(soergel): Implement reduce() etc.
- /**
- * Filters this stream according to `predicate`.
- *
- * @param predicate A function mapping a stream element to a boolean or a
- * `Promise` for one.
- *
- * @returns A `LazyIterator` of elements for which the predicate was true.
- */
- filter(predicate) {
- return new FilterIterator(this, predicate);
- }
- /**
- * Maps this stream through a 1-to-1 transform.
- *
- * @param transform A function mapping a stream element to a transformed
- * element.
- *
- * @returns A `LazyIterator` of transformed elements.
- */
- map(transform) {
- return new MapIterator(this, transform);
- }
- /**
- * Maps this stream through an async 1-to-1 transform.
- *
- * @param transform A function mapping a stream element to a `Promise` for a
- * transformed stream element.
- *
- * @returns A `LazyIterator` of transformed elements.
- */
- mapAsync(transform) {
- return new AsyncMapIterator(this, transform);
- }
- /**
- * Maps this stream through a 1-to-1 transform, forcing serial execution.
- *
- * @param transform A function mapping a stream element to a transformed
- * element.
- *
- * @returns A `LazyIterator` of transformed elements.
- */
- serialMapAsync(transform) {
- return new AsyncMapIterator(this, transform).serial();
- }
- /**
- * Maps this stream through a 1-to-many transform.
- *
- * @param transform A function mapping a stream element to an array of
- * transformed elements.
- *
- * @returns A `DataStream` of transformed elements.
- */
- flatmap(transform) {
- return new FlatmapIterator(this, transform);
- }
- /**
- * Apply a function to every element of the stream.
- *
- * @param f A function to apply to each stream element.
- */
- async forEachAsync(f) {
- return this.map(f).resolveFully();
- }
- /**
- * Apply a function to every element of the stream, forcing serial execution.
- *
- * @param f A function to apply to each stream element. Should return 'true'
- * to indicate that the stream should continue, or 'false' to cause it to
- * terminate.
- */
- async serialForEach(f) {
- return this.serialMapAsync(f).resolveWhile(x => (x === true));
- }
- /**
- * Groups elements into batches, represented as arrays of elements.
- *
- * We can think of the elements of this iterator as 'rows' (even if they are
- * nested structures). By the same token, consecutive values for a given
- * key within the elements form a 'column'. This matches the usual sense of
- * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
- *
- * Thus, "Row-major" means that the resulting batch is simply a collection of
- * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
- * form, which is needed for vectorized computation.
- *
- * @param batchSize The number of elements desired per batch.
- * @param smallLastBatch Whether to emit the final batch when it has fewer
- * than batchSize elements. Default true.
- * @returns A `LazyIterator` of batches of elements, represented as arrays
- * of the original element type.
- */
- rowMajorBatch(batchSize, smallLastBatch = true) {
- return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
- }
- /**
- * Groups elements into batches, represented in column-major form.
- *
- * We can think of the elements of this iterator as 'rows' (even if they are
- * nested structures). By the same token, consecutive values for a given
- * key within the elements form a 'column'. This matches the usual sense of
- * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
- *
- * Thus, "column-major" means that the resulting batch is a (potentially
- * nested) structure representing the columns. Each column entry, then,
- * contains a collection of the values found in that column for a range of
- * input elements. This representation allows for vectorized computation, in
- * contrast to the row-major form.
- *
- * The inputs should all have the same nested structure (i.e., of arrays and
- * dicts). The result is a single object with the same nested structure,
- * where the leaves are arrays collecting the values of the inputs at that
- * location (or, optionally, the result of a custom function applied to those
- * arrays).
- *
- * @param batchSize The number of elements desired per batch.
- * @param smallLastBatch Whether to emit the final batch when it has fewer
- * than batchSize elements. Default true.
- * @param zipFn: (optional) A function that expects an array of elements at a
- * single node of the object tree, and returns a `DeepMapResult`. The
- * `DeepMapResult` either provides a result value for that node (i.e.,
- * representing the subtree), or indicates that the node should be processed
- * recursively. The default zipFn recurses as far as possible and places
- * arrays at the leaves.
- * @returns A `LazyIterator` of batches of elements, represented as an object
- * with collections at the leaves.
- */
- columnMajorBatch(batchSize, smallLastBatch = true,
- // tslint:disable-next-line:no-any
- zipFn = zipToList) {
- // First collect the desired number of input elements as a row-major batch.
- const rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
- // Now 'rotate' or 'pivot' the data, collecting all values from each column
- // in the batch (i.e., for each key within the elements) into an array.
- return rowBatches.map(x => deepZip(x, zipFn));
- }
- /**
- * Concatenate this `LazyIterator` with another.
- *
- * @param iterator A `LazyIterator` to be concatenated onto this one.
- * @param baseErrorHandler An optional function that can intercept `Error`s
- * raised during a `next()` call on the base stream. This function can
- * decide whether the error should be propagated, whether the error should
- * be ignored, or whether the base stream should be terminated.
- * @returns A `LazyIterator`.
- */
- concatenate(iterator, baseErrorHandler) {
- return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
- }
- /**
- * Limits this stream to return at most `count` items.
- *
- * @param count The maximum number of items to provide from the stream. If
- * a negative or undefined value is given, the entire stream is returned
- * unaltered.
- */
- take(count) {
- if (count < 0 || count == null) {
- return this;
- }
- return new TakeIterator(this, count);
- }
- /**
- * Skips the first `count` items in this stream.
- *
- * @param count The number of items to skip. If a negative or undefined
- * value is given, the entire stream is returned unaltered.
- */
- skip(count) {
- if (count < 0 || count == null) {
- return this;
- }
- return new SkipIterator(this, count);
- }
- /**
- * Prefetch the first `bufferSize` items in this stream.
- *
- * Note this prefetches Promises, but makes no guarantees about when those
- * Promises resolve.
- *
- * @param bufferSize: An integer specifying the number of elements to be
- * prefetched.
- */
- prefetch(bufferSize) {
- return new PrefetchIterator(this, bufferSize);
- }
- // TODO(soergel): deep sharded shuffle, where supported
- /**
- * Randomly shuffles the elements of this stream.
- *
- * @param bufferSize: An integer specifying the number of elements from
- * this stream from which the new stream will sample.
- * @param seed: (Optional.) An integer specifying the random seed that
- * will be used to create the distribution.
- */
- shuffle(windowSize, seed) {
- return new ShuffleIterator(this, windowSize, seed);
- }
- /**
- * Force an iterator to execute serially: each next() call will await the
- * prior one, so that they cannot execute concurrently.
- */
- serial() {
- return new SerialIterator(this);
- }
- }
- // ============================================================================
- // The following private classes serve to implement the chainable methods
- // on LazyIterator. Unfortunately they can't be placed in separate files,
- // due to resulting trouble with circular imports.
- // ============================================================================
- // Iterators that just extend LazyIterator directly
- // ============================================================================
- class ArrayIterator extends LazyIterator {
- constructor(items) {
- super();
- this.items = items;
- this.trav = 0;
- }
- summary() {
- return `Array of ${this.items.length} items`;
- }
- async next() {
- if (this.trav >= this.items.length) {
- return { value: null, done: true };
- }
- const item = this.items[this.trav];
- this.trav++;
- return { value: deepClone(item), done: false };
- }
- }
- class FunctionCallIterator extends LazyIterator {
- constructor(nextFn) {
- super();
- this.nextFn = nextFn;
- }
- summary() {
- return `Function call`;
- }
- async next() {
- try {
- return this.nextFn();
- }
- catch (e) {
- // Modify the error message but leave the stack trace intact
- e.message =
- `Error thrown while iterating through a dataset: ${e.message}`;
- throw e;
- }
- }
- }
- class SerialIterator extends LazyIterator {
- constructor(upstream) {
- super();
- this.upstream = upstream;
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- summary() {
- return `${this.upstream.summary()} -> Serial`;
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- return this.upstream.next();
- }
- }
- class SkipIterator extends LazyIterator {
- constructor(upstream, maxCount) {
- super();
- this.upstream = upstream;
- this.maxCount = maxCount;
- // Local state that should not be clobbered by out-of-order execution.
- this.count = 0;
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- summary() {
- return `${this.upstream.summary()} -> Skip`;
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- // TODO(soergel): consider tradeoffs of reading in parallel, eg.
- // collecting next() promises in an Array and then waiting for
- // Promise.all() of those. Benefit: pseudo-parallel execution. Drawback:
- // maybe delayed GC.
- while (this.count++ < this.maxCount) {
- const skipped = await this.upstream.next();
- // short-circuit if upstream is already empty
- if (skipped.done) {
- return skipped;
- }
- dispose(skipped.value);
- }
- return this.upstream.next();
- }
- }
- class TakeIterator extends LazyIterator {
- constructor(upstream, maxCount) {
- super();
- this.upstream = upstream;
- this.maxCount = maxCount;
- this.count = 0;
- }
- summary() {
- return `${this.upstream.summary()} -> Take`;
- }
- async next() {
- if (this.count++ >= this.maxCount) {
- return { value: null, done: true };
- }
- return this.upstream.next();
- }
- }
- // Note this batch just groups items into row-wise element arrays.
- // Rotating these to a column-wise representation happens only at the dataset
- // level.
- class RowMajorBatchIterator extends LazyIterator {
- constructor(upstream, batchSize, enableSmallLastBatch = true) {
- super();
- this.upstream = upstream;
- this.batchSize = batchSize;
- this.enableSmallLastBatch = enableSmallLastBatch;
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- summary() {
- return `${this.upstream.summary()} -> RowMajorBatch`;
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- const batch = [];
- while (batch.length < this.batchSize) {
- const item = await this.upstream.next();
- if (item.done) {
- if (this.enableSmallLastBatch && batch.length > 0) {
- return { value: batch, done: false };
- }
- return { value: null, done: true };
- }
- batch.push(item.value);
- }
- return { value: batch, done: false };
- }
- }
- class FilterIterator extends LazyIterator {
- constructor(upstream, predicate) {
- super();
- this.upstream = upstream;
- this.predicate = predicate;
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- summary() {
- return `${this.upstream.summary()} -> Filter`;
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- while (true) {
- const item = await this.upstream.next();
- if (item.done || this.predicate(item.value)) {
- return item;
- }
- dispose(item.value);
- }
- }
- }
- class MapIterator extends LazyIterator {
- constructor(upstream, transform) {
- super();
- this.upstream = upstream;
- this.transform = transform;
- }
- summary() {
- return `${this.upstream.summary()} -> Map`;
- }
- async next() {
- const item = await this.upstream.next();
- if (item.done) {
- return { value: null, done: true };
- }
- const inputTensors = getTensorsInContainer(item.value);
- // Careful: the transform may mutate the item in place.
- // That's why we have to remember the input Tensors above, and then
- // below dispose only those that were not passed through to the output.
- // Note too that the transform function is responsible for tidying
- // any intermediate Tensors. Here we are concerned only about the
- // inputs.
- const mapped = this.transform(item.value);
- const outputTensors = getTensorsInContainer(mapped);
- // TODO(soergel) faster intersection
- // TODO(soergel) move to tf.disposeExcept(in, out)?
- for (const t of inputTensors) {
- if (!isTensorInList(t, outputTensors)) {
- t.dispose();
- }
- }
- return { value: mapped, done: false };
- }
- }
- class ErrorHandlingLazyIterator extends LazyIterator {
- constructor(upstream, handler) {
- super();
- this.upstream = upstream;
- this.handler = handler;
- this.count = 0;
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- summary() {
- return `${this.upstream.summary()} -> handleErrors`;
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- while (true) {
- try {
- return await this.upstream.next();
- }
- catch (e) {
- if (!this.handler(e)) {
- return { value: null, done: true };
- }
- // If the handler returns true, loop and fetch the next upstream item.
- // If the upstream iterator throws an endless stream of errors, and if
- // the handler says to ignore them, then we loop forever here. That is
- // the correct behavior-- it's up to the handler to decide when to stop.
- }
- }
- }
- }
- class AsyncMapIterator extends LazyIterator {
- constructor(upstream, transform) {
- super();
- this.upstream = upstream;
- this.transform = transform;
- }
- summary() {
- return `${this.upstream.summary()} -> AsyncMap`;
- }
- async next() {
- const item = await this.upstream.next();
- if (item.done) {
- return { value: null, done: true };
- }
- const inputTensors = getTensorsInContainer(item.value);
- // Careful: the transform may mutate the item in place.
- // That's why we have to remember the input Tensors above, and then
- // below dispose only those that were not passed through to the output.
- // Note too that the transform function is responsible for tidying
- // any intermediate Tensors. Here we are concerned only about the
- // inputs.
- const mapped = await this.transform(item.value);
- const outputTensors = getTensorsInContainer(mapped);
- // TODO(soergel) faster intersection
- // TODO(soergel) move to tf.disposeExcept(in, out)?
- for (const t of inputTensors) {
- if (!isTensorInList(t, outputTensors)) {
- t.dispose();
- }
- }
- return { value: mapped, done: false };
- }
- }
- // Iterators that maintain a queue of pending items
- // ============================================================================
- /**
- * A base class for transforming streams that operate by maintaining an
- * output queue of elements that are ready to return via next(). This is
- * commonly required when the transformation is 1-to-many: A call to next()
- * may trigger a call to the underlying stream, which will produce many
- * mapped elements of this stream-- of which we need to return only one, so
- * we have to queue the rest.
- */
- class OneToManyIterator extends LazyIterator {
- constructor() {
- super();
- this.outputQueue = new GrowingRingBuffer();
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- async serialNext() {
- // Fetch so that the queue contains at least one item if possible.
- // If the upstream source is exhausted, AND there are no items left in
- // the output queue, then this stream is also exhausted.
- while (this.outputQueue.length() === 0) {
- // TODO(soergel): consider parallel reads.
- if (!await this.pump()) {
- return { value: null, done: true };
- }
- }
- return { value: this.outputQueue.shift(), done: false };
- }
- }
- class FlatmapIterator extends OneToManyIterator {
- constructor(upstream, transform) {
- super();
- this.upstream = upstream;
- this.transform = transform;
- }
- summary() {
- return `${this.upstream.summary()} -> Flatmap`;
- }
- async pump() {
- const item = await this.upstream.next();
- if (item.done) {
- return false;
- }
- const inputTensors = getTensorsInContainer(item.value);
- // Careful: the transform may mutate the item in place.
- // that's why we have to remember the input Tensors above, and then
- // below dispose only those that were not passed through to the output.
- // Note too that the transform function is responsible for tidying any
- // intermediate Tensors. Here we are concerned only about the inputs.
- const mappedArray = this.transform(item.value);
- const outputTensors = getTensorsInContainer(mappedArray);
- this.outputQueue.pushAll(mappedArray);
- // TODO(soergel) faster intersection, and deduplicate outputTensors
- // TODO(soergel) move to tf.disposeExcept(in, out)?
- for (const t of inputTensors) {
- if (!isTensorInList(t, outputTensors)) {
- t.dispose();
- }
- }
- return true;
- }
- }
- /**
- * Provides a `LazyIterator` that concatenates a stream of underlying
- * streams.
- *
- * Doing this in a concurrency-safe way requires some trickery. In
- * particular, we want this stream to return the elements from the
- * underlying streams in the correct order according to when next() was
- * called, even if the resulting Promises resolve in a different order.
- */
- class ChainedIterator extends LazyIterator {
- constructor(iterators, baseErrorHandler) {
- super();
- this.baseErrorHandler = baseErrorHandler;
- // Strict Promise execution order:
- // a next() call may not even begin until the previous one completes.
- this.lastRead = null;
- // Local state that should not be clobbered by out-of-order execution.
- this.iterator = null;
- this.moreIterators = iterators;
- }
- summary() {
- const upstreamSummaries = 'TODO: fill in upstream of chained summaries';
- return `${upstreamSummaries} -> Chained`;
- }
- async next() {
- this.lastRead = this.readFromChain(this.lastRead);
- return this.lastRead;
- }
- async readFromChain(lastRead) {
- // Must await on the previous read since the previous read may have advanced
- // the stream of streams, from which we need to read.
- // This is unfortunate since we can't parallelize reads. Which means
- // prefetching of chained streams is a no-op.
- // One solution is to prefetch immediately upstream of this.
- await lastRead;
- if (this.iterator == null) {
- const iteratorResult = await this.moreIterators.next();
- if (iteratorResult.done) {
- // No more streams to stream from.
- return { value: null, done: true };
- }
- this.iterator = iteratorResult.value;
- if (this.baseErrorHandler != null) {
- this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
- }
- }
- const itemResult = await this.iterator.next();
- if (itemResult.done) {
- this.iterator = null;
- return this.readFromChain(lastRead);
- }
- return itemResult;
- }
- }
- var ZipMismatchMode;
- (function (ZipMismatchMode) {
- ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
- ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
- ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
- })(ZipMismatchMode || (ZipMismatchMode = {}));
- /**
- * Provides a `LazyIterator` that zips together an array, dict, or nested
- * structure of `LazyIterator`s (and perhaps additional constants).
- *
- * The underlying streams must provide elements in a consistent order such
- * that they correspond.
- *
- * Typically, the underlying streams should have the same number of
- * elements. If they do not, the behavior is determined by the
- * `mismatchMode` argument.
- *
- * The nested structure of the `iterators` argument determines the
- * structure of elements in the resulting iterator.
- *
- * Doing this in a concurrency-safe way requires some trickery. In
- * particular, we want this stream to return the elements from the
- * underlying streams in the correct order according to when next() was
- * called, even if the resulting Promises resolve in a different order.
- *
- * @param iterators: An array or object containing LazyIterators at the
- * leaves.
- * @param mismatchMode: Determines what to do when one underlying iterator
- * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
- * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
- * causes the zipped iterator to terminate with the furst underlying
- * streams, so elements remaining on the longer streams are ignored.
- * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
- * in nulls for the exhausted streams, until all streams are exhausted.
- */
- class ZipIterator extends LazyIterator {
- constructor(iterators, mismatchMode = ZipMismatchMode.FAIL) {
- super();
- this.iterators = iterators;
- this.mismatchMode = mismatchMode;
- this.count = 0;
- this.currentPromise = null;
- }
- summary() {
- const upstreamSummaries = 'TODO: fill in upstream of zip summaries';
- return `{${upstreamSummaries}} -> Zip`;
- }
- async nextState(afterState) {
- // This chaining ensures that the underlying next() are not even called
- // before the previous ones have resolved.
- await afterState;
- // Collect underlying iterator "done" signals as a side effect in
- // getNext()
- let numIterators = 0;
- let iteratorsDone = 0;
- function getNext(container) {
- if (container instanceof LazyIterator) {
- const result = container.next();
- return {
- value: result.then(x => {
- numIterators++;
- if (x.done) {
- iteratorsDone++;
- }
- return x.value;
- }),
- recurse: false
- };
- }
- else {
- return { value: null, recurse: true };
- }
- }
- const mapped = await deepMapAndAwaitAll(this.iterators, getNext);
- if (numIterators === iteratorsDone) {
- // The streams have all ended.
- return { value: null, done: true };
- }
- if (iteratorsDone > 0) {
- switch (this.mismatchMode) {
- case ZipMismatchMode.FAIL:
- throw new Error('Zipped streams should have the same length. ' +
- `Mismatched at element ${this.count}.`);
- case ZipMismatchMode.SHORTEST:
- return { value: null, done: true };
- case ZipMismatchMode.LONGEST:
- default:
- // Continue. The exhausted streams already produced value: null.
- }
- }
- this.count++;
- return { value: mapped, done: false };
- }
- async next() {
- this.currentPromise = this.nextState(this.currentPromise);
- return this.currentPromise;
- }
- }
- // Iterators that maintain a ring buffer of pending promises
- // ============================================================================
- /**
- * A stream that prefetches a given number of items from an upstream source,
- * returning them in FIFO order.
- *
- * Note this prefetches Promises, but makes no guarantees about when those
- * Promises resolve.
- */
- class PrefetchIterator extends LazyIterator {
- constructor(upstream, bufferSize) {
- super();
- this.upstream = upstream;
- this.bufferSize = bufferSize;
- this.buffer = new RingBuffer(bufferSize);
- }
- summary() {
- return `${this.upstream.summary()} -> Prefetch`;
- }
- /**
- * Refill the prefetch buffer. Returns only after the buffer is full, or
- * the upstream source is exhausted.
- */
- refill() {
- while (!this.buffer.isFull()) {
- const v = this.upstream.next();
- this.buffer.push(v);
- }
- }
- next() {
- this.refill();
- // This shift will never throw an error because the buffer is always
- // full after a refill. If the stream is exhausted, the buffer will be
- // full of Promises that will resolve to the end-of-stream signal.
- return this.buffer.shift();
- }
- }
- /**
- * A stream that performs a sliding-window random shuffle on an upstream
- * source. This is like a `PrefetchIterator` except that the items are
- * returned in randomized order. Mixing naturally improves as the buffer
- * size increases.
- */
- class ShuffleIterator extends PrefetchIterator {
- constructor(upstream, windowSize, seed) {
- super(upstream, windowSize);
- this.upstream = upstream;
- this.windowSize = windowSize;
- // Local state that should not be clobbered by out-of-order execution.
- this.upstreamExhausted = false;
- this.random = seedrandom_1(seed || now().toString());
- this.lastRead = Promise.resolve({ value: null, done: false });
- }
- async next() {
- // This sets this.lastRead to a new Promise right away, as opposed to
- // saying `await this.lastRead; this.lastRead = this.serialNext();` which
- // would not work because this.nextRead would be updated only after the
- // promise resolves.
- this.lastRead = this.lastRead.then(() => this.serialNext());
- return this.lastRead;
- }
- randomInt(max) {
- return Math.floor(this.random() * max);
- }
- chooseIndex() {
- return this.randomInt(this.buffer.length());
- }
- async serialNext() {
- // TODO(soergel): consider performance
- if (!this.upstreamExhausted) {
- this.refill();
- }
- while (!this.buffer.isEmpty()) {
- const chosenIndex = this.chooseIndex();
- const result = await this.buffer.shuffleExcise(chosenIndex);
- if (result.done) {
- this.upstreamExhausted = true;
- }
- else {
- this.refill();
- return result;
- }
- }
- return { value: null, done: true };
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- // TODO(soergel): consider vectorized operations within the pipeline.
- /**
- * Represents a potentially large list of independent data elements (typically
- * 'samples' or 'examples').
- *
- * A 'data example' may be a primitive, an array, a map from string keys to
- * values, or any nested structure of these.
- *
- * A `Dataset` represents an ordered collection of elements, together with a
- * chain of transformations to be performed on those elements. Each
- * transformation is a method of `Dataset` that returns another `Dataset`, so
- * these may be chained, e.g.
- * `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
- *
- * Data loading and transformation is done in a lazy, streaming fashion. The
- * dataset may be iterated over multiple times; each iteration starts the data
- * loading anew and recapitulates the transformations.
- *
- * A `Dataset` is typically processed as a stream of unbatched examples --i.e.,
- * its transformations are applied one example at a time. Batching produces a
- * new `Dataset` where each element is a batch. Batching should usually come
- * last in a pipeline, because data transformations are easier to express on a
- * per-example basis than on a per-batch basis.
- *
- * The following code examples are calling `await dataset.forEachAsync(...)` to
- * iterate once over the entire dataset in order to print out the data.
- *
- * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
- */
- class Dataset {
- constructor() {
- this.size = null;
- }
- // TODO(soergel): Make Datasets report whether repeated iterator() calls
- // produce the same result (e.g., reading from a file) or different results
- // (e.g., from the webcam). Currently we don't make this distinction but it
- // could be important for the user to know.
- // abstract isDeterministic(): boolean;
- /**
- * Groups elements into batches.
- *
- * It is assumed that each of the incoming dataset elements has the same
- * structure-- i.e. the same set of keys at each location in an object
- * hierarchy. For each key, the resulting `Dataset` provides a batched
- * element collecting all of the incoming values for that key.
- *
- * * Incoming primitives are grouped into a 1-D Tensor.
- * * Incoming Tensors are grouped into a new Tensor where the 0'th axis is
- * the batch dimension.
- * * Incoming arrays are converted to Tensor and then batched.
- * * A nested array is interpreted as an n-D Tensor, so the batched result
- * has n+1 dimensions.
- * * An array that cannot be converted to Tensor produces an error.
- *
- * If an array should not be batched as a unit, it should first be converted
- * to an object with integer keys.
- *
- * Here are a few examples:
- *
- * Batch a dataset of numbers:
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
- * await a.forEachAsync(e => e.print());
- * ```
- *
- * Batch a dataset of arrays:
- * ```js
- * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
- * await b.forEachAsync(e => e.print());
- * ```
- *
- * Batch a dataset of objects:
- * ```js
- * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
- * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
- * {a: 8, b: 18}]).batch(4);
- * await c.forEachAsync(e => {
- * console.log('{');
- * for(var key in e) {
- * console.log(key+':');
- * e[key].print();
- * }
- * console.log('}');
- * })
- * ```
- *
- * @param batchSize The number of elements desired per batch.
- * @param smallLastBatch Whether to emit the final batch when it has fewer
- * than batchSize elements. Default true.
- * @returns A `Dataset`, from which a stream of batches can be obtained.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- batch(batchSize, smallLastBatch = true) {
- const base = this;
- assert(batchSize > 0, () => `batchSize needs to be positive, but it is
- ${batchSize}`);
- let size;
- if (this.size === Infinity || this.size == null) {
- // If the size of this dataset is infinity or null, the new size keeps the
- // same.
- size = this.size;
- }
- else if (smallLastBatch) {
- // If the size of this dataset is known and include small last batch, the
- // new size is full batch count plus last batch.
- size = Math.ceil(this.size / batchSize);
- }
- else {
- // If the size of this dataset is known and not include small last batch,
- // the new size is full batch count.
- size = Math.floor(this.size / batchSize);
- }
- return datasetFromIteratorFn(async () => {
- return (await base.iterator())
- .columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat);
- }, size);
- }
- /**
- * Concatenates this `Dataset` with another.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3]);
- * const b = tf.data.array([4, 5, 6]);
- * const c = a.concatenate(b);
- * await c.forEachAsync(e => console.log(e));
- * ```
- *
- * @param dataset A `Dataset` to be concatenated onto this one.
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- concatenate(dataset) {
- const base = this;
- let size;
- if (this.size === Infinity || dataset.size === Infinity) {
- // If the size of any of these two dataset is infinity, new size is
- // infinity.
- size = Infinity;
- }
- else if (this.size != null && dataset.size != null) {
- // If the size of both datasets are known and not infinity, new size is
- // sum the size of these two datasets.
- size = this.size + dataset.size;
- }
- else {
- // If neither of these two datasets has infinite size and any of these two
- // datasets' size is null, the new size is null.
- size = null;
- }
- return datasetFromIteratorFn(async () => (await base.iterator()).concatenate(await dataset.iterator()), size);
- }
- /**
- * Filters this dataset according to `predicate`.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
- * .filter(x => x%2 === 0);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param predicate A function mapping a dataset element to a boolean or a
- * `Promise` for one.
- *
- * @returns A `Dataset` of elements for which the predicate was true.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- filter(predicate) {
- const base = this;
- let size;
- if (this.size === Infinity) {
- // If the size of this dataset is infinity, new size is infinity
- size = Infinity;
- }
- else {
- // If this dataset has limited elements, new size is null because it might
- // exhausted randomly.
- size = null;
- }
- return datasetFromIteratorFn(async () => {
- return (await base.iterator()).filter(x => tidy(() => predicate(x)));
- }, size);
- }
- /**
- * Apply a function to every element of the dataset.
- *
- * After the function is applied to a dataset element, any Tensors contained
- * within that element are disposed.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3]);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param f A function to apply to each dataset element.
- * @returns A `Promise` that resolves after all elements have been processed.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- async forEachAsync(f) {
- return (await this.iterator()).forEachAsync(f);
- }
- /**
- * Maps this dataset through a 1-to-1 transform.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3]).map(x => x*x);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param transform A function mapping a dataset element to a transformed
- * dataset element.
- *
- * @returns A `Dataset` of transformed elements.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- map(transform) {
- const base = this;
- return datasetFromIteratorFn(async () => {
- return (await base.iterator()).map(x => tidy(() => transform(x)));
- }, this.size);
- }
- /**
- * Maps this dataset through an async 1-to-1 transform.
- *
- * ```js
- * const a =
- * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
- * setTimeout(() => {
- * resolve(x * x);
- * }, Math.random()*1000 + 500);
- * }));
- * console.log(await a.toArray());
- * ```
- *
- * @param transform A function mapping a dataset element to a `Promise` for a
- * transformed dataset element. This transform is responsible for disposing
- * any intermediate `Tensor`s, i.e. by wrapping its computation in
- * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
- * `map()` case).
- *
- * @returns A `Dataset` of transformed elements.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- mapAsync(transform) {
- const base = this;
- return datasetFromIteratorFn(async () => {
- return (await base.iterator()).mapAsync(transform);
- }, this.size);
- }
- /**
- * Creates a `Dataset` that prefetches elements from this dataset.
- *
- * @param bufferSize: An integer specifying the number of elements to be
- * prefetched.
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- prefetch(bufferSize) {
- if (bufferSize == null) {
- throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
- }
- const base = this;
- return datasetFromIteratorFn(async () => (await base.iterator()).prefetch(bufferSize), this.size);
- }
- /**
- * Repeats this dataset `count` times.
- *
- * NOTE: If this dataset is a function of global state (e.g. a random number
- * generator), then different repetitions may produce different elements.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3]).repeat(3);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param count: (Optional) An integer, representing the number of times
- * the dataset should be repeated. The default behavior (if `count` is
- * `undefined` or negative) is for the dataset be repeated indefinitely.
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- repeat(count) {
- const base = this;
- let size;
- if (this.size != null && count > 0) {
- // If this dataset has size and count is positive, new size is current
- // size multiply count. This also covers the case that current size is
- // infinity.
- size = this.size * count;
- }
- else if (count === 0) {
- // If count is 0, new size is 0.
- size = 0;
- }
- else if (this.size != null && (count === undefined || count < 0)) {
- // If this dataset has size and count is undefined or negative, the
- // dataset will be repeated indefinitely and new size is infinity.
- size = Infinity;
- }
- else {
- // If the size of this dataset is null, the new dataset's size is null.
- size = null;
- }
- return datasetFromIteratorFn(async () => {
- const iteratorIterator = iteratorFromFunction(async () => ({ value: await base.iterator(), done: false }));
- return iteratorFromConcatenated(iteratorIterator.take(count));
- }, size);
- }
- /**
- * Creates a `Dataset` that skips `count` initial elements from this dataset.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param count: The number of elements of this dataset that should be skipped
- * to form the new dataset. If `count` is greater than the size of this
- * dataset, the new dataset will contain no elements. If `count`
- * is `undefined` or negative, skips the entire dataset.
- *
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- skip(count) {
- const base = this;
- let size;
- if (this.size != null && count >= 0 && this.size >= count) {
- // If the size of this dataset is greater than count, the new dataset's
- // size is current size minus skipped size.This also covers the case that
- // current size is infinity.
- size = this.size - count;
- }
- else if (this.size != null &&
- (this.size < count || count === undefined || count < 0)) {
- // If the size of this dataset is smaller than count, or count is
- // undefined or negative, skips the entire dataset and the new size is 0.
- size = 0;
- }
- else {
- // If the size of this dataset is null, the new dataset's size is null.
- size = null;
- }
- return datasetFromIteratorFn(async () => (await base.iterator()).skip(count), size);
- }
- /**
- * Pseudorandomly shuffles the elements of this dataset. This is done in a
- * streaming manner, by sampling from a given number of prefetched elements.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param bufferSize: An integer specifying the number of elements from this
- * dataset from which the new dataset will sample.
- * @param seed: (Optional) An integer specifying the random seed that will
- * be used to create the distribution.
- * @param reshuffleEachIteration: (Optional) A boolean, which if true
- * indicates that the dataset should be pseudorandomly reshuffled each time
- * it is iterated over. If false, elements will be returned in the same
- * shuffled order on each iteration. (Defaults to `true`.)
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- shuffle(bufferSize, seed, reshuffleEachIteration = true) {
- if (bufferSize == null || bufferSize < 0) {
- if (this.size == null) {
- throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
- }
- else {
- throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' +
- 'If your data fits in main memory (for regular JS objects), ' +
- 'and/or GPU memory (for `tf.Tensor`s), consider setting ' +
- `bufferSize to the dataset size (${this.size} elements)`);
- }
- }
- const base = this;
- const random = seedrandom_1(seed || now().toString());
- return datasetFromIteratorFn(async () => {
- let seed2 = random.int32();
- if (reshuffleEachIteration) {
- seed2 += random.int32();
- }
- return (await base.iterator()).shuffle(bufferSize, seed2.toString());
- }, this.size);
- }
- /**
- * Creates a `Dataset` with at most `count` initial elements from this
- * dataset.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * @param count: The number of elements of this dataset that should be taken
- * to form the new dataset. If `count` is `undefined` or negative, or if
- * `count` is greater than the size of this dataset, the new dataset will
- * contain all elements of this dataset.
- * @returns A `Dataset`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- take(count) {
- const base = this;
- let size;
- if (this.size != null && this.size > count) {
- // If the size of this dataset is greater than count, the new dataset's
- // size is count.
- size = count;
- }
- else if (this.size != null && this.size <= count) {
- // If the size of this dataset is equal or smaller than count, the new
- // dataset's size is the size of this dataset.
- size = this.size;
- }
- else {
- // If the size of this dataset is null, the new dataset's size is null.
- size = null;
- }
- return datasetFromIteratorFn(async () => (await base.iterator()).take(count), size);
- }
- /**
- * Collect all elements of this dataset into an array.
- *
- * Obviously this will succeed only for small datasets that fit in memory.
- * Useful for testing and generally should be avoided if possible.
- *
- * ```js
- * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
- * console.log(await a.toArray());
- * ```
- *
- * @returns A Promise for an array of elements, which will resolve
- * when a new stream has been obtained and fully consumed.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- async toArray() {
- if (this.size === Infinity) {
- throw new Error('Can not convert infinite data stream to array.');
- }
- return (await this.iterator()).toArray();
- }
- /**
- * Collect all elements of this dataset into an array with prefetching 100
- * elements. This is useful for testing, because the prefetch changes the
- * order in which the Promises are resolved along the processing pipeline.
- * This may help expose bugs where results are dependent on the order of
- * Promise resolution rather than on the logical order of the stream (i.e.,
- * due to hidden mutable state).
- *
- * @returns A Promise for an array of elements, which will resolve
- * when a new stream has been obtained and fully consumed.
- */
- async toArrayForTest() {
- if (this.size === Infinity) {
- throw new Error('Can not convert infinite data stream to array.');
- }
- return (await this.iterator()).toArrayForTest();
- }
- }
- // TODO(soergel): deep sharded shuffle, where supported
- Dataset.MAX_BUFFER_SIZE = 10000;
- /**
- * Create a `Dataset` defined by a provided iterator() function.
- *
- * ```js
- * let i = -1;
- * const func = () =>
- * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
- * const iter = tf.data.iteratorFromFunction(func);
- * const ds = tf.data.datasetFromIteratorFn(iter);
- * await ds.forEachAsync(e => console.log(e));
- * ```
- */
- function datasetFromIteratorFn(iteratorFn, size = null) {
- return new class extends Dataset {
- constructor() {
- super(...arguments);
- this.size = size;
- }
- /*
- * Provide a new stream of elements. Note this will also start new streams
- * from any underlying `Dataset`s.
- */
- async iterator() {
- return iteratorFn();
- }
- }();
- }
- /**
- * Create a `Dataset` from an array of elements.
- *
- * Create a Dataset from an array of objects:
- * ```js
- * const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
- * await a.forEachAsync(e => console.log(e));
- * ```
- *
- * Create a Dataset from an array of numbers:
- * ```js
- * const a = tf.data.array([4, 5, 6]);
- * await a.forEachAsync(e => console.log(e));
- * ```
- * @param items An array of elements that will be parsed as items in a dataset.
- *
- * @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
- */
- function array(items) {
- return datasetFromIteratorFn(async () => iteratorFromItems(items), items.length);
- }
- /**
- * Create a `Dataset` by zipping together an array, dict, or nested
- * structure of `Dataset`s (and perhaps additional constants).
- * The underlying datasets must provide elements in a consistent order such that
- * they correspond.
- *
- * The number of elements in the resulting dataset is the same as the size of
- * the smallest dataset in datasets.
- *
- * The nested structure of the `datasets` argument determines the
- * structure of elements in the resulting iterator.
- *
- * Note this means that, given an array of two datasets that produce dict
- * elements, the result is a dataset that produces elements that are arrays
- * of two dicts:
- *
- * Zip an array of datasets:
- * ```js
- * console.log('Zip two datasets of objects:');
- * const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
- * const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
- * const ds3 = tf.data.zip([ds1, ds2]);
- * await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
- *
- * // If the goal is to merge the dicts in order to produce elements like
- * // {a: ..., b: ...}, this requires a second step such as:
- * console.log('Merge the objects:');
- * const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
- * await ds4.forEachAsync(e => console.log(e));
- * ```
- *
- * Zip a dict of datasets:
- * ```js
- * const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
- * const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
- * const c = tf.data.zip({c: a, d: b});
- * await c.forEachAsync(e => console.log(JSON.stringify(e)));
- * ```
- *
- * @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
- */
- function zip(datasets) {
- // manually type-check the argument for JS users
- if (!isIterable$1(datasets)) {
- throw new Error('The argument to zip() must be an object or array.');
- }
- let size;
- if (Array.isArray(datasets)) {
- for (let i = 0; i < datasets.length; i++) {
- size = size == null ? datasets[i].size :
- Math.min(size, datasets[i].size);
- }
- }
- else if (datasets instanceof Object) {
- for (const ds in datasets) {
- size = size == null ? datasets[ds].size :
- Math.min(size, datasets[ds].size);
- }
- }
- return datasetFromIteratorFn(async () => {
- const streams = await deepMapAndAwaitAll(datasets, d => {
- if (d instanceof Dataset) {
- return { value: d.iterator(), recurse: false };
- }
- else if (isIterable$1(d)) {
- return { value: null, recurse: true };
- }
- else {
- throw new Error('Leaves of the structure passed to zip() must be Datasets, ' +
- 'not primitives.');
- }
- });
- return iteratorFromZipped(streams, ZipMismatchMode.SHORTEST);
- }, size);
- }
- /**
- * A zip function for use with deepZip, passed via the columnMajorBatch call.
- *
- * Accepts an array of identically-structured nested elements and either batches
- * them (if they are primitives, numeric arrays, or Tensors) or requests
- * recursion (if not).
- */
- // tslint:disable-next-line:no-any
- function deepBatchConcat(rows) {
- if (rows === null) {
- return null;
- }
- // use the first item to decide whether to recurse or batch here.
- const exampleRow = rows[0];
- if (canTensorify(exampleRow)) {
- // rows is an array of primitives, Tensors, or arrays. Batch them.
- const value = batchConcat(rows);
- return { value, recurse: false };
- }
- // the example row is an object, so recurse into it.
- return { value: null, recurse: true };
- }
- /**
- * Assembles a list of same-shaped numbers, number arrays, or Tensors
- * into a single new Tensor where axis 0 is the batch dimension.
- */
- function batchConcat(arrays) {
- if (arrays.length === 0) {
- // We can't return an empty Tensor because we don't know the element shape.
- throw new Error('Can\'t make a batch of zero elements.');
- }
- if (arrays[0] instanceof Tensor) {
- // Input is an array of Tensors
- return stack(arrays);
- }
- else {
- // Input is a possibly-nested array of numbers.
- return tensor(arrays);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Represents a potentially large collection of text lines.
- *
- * The results are not batched.
- */
- class TextLineDataset extends Dataset {
- /**
- * Create a `TextLineDataset`.
- *
- * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
- */
- constructor(input) {
- super();
- this.input = input;
- }
- async iterator() {
- const inputIterator = await this.input.iterator();
- const utf8Iterator = inputIterator.decodeUTF8();
- const lineIterator = utf8Iterator.split('\n').map(line => {
- // Windows/DOS format text file has extra line breaker at the end of line.
- if (line.endsWith('\r')) {
- line = line.slice(0, -1);
- }
- return line;
- });
- return lineIterator;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- const CODE_QUOTE = '"';
- const STATE_OUT = Symbol('out');
- const STATE_FIELD = Symbol('field');
- const STATE_QUOTE = Symbol('quote');
- const STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
- const STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
- /**
- * Represents a potentially large collection of delimited text records.
- *
- * The produced `TensorContainer`s each contain one key-value pair for
- * every column of the table. When a field is empty in the incoming data, the
- * resulting value is `undefined`, or throw error if it is required. Values
- * that can be parsed as numbers are emitted as type `number`, other values
- * are parsed as `string`.
- *
- * The results are not batched.
- *
- * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
- */
- class CSVDataset extends Dataset {
- /**
- * Create a `CSVDataset`.
- *
- * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
- * @param csvConfig (Optional) A CSVConfig object that contains configurations
- * of reading and decoding from CSV file(s).
- *
- * hasHeader: (Optional) A boolean value that indicates whether the first
- * row of provided CSV file is a header line with column names, and should
- * not be included in the data. Defaults to `true`.
- *
- * columnNames: (Optional) A list of strings that corresponds to
- * the CSV column names, in order. If provided, it ignores the column
- * names inferred from the header row. If not provided, infers the column
- * names from the first row of the records. If hasHeader is false and
- * columnNames is not provided, this method throws an error.
- *
- * columnConfigs: (Optional) A dictionary whose key is column names, value
- * is an object stating if this column is required, column's data type,
- * default value, and if this column is label. If provided, keys must
- * correspond to names provided in columnNames or inferred from the file
- * header lines. If isLabel is true any column, returns an array of two
- * items: the first item is a dict of features key/value pairs, the second
- * item is a dict of labels key/value pairs. If no feature is marked as
- * label, returns a dict of features only.
- *
- * configuredColumnsOnly (Optional) If true, only columns provided in
- * columnConfigs will be parsed and provided during iteration.
- *
- * delimiter (Optional) The string used to parse each line of the input
- * file. Defaults to `,`.
- */
- constructor(input, csvConfig) {
- super();
- this.input = input;
- this.hasHeader = true;
- this.fullColumnNames = null;
- this.columnNamesValidated = false;
- this.columnConfigs = null;
- this.configuredColumnsOnly = false;
- this.delimiter = ',';
- this.delimWhitespace = false;
- this.base = new TextLineDataset(input);
- if (!csvConfig) {
- csvConfig = {};
- }
- this.hasHeader = csvConfig.hasHeader === false ? false : true;
- this.fullColumnNames = csvConfig.columnNames;
- this.columnConfigs = csvConfig.columnConfigs;
- this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
- if (csvConfig.delimWhitespace) {
- assert(csvConfig.delimiter == null, () => 'Delimiter should not be provided when delimWhitespace is true.');
- this.delimWhitespace = true;
- this.delimiter = ' ';
- }
- else {
- this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
- }
- }
- /**
- * Returns column names of the csv dataset. If `configuredColumnsOnly` is
- * true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
- * false and `columnNames` is provided, `columnNames`. If
- * `configuredColumnsOnly` is false and `columnNames` is not provided, return
- * all column names parsed from the csv file. For example usage please go to
- * `tf.data.csv`.
- *
- * @doc {heading: 'Data', subheading: 'Classes'}
- */
- async columnNames() {
- if (!this.columnNamesValidated) {
- await this.setColumnNames();
- }
- return this.configuredColumnsOnly ? Object.keys(this.columnConfigs) :
- this.fullColumnNames;
- }
- /* 1) If `columnNames` is provided as string[], use this string[] as output
- * keys in corresponding order. The length must match the number of inferred
- * columns if `hasHeader` is true .
- * 2) If `columnNames` is not provided, parse header line as `columnNames` if
- * hasHeader is true. If `hasHeader` is false, throw an error.
- * 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
- * exist in parsed `columnNames`.
- */
- async setColumnNames() {
- const columnNamesFromFile = await this.maybeReadHeaderLine();
- if (!this.fullColumnNames && !columnNamesFromFile) {
- // Throw an error if columnNames is not provided and no header line.
- throw new Error('Column names must be provided if there is no header line.');
- }
- else if (this.fullColumnNames && columnNamesFromFile) {
- // Check provided columnNames match header line.
- assert(columnNamesFromFile.length === this.fullColumnNames.length, () => 'The length of provided columnNames (' +
- this.fullColumnNames.length.toString() +
- ') does not match the length of the header line read from ' +
- 'file (' + columnNamesFromFile.length.toString() + ').');
- }
- if (!this.fullColumnNames) {
- this.fullColumnNames = columnNamesFromFile;
- }
- // Check if there are duplicate column names.
- const counts = this.fullColumnNames.reduce((countAcc, name) => {
- countAcc[name] = (countAcc[name] + 1) || 1;
- return countAcc;
- }, {});
- const duplicateNames = Object.keys(counts).filter((name) => (counts[name] > 1));
- assert(duplicateNames.length === 0, () => 'Duplicate column names found: ' + duplicateNames.toString());
- // Check if keys in columnConfigs match columnNames.
- if (this.columnConfigs) {
- for (const key of Object.keys(this.columnConfigs)) {
- const index = this.fullColumnNames.indexOf(key);
- if (index === -1) {
- throw new Error('The key "' + key +
- '" provided in columnConfigs does not match any of the column ' +
- 'names (' + this.fullColumnNames.toString() + ').');
- }
- }
- }
- this.columnNamesValidated = true;
- }
- async maybeReadHeaderLine() {
- if (this.hasHeader) {
- const iter = await this.base.iterator();
- const firstElement = await iter.next();
- if (firstElement.done) {
- throw new Error('No data was found for CSV parsing.');
- }
- const firstLine = firstElement.value;
- const headers = this.parseRow(firstLine, false);
- return headers;
- }
- else {
- return null;
- }
- }
- async iterator() {
- if (!this.columnNamesValidated) {
- await this.setColumnNames();
- }
- let lines = await this.base.iterator();
- if (this.hasHeader) {
- // We previously read the first line to get the columnNames.
- // Now that we're providing data, skip it.
- lines = lines.skip(1);
- }
- return lines.map(x => this.makeDataElement(x));
- }
- makeDataElement(line) {
- const values = this.parseRow(line);
- const features = {};
- const labels = {};
- for (let i = 0; i < this.fullColumnNames.length; i++) {
- const key = this.fullColumnNames[i];
- const config = this.columnConfigs ? this.columnConfigs[key] : null;
- if (this.configuredColumnsOnly && !config) {
- // This column is not selected.
- continue;
- }
- else {
- const value = values[i];
- let parsedValue = null;
- if (value === '') {
- // If default value is provided, use it. If default value is not
- // provided, set as undefined.
- if (config && config.default !== undefined) {
- parsedValue = config.default;
- }
- else if (config && (config.required || config.isLabel)) {
- throw new Error(`Required column ${key} is empty in this line: ${line}`);
- }
- else {
- parsedValue = undefined;
- }
- }
- else {
- // A value is present, so parse it based on type
- const valueAsNum = Number(value);
- if (isNaN(valueAsNum)) {
- // The value is a string and this column is declared as boolean
- // in config, parse it as boolean.
- if (config && config.dtype === 'bool') {
- parsedValue = this.getBoolean(value);
- }
- else {
- // Set value as string
- parsedValue = value;
- }
- }
- else if (!config || !config.dtype) {
- // If this value is a number and no type config is provided, return
- // it as number.
- parsedValue = valueAsNum;
- }
- else {
- // If this value is a number and data type is provided, parse it
- // according to provided data type.
- switch (config.dtype) {
- case 'float32':
- parsedValue = valueAsNum;
- break;
- case 'int32':
- parsedValue = Math.floor(valueAsNum);
- break;
- case 'bool':
- parsedValue = this.getBoolean(value);
- break;
- default:
- parsedValue = valueAsNum;
- }
- }
- }
- // Check if this column is label.
- (config && config.isLabel) ? labels[key] = parsedValue :
- features[key] = parsedValue;
- }
- }
- // If label exists, return an object of features and labels as {xs:features,
- // ys:labels}, otherwise return features only.
- if (Object.keys(labels).length === 0) {
- return features;
- }
- else {
- return { xs: features, ys: labels };
- }
- }
- getBoolean(value) {
- if (value === '1' || value.toLowerCase() === 'true') {
- return 1;
- }
- else {
- return 0;
- }
- }
- // adapted from https://beta.observablehq.com/@mbostock/streaming-csv
- parseRow(line, validateElementCount = true) {
- const result = [];
- let readOffset = 0;
- const readLength = line.length;
- let currentState = STATE_OUT;
- // Goes through the line to parse quote.
- for (let i = 0; i < readLength; i++) {
- switch (currentState) {
- // Before enter a new field
- case STATE_OUT:
- switch (line.charAt(i)) {
- // Enter a quoted field
- case CODE_QUOTE:
- readOffset = i + 1;
- currentState = STATE_QUOTE;
- break;
- // Read an empty field
- case this.delimiter:
- readOffset = i + 1;
- // If delimiter is white space and configured to collapse
- // multiple white spaces, ignore this white space.
- if (this.delimiter === ' ' && this.delimWhitespace) {
- break;
- }
- result.push('');
- currentState = STATE_OUT;
- break;
- // Enter an unquoted field
- default:
- currentState = STATE_FIELD;
- readOffset = i;
- break;
- }
- break;
- // In an unquoted field
- case STATE_FIELD:
- switch (line.charAt(i)) {
- // Exit an unquoted field, add it to result
- case this.delimiter:
- result.push(line.substring(readOffset, i));
- currentState = STATE_OUT;
- readOffset = i + 1;
- break;
- default:
- }
- break;
- // In a quoted field
- case STATE_QUOTE:
- switch (line.charAt(i)) {
- // Read a quote after a quote
- case CODE_QUOTE:
- currentState = STATE_QUOTE_AFTER_QUOTE;
- break;
- default:
- }
- break;
- // This state means it's right after a second quote in a field
- case STATE_QUOTE_AFTER_QUOTE:
- switch (line.charAt(i)) {
- // Finished a quoted field
- case this.delimiter:
- result.push(line.substring(readOffset, i - 1));
- currentState = STATE_OUT;
- readOffset = i + 1;
- break;
- // Finished a quoted part in a quoted field
- case CODE_QUOTE:
- currentState = STATE_QUOTE;
- break;
- // In a quoted part in a quoted field
- default:
- currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
- break;
- }
- break;
- case STATE_WITHIN_QUOTE_IN_QUOTE:
- switch (line.charAt(i)) {
- // Exit a quoted part in a quoted field
- case CODE_QUOTE:
- currentState = STATE_QUOTE;
- break;
- default:
- }
- break;
- default:
- }
- }
- // Adds last item based on if it is quoted.
- if (currentState === STATE_QUOTE_AFTER_QUOTE) {
- result.push(line.substring(readOffset, readLength - 1));
- }
- else {
- result.push(line.substring(readOffset));
- }
- // Check if each row has the same number of elements as column names.
- if (validateElementCount && result.length !== this.fullColumnNames.length) {
- throw new Error(`Invalid row in csv file. Should have ${this.fullColumnNames.length} elements in a row, but got ${result}`);
- }
- return result;
- }
- }
- // TODO(soergel): add more basic datasets for parity with tf.data
- // tf.data.FixedLengthRecordDataset()
- // tf.data.TFRecordDataset()
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Provide a stream of tensors from microphone audio stream. The tensors are
- * representing audio data as frequency-domain spectrogram generated with
- * browser's native FFT. Tensors representing time-domain waveform is available
- * based on configuration. Only works in browser environment.
- */
- class MicrophoneIterator extends LazyIterator {
- constructor(microphoneConfig) {
- super();
- this.microphoneConfig = microphoneConfig;
- this.isClosed = false;
- this.fftSize = microphoneConfig.fftSize || 1024;
- const fftSizeLog2 = Math.log2(this.fftSize);
- if (this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 ||
- !Number.isInteger(fftSizeLog2)) {
- throw new Error(`Invalid fftSize: it must be a power of 2 between ` +
- `2 to 4 and 2 to 14, but got ${this.fftSize}`);
- }
- this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
- this.sampleRateHz = microphoneConfig.sampleRateHz;
- this.columnTruncateLength =
- microphoneConfig.columnTruncateLength || this.fftSize;
- this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
- this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
- this.includeSpectrogram =
- microphoneConfig.includeSpectrogram === false ? false : true;
- this.includeWaveform =
- microphoneConfig.includeWaveform === true ? true : false;
- if (!this.includeSpectrogram && !this.includeWaveform) {
- throw new Error('Both includeSpectrogram and includeWaveform are false. ' +
- 'At least one type of data should be returned.');
- }
- }
- summary() {
- return `microphone`;
- }
- // Construct a MicrophoneIterator and start the audio stream.
- static async create(microphoneConfig = {}) {
- if (env().get('IS_NODE')) {
- throw new Error('microphone API is only supported in browser environment.');
- }
- const microphoneIterator = new MicrophoneIterator(microphoneConfig);
- // Call async function start() to initialize the audio stream.
- await microphoneIterator.start();
- return microphoneIterator;
- }
- // Start the audio stream and FFT.
- async start() {
- try {
- this.stream = await navigator.mediaDevices.getUserMedia({
- audio: this.audioTrackConstraints == null ? true :
- this.audioTrackConstraints,
- video: false
- });
- }
- catch (e) {
- throw new Error(`Error thrown while initializing video stream: ${e.message}`);
- }
- if (!this.stream) {
- throw new Error('Could not obtain audio from microphone.');
- }
- const ctxConstructor =
- // tslint:disable-next-line:no-any
- window.AudioContext || window.webkitAudioContext;
- this.audioContext = new ctxConstructor();
- if (!this.sampleRateHz) {
- // If sample rate is not provided, use the available sample rate on
- // device.
- this.sampleRateHz = this.audioContext.sampleRate;
- }
- else if (this.audioContext.sampleRate !== this.sampleRateHz) {
- throw new Error(`Mismatch in sampling rate: ` +
- `Expected: ${this.sampleRateHz}; ` +
- `Actual: ${this.audioContext.sampleRate}`);
- }
- const streamSource = this.audioContext.createMediaStreamSource(this.stream);
- this.analyser = this.audioContext.createAnalyser();
- this.analyser.fftSize = this.fftSize * 2;
- this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
- streamSource.connect(this.analyser);
- this.freqData = new Float32Array(this.fftSize);
- this.timeData = new Float32Array(this.fftSize);
- return;
- }
- async next() {
- if (this.isClosed) {
- return { value: null, done: true };
- }
- let spectrogramTensor;
- let waveformTensor;
- const audioDataQueue = await this.getAudioData();
- if (this.includeSpectrogram) {
- const freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
- spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
- }
- if (this.includeWaveform) {
- const timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
- waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
- }
- return {
- value: { 'spectrogram': spectrogramTensor, 'waveform': waveformTensor },
- done: false
- };
- }
- // Capture one result from the audio stream, and extract the value from
- // iterator.next() result.
- async capture() {
- return (await this.next()).value;
- }
- async getAudioData() {
- const freqDataQueue = [];
- const timeDataQueue = [];
- let currentFrames = 0;
- return new Promise(resolve => {
- const intervalID = setInterval(() => {
- if (this.includeSpectrogram) {
- this.analyser.getFloatFrequencyData(this.freqData);
- // If the audio stream is initializing, return empty queue.
- if (this.freqData[0] === -Infinity) {
- resolve({ freqDataQueue, timeDataQueue });
- }
- freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
- }
- if (this.includeWaveform) {
- this.analyser.getFloatTimeDomainData(this.timeData);
- timeDataQueue.push(this.timeData.slice());
- }
- // Clean interval and return when all frames have been collected
- if (++currentFrames === this.numFrames) {
- clearInterval(intervalID);
- resolve({ freqDataQueue, timeDataQueue });
- }
- }, this.fftSize / this.sampleRateHz * 1e3);
- });
- }
- // Stop the audio stream and pause the iterator.
- stop() {
- if (!this.isClosed) {
- this.isClosed = true;
- this.analyser.disconnect();
- this.audioContext.close();
- if (this.stream != null && this.stream.getTracks().length > 0) {
- this.stream.getTracks()[0].stop();
- }
- }
- }
- // Override toArray() function to prevent collecting.
- toArray() {
- throw new Error('Can not convert infinite audio stream to array.');
- }
- // Return audio sampling rate in Hz
- getSampleRate() {
- return this.sampleRateHz;
- }
- flattenQueue(queue) {
- const frameSize = queue[0].length;
- const freqData = new Float32Array(queue.length * frameSize);
- queue.forEach((data, i) => freqData.set(data, i * frameSize));
- return freqData;
- }
- getTensorFromAudioDataArray(freqData, shape) {
- const vals = new Float32Array(sizeFromShape(shape));
- // If the data is less than the output shape, the rest is padded with zeros.
- vals.set(freqData, vals.length - freqData.length);
- return tensor(vals, shape);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Provide a stream of image tensors from webcam video stream. Only works in
- * browser environment.
- */
- class WebcamIterator extends LazyIterator {
- constructor(webcamVideoElement, webcamConfig) {
- super();
- this.webcamVideoElement = webcamVideoElement;
- this.webcamConfig = webcamConfig;
- this.isClosed = true;
- this.resize = false;
- if (this.needToResize()) {
- this.resize = true;
- this.cropSize =
- [this.webcamConfig.resizeHeight, this.webcamConfig.resizeWidth];
- this.cropBoxInd = tensor1d([0], 'int32');
- if (this.webcamConfig.centerCrop) {
- // Calculate the box based on resizing shape.
- const widthCroppingRatio = this.webcamConfig.resizeWidth * 1.0 / this.webcamVideoElement.width;
- const heightCroppingRatio = this.webcamConfig.resizeHeight * 1.0 /
- this.webcamVideoElement.height;
- const widthCropStart = (1 - widthCroppingRatio) / 2;
- const heightCropStart = (1 - heightCroppingRatio) / 2;
- const widthCropEnd = widthCropStart + widthCroppingRatio;
- const heightCropEnd = heightCroppingRatio + heightCropStart;
- this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
- }
- else {
- this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
- }
- }
- }
- summary() {
- return `webcam`;
- }
- // Construct a WebcamIterator and start it's video stream.
- static async create(webcamVideoElement, webcamConfig = {}) {
- if (env().get('IS_NODE')) {
- throw new Error('tf.data.webcam is only supported in browser environment.');
- }
- if (!webcamVideoElement) {
- // If webcam video element is not provided, create a hidden video element
- // with provided width and height.
- webcamVideoElement = document.createElement('video');
- if (!webcamConfig.resizeWidth || !webcamConfig.resizeHeight) {
- throw new Error('Please provide webcam video element, or resizeWidth and ' +
- 'resizeHeight to create a hidden video element.');
- }
- webcamVideoElement.width = webcamConfig.resizeWidth;
- webcamVideoElement.height = webcamConfig.resizeHeight;
- }
- const webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig);
- // Call async function to initialize the video stream.
- await webcamIterator.start();
- return webcamIterator;
- }
- // Async function to start video stream.
- async start() {
- if (this.webcamConfig.facingMode) {
- assert((this.webcamConfig.facingMode === 'user') ||
- (this.webcamConfig.facingMode === 'environment'), () => `Invalid webcam facing mode: ${this.webcamConfig.facingMode}. ` +
- `Please provide 'user' or 'environment'`);
- }
- try {
- this.stream = await navigator.mediaDevices.getUserMedia({
- video: {
- deviceId: this.webcamConfig.deviceId,
- facingMode: this.webcamConfig.facingMode ?
- this.webcamConfig.facingMode :
- 'user',
- width: this.webcamVideoElement.width,
- height: this.webcamVideoElement.height
- }
- });
- }
- catch (e) {
- // Modify the error message but leave the stack trace intact
- e.message = `Error thrown while initializing video stream: ${e.message}`;
- throw e;
- }
- if (!this.stream) {
- throw new Error('Could not obtain video from webcam.');
- }
- // Older browsers may not have srcObject
- try {
- this.webcamVideoElement.srcObject = this.stream;
- }
- catch (error) {
- console.log(error);
- this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
- }
- // Start the webcam video stream
- this.webcamVideoElement.play();
- this.isClosed = false;
- return new Promise(resolve => {
- // Add event listener to make sure the webcam has been fully initialized.
- this.webcamVideoElement.onloadedmetadata = () => {
- resolve();
- };
- });
- }
- async next() {
- if (this.isClosed) {
- return { value: null, done: true };
- }
- let img;
- try {
- img = fromPixels(this.webcamVideoElement);
- }
- catch (e) {
- throw new Error(`Error thrown converting video to pixels: ${JSON.stringify(e)}`);
- }
- if (this.resize) {
- try {
- return { value: this.cropAndResizeFrame(img), done: false };
- }
- catch (e) {
- throw new Error(`Error thrown cropping the video: ${e.message}`);
- }
- finally {
- img.dispose();
- }
- }
- else {
- return { value: img, done: false };
- }
- }
- needToResize() {
- // If resizeWidth and resizeHeight are provided, and different from the
- // width and height of original HTMLVideoElement, then resizing and cropping
- // is required.
- if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight &&
- (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth ||
- this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
- return true;
- }
- return false;
- }
- // Cropping and resizing each frame based on config
- cropAndResizeFrame(img) {
- return tidy(() => {
- const expandedImage = img.toFloat().expandDims(0);
- let resizedImage;
- resizedImage = image.cropAndResize(expandedImage, this.cropBox, this.cropBoxInd, this.cropSize, 'bilinear');
- // Extract image from batch cropping.
- const shape = resizedImage.shape;
- return resizedImage.reshape(shape.slice(1));
- });
- }
- // Capture one frame from the video stream, and extract the value from
- // iterator.next() result.
- async capture() {
- return (await this.next()).value;
- }
- // Stop the video stream and pause webcam iterator.
- stop() {
- const tracks = this.stream.getTracks();
- tracks.forEach(track => track.stop());
- try {
- this.webcamVideoElement.srcObject = null;
- }
- catch (error) {
- console.log(error);
- this.webcamVideoElement.src = null;
- }
- this.isClosed = true;
- }
- // Override toArray() function to prevent collecting.
- toArray() {
- throw new Error('Can not convert infinite video stream to array.');
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Represents a data source readable as a stream of binary data chunks.
- *
- * Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
- * provides a means to repeatedly create streams from the underlying data
- * sources.
- */
- class DataSource {
- }
- // TODO(soergel): consider convenience factory functions here
- // in combination with chainable source->dataset above, e.g.:
- // tf.data.url(...).asCsvDataset().shuffle().batch()
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- class StringIterator extends LazyIterator {
- /**
- * Splits a string stream on a given separator.
- *
- * It is assumed that the incoming chunk boundaries have no semantic meaning,
- * so conceptually the incoming stream is treated simply as the concatenation
- * of its elements.
- *
- * The outgoing stream provides chunks corresponding to the results of the
- * standard string split() operation (even if such a chunk spanned incoming
- * chunks). The separators are not included.
- *
- * A typical usage is to split a text file (represented as a stream with
- * arbitrary chunk boundaries) into lines.
- *
- * @param upstream A readable stream of strings that can be treated as
- * concatenated.
- * @param separator A character to split on.
- */
- split(separator) {
- return new SplitIterator(this, separator);
- }
- }
- // ============================================================================
- // The following private classes serve to implement the chainable methods
- // on StringIterator. Unfortunately they can't be placed in separate files, due
- // to resulting trouble with circular imports.
- // ============================================================================
- // We wanted multiple inheritance, e.g.
- // class SplitIterator extends QueueIterator, StringIterator
- // but the TypeScript mixin approach is a bit hacky, so we take this adapter
- // approach instead.
- class SplitIterator extends StringIterator {
- constructor(upstream, separator) {
- super();
- this.upstream = upstream;
- this.impl = new SplitIteratorImpl(upstream, separator);
- }
- summary() {
- return this.impl.summary();
- }
- async next() {
- return this.impl.next();
- }
- }
- class SplitIteratorImpl extends OneToManyIterator {
- constructor(upstream, separator) {
- super();
- this.upstream = upstream;
- this.separator = separator;
- // A partial string at the end of an upstream chunk
- this.carryover = '';
- }
- summary() {
- return `${this.upstream.summary()} -> Split('${this.separator}')`;
- }
- async pump() {
- const chunkResult = await this.upstream.next();
- if (chunkResult.done) {
- if (this.carryover === '') {
- return false;
- }
- // Pretend that the pump succeeded in order to emit the small last batch.
- // The next pump() call will actually fail.
- this.outputQueue.push(this.carryover);
- this.carryover = '';
- return true;
- }
- const lines = chunkResult.value.split(this.separator);
- // Note the behavior: " ab ".split(' ') === ['', 'ab', '']
- // Thus the carryover may be '' if the separator falls on a chunk
- // boundary; this produces the correct result.
- lines[0] = this.carryover + lines[0];
- for (const line of lines.slice(0, -1)) {
- this.outputQueue.push(line);
- }
- this.carryover = lines[lines.length - 1];
- return true;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- class ByteChunkIterator extends LazyIterator {
- /**
- * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
- *
- * The byte arrays producetd from the ByteChunkIterator on which this is
- * called will be interpreted as concatenated. No assumptions are made about
- * the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
- * character may span the boundary between chunks. This naturally happens,
- * for instance, when reading fixed-size byte arrays from a file.
- */
- decodeUTF8() {
- return new Utf8Iterator(this);
- }
- }
- // ============================================================================
- // The following private classes serve to implement the chainable methods
- // on ByteChunkIterator. Unfortunately they can't be placed in separate files,
- // due to resulting trouble with circular imports.
- // ============================================================================
- // We wanted multiple inheritance, e.g.
- // class Utf8Iterator extends QueueIterator, StringIterator
- // but the TypeScript mixin approach is a bit hacky, so we take this adapter
- // approach instead.
- class Utf8Iterator extends StringIterator {
- constructor(upstream) {
- super();
- this.upstream = upstream;
- this.impl = new Utf8IteratorImpl(upstream);
- }
- summary() {
- return this.impl.summary();
- }
- async next() {
- return this.impl.next();
- }
- }
- /**
- * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
- *
- * This is tricky because the incoming byte array boundaries may disrupt a
- * multi-byte UTF8 character. Thus any incomplete character data at the end of
- * a chunk must be carried over and prepended to the next chunk before
- * decoding. Luckily with native decoder, TextDecoder in browser and
- * string_decoder in node, byte array boundaries are handled automatically.
- *
- * In the context of an input pipeline for machine learning, UTF8 decoding is
- * needed to parse text files containing training examples or prediction
- * requests (e.g., formatted as CSV or JSON). We cannot use the built-in
- * decoding provided by FileReader.readAsText() because here we are in a
- * streaming context, which FileReader does not support.
- *
- * @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
- * text, which should be interpreted as concatenated. No assumptions are
- * made about the boundaries of the incoming chunks, so a multi-byte UTF8
- * encoding of a character may span the boundary between chunks. This
- * naturally happens, for instance, when reading fixed-size byte arrays from a
- * file.
- */
- class Utf8IteratorImpl extends OneToManyIterator {
- constructor(upstream) {
- super();
- this.upstream = upstream;
- if (env().get('IS_BROWSER')) {
- this.decoder = new TextDecoder('utf-8');
- }
- else {
- // tslint:disable-next-line:no-require-imports
- const { StringDecoder } = require('string_decoder');
- this.decoder = new StringDecoder('utf8');
- }
- }
- summary() {
- return `${this.upstream.summary()} -> Utf8`;
- }
- async pump() {
- const chunkResult = await this.upstream.next();
- let chunk;
- if (chunkResult.done) {
- return false;
- }
- else {
- chunk = chunkResult.value;
- }
- let text;
- if (env().get('IS_BROWSER')) {
- text = this.decoder.decode(chunk, { stream: true });
- }
- else {
- text = this.decoder.write(Buffer.from(chunk.buffer));
- }
- this.outputQueue.push(text);
- return true;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Provide a stream of chunks from a File, Blob, or Uint8Array.
- * @param file The source File, Blob or Uint8Array.
- * @param options Optional settings controlling file reading.
- * @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
- * input File, Blob or Uint8Array.
- */
- class FileChunkIterator extends ByteChunkIterator {
- constructor(file, options = {}) {
- super();
- this.file = file;
- this.options = options;
- assert((file instanceof Uint8Array) ||
- (env().get('IS_BROWSER') ?
- (file instanceof File || file instanceof Blob) :
- false), () => 'FileChunkIterator only supports File, Blob and Uint8Array ' +
- 'right now.');
- this.offset = options.offset || 0;
- // default 1MB chunk has tolerable perf on large files
- this.chunkSize = options.chunkSize || 1024 * 1024;
- }
- summary() {
- return `FileChunks ${this.file}`;
- }
- async next() {
- if (this.offset >= ((this.file instanceof Uint8Array) ?
- this.file.byteLength :
- this.file.size)) {
- return { value: null, done: true };
- }
- const chunk = new Promise((resolve, reject) => {
- const end = this.offset + this.chunkSize;
- if (this.file instanceof Uint8Array) {
- // Note if end > this.uint8Array.byteLength, we just get a small last
- // chunk.
- resolve(new Uint8Array(this.file.slice(this.offset, end)));
- }
- else {
- // This branch assumes that this.file type is File or Blob, which
- // means it is in the browser environment.
- // TODO(soergel): is this a performance issue?
- const fileReader = new FileReader();
- fileReader.onload = (event) => {
- let data = fileReader.result;
- // Not sure we can trust the return type of
- // FileReader.readAsArrayBuffer See e.g.
- // https://github.com/node-file-api/FileReader/issues/2
- if (data instanceof ArrayBuffer) {
- data = new Uint8Array(data);
- }
- if (!(data instanceof Uint8Array)) {
- return reject(new TypeError('FileReader returned unknown type.'));
- }
- resolve(data);
- };
- fileReader.onabort = (event) => {
- return reject(new Error('Aborted'));
- };
- fileReader.onerror = (event) => {
- return reject(new Error(event.type));
- };
- // TODO(soergel): better handle onabort, onerror
- // Note if end > this.file.size, we just get a small last chunk.
- const slice = this.file.slice(this.offset, end);
- // We can't use readAsText here (even if we know the file is text)
- // because the slice boundary may fall within a multi-byte character.
- fileReader.readAsArrayBuffer(slice);
- }
- this.offset = end;
- });
- return { value: (await chunk), done: false };
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Provide a stream of chunks from a URL.
- *
- * Note this class first downloads the entire file into memory before providing
- * the first element from the stream. This is because the Fetch API does not
- * yet reliably provide a reader stream for the response body.
- */
- async function urlChunkIterator(url, options = {}) {
- let urlString;
- let requestInit;
- if ((typeof url) === 'string') {
- urlString = url;
- }
- else {
- urlString = url.url;
- requestInit = getRequestInitFromRequest(url);
- }
- const response = await fetch$1(urlString, requestInit);
- if (response.ok) {
- const uint8Array = new Uint8Array(await response.arrayBuffer());
- return new FileChunkIterator(uint8Array, options);
- }
- else {
- throw new Error(response.statusText);
- }
- }
- // Generate RequestInit from Request to match tf.util.fetch signature.
- const getRequestInitFromRequest = (request) => {
- const init = {
- method: request.method,
- headers: request.headers,
- body: request.body,
- mode: request.mode,
- credentials: request.credentials,
- cache: request.cache,
- redirect: request.redirect,
- referrer: request.referrer,
- integrity: request.integrity,
- };
- return init;
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- // Skip tslint any type check cause this method is aiming to check type of
- // input.
- // tslint:disable-next-line:no-any
- function isLocalPath(source) {
- return (typeof source === 'string') && source.substr(0, 7) === 'file://';
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Represents a file, blob, or Uint8Array readable as a stream of binary data
- * chunks.
- */
- class FileDataSource extends DataSource {
- /**
- * Create a `FileDataSource`.
- *
- * @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
- * read. Local file only works in node environment.
- * @param options Options passed to the underlying `FileChunkIterator`s,
- * such as {chunksize: 1024}.
- */
- constructor(input, options = {}) {
- super();
- this.input = input;
- this.options = options;
- }
- async iterator() {
- if (isLocalPath(this.input) && env().get('IS_NODE')) {
- // tslint:disable-next-line:no-require-imports
- const fs = require('fs');
- this.input = fs.readFileSync(this.input.substr(7));
- }
- // TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
- // with file in browser.
- return new FileChunkIterator(this.input, this.options);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /*
- * Represents a URL readable as a stream of binary data chunks.
- */
- class URLDataSource extends DataSource {
- /**
- * Create a `URLDataSource`.
- *
- * @param url A source URL string, or a `Request` object.
- * @param options Options passed to the underlying `FileChunkIterator`s,
- * such as {chunksize: 1024}.
- */
- constructor(url, fileOptions = {}) {
- super();
- this.url = url;
- this.fileOptions = fileOptions;
- }
- // TODO(soergel): provide appropriate caching options. Currently this
- // will download the URL anew for each call to iterator(). Since we have
- // to treat the downloaded file as a blob/buffer anyway, we may as well retain
- // it-- but that raises GC issues. Also we may want a persistent disk cache.
- async iterator() {
- if (isLocalPath(this.url)) {
- return (new FileDataSource(this.url, this.fileOptions))
- .iterator();
- }
- else {
- return urlChunkIterator(this.url, this.fileOptions);
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- * =============================================================================
- */
- /**
- * Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
- * or local path if it's in Node environment.
- *
- * Note: If isLabel in columnConfigs is `true` for at least one column, the
- * element in returned `CSVDataset` will be an object of
- * `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
- * is a dict of labels key/value pairs. If no column is marked as label,
- * returns a dict of features only.
- *
- * ```js
- * const csvUrl =
- * 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
- *
- * async function run() {
- * // We want to predict the column "medv", which represents a median value of
- * // a home (in $1000s), so we mark it as a label.
- * const csvDataset = tf.data.csv(
- * csvUrl, {
- * columnConfigs: {
- * medv: {
- * isLabel: true
- * }
- * }
- * });
- *
- * // Number of features is the number of column names minus one for the label
- * // column.
- * const numOfFeatures = (await csvDataset.columnNames()).length - 1;
- *
- * // Prepare the Dataset for training.
- * const flattenedDataset =
- * csvDataset
- * .map(({xs, ys}) =>
- * {
- * // Convert xs(features) and ys(labels) from object form (keyed by
- * // column name) to array form.
- * return {xs:Object.values(xs), ys:Object.values(ys)};
- * })
- * .batch(10);
- *
- * // Define the model.
- * const model = tf.sequential();
- * model.add(tf.layers.dense({
- * inputShape: [numOfFeatures],
- * units: 1
- * }));
- * model.compile({
- * optimizer: tf.train.sgd(0.000001),
- * loss: 'meanSquaredError'
- * });
- *
- * // Fit the model using the prepared Dataset
- * return model.fitDataset(flattenedDataset, {
- * epochs: 10,
- * callbacks: {
- * onEpochEnd: async (epoch, logs) => {
- * console.log(epoch + ':' + logs.loss);
- * }
- * }
- * });
- * }
- *
- * await run();
- * ```
- *
- * @param source URL or local path to get CSV file. If it's a local path, it
- * must have prefix `file://` and it only works in node environment.
- * @param csvConfig (Optional) A CSVConfig object that contains configurations
- * of reading and decoding from CSV file(s).
- *
- * @doc {
- * heading: 'Data',
- * subheading: 'Creation',
- * namespace: 'data',
- * configParamIndices: [1]
- * }
- */
- function csv(source, csvConfig = {}) {
- return new CSVDataset(new URLDataSource(source), csvConfig);
- }
- /**
- * Create a `Dataset` that produces each element by calling a provided function.
- *
- * Note that repeated iterations over this `Dataset` may produce different
- * results, because the function will be called anew for each element of each
- * iteration.
- *
- * Also, beware that the sequence of calls to this function may be out of order
- * in time with respect to the logical order of the Dataset. This is due to the
- * asynchronous lazy nature of stream processing, and depends on downstream
- * transformations (e.g. .shuffle()). If the provided function is pure, this is
- * no problem, but if it is a closure over a mutable state (e.g., a traversal
- * pointer), then the order of the produced elements may be scrambled.
- *
- * ```js
- * let i = -1;
- * const func = () =>
- * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
- * const ds = tf.data.func(func);
- * await ds.forEachAsync(e => console.log(e));
- * ```
- *
- * @param f A function that produces one data element on each call.
- */
- function func(f) {
- const iter = iteratorFromFunction(f);
- return datasetFromIteratorFn(async () => iter);
- }
- /**
- * Create a `Dataset` that produces each element from provided JavaScript
- * generator, which is a function*
- * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
- * or a function that returns an
- * iterator
- * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
- *
- * The returned iterator should have `.next()` function that returns element in
- * format of `{value: TensorContainer, done:boolean}`.
- *
- * Example of creating a dataset from an iterator factory:
- * ```js
- * function makeIterator() {
- * const numElements = 10;
- * let index = 0;
- *
- * const iterator = {
- * next: () => {
- * let result;
- * if (index < numElements) {
- * result = {value: index, done: false};
- * index++;
- * return result;
- * }
- * return {value: index, done: true};
- * }
- * };
- * return iterator;
- * }
- * const ds = tf.data.generator(makeIterator);
- * await ds.forEachAsync(e => console.log(e));
- * ```
- *
- * Example of creating a dataset from a generator:
- * ```js
- * function* dataGenerator() {
- * const numElements = 10;
- * let index = 0;
- * while (index < numElements) {
- * const x = index;
- * index++;
- * yield x;
- * }
- * }
- *
- * const ds = tf.data.generator(dataGenerator);
- * await ds.forEachAsync(e => console.log(e));
- * ```
- *
- * @param generator A Javascript generator function that returns a JavaScript
- * iterator.
- *
- * @doc {
- * heading: 'Data',
- * subheading: 'Creation',
- * namespace: 'data',
- * configParamIndices: [1]
- * }
- */
- function generator(generator) {
- return datasetFromIteratorFn(async () => {
- const gen = await generator();
- return iteratorFromFunction(() => gen.next());
- });
- }
- /**
- * Create an iterator that generate `Tensor`s from webcam video stream. This API
- * only works in Browser environment when the device has webcam.
- *
- * Note: this code snippet only works when the device has a webcam. It will
- * request permission to open the webcam when running.
- * ```js
- * const videoElement = document.createElement('video');
- * videoElement.width = 100;
- * videoElement.height = 100;
- * const cam = await tf.data.webcam(videoElement);
- * const img = await cam.capture();
- * img.print();
- * cam.stop();
- * ```
- *
- * @param webcamVideoElement A `HTMLVideoElement` used to play video from
- * webcam. If this element is not provided, a hidden `HTMLVideoElement` will
- * be created. In that case, `resizeWidth` and `resizeHeight` must be
- * provided to set the generated tensor shape.
- * @param webcamConfig A `WebcamConfig` object that contains configurations of
- * reading and manipulating data from webcam video stream.
- *
- * @doc {
- * heading: 'Data',
- * subheading: 'Creation',
- * namespace: 'data',
- * ignoreCI: true
- * }
- */
- async function webcam(webcamVideoElement, webcamConfig) {
- return WebcamIterator.create(webcamVideoElement, webcamConfig);
- }
- /**
- * Create an iterator that generate frequency-domain spectrogram `Tensor`s from
- * microphone audio stream with browser's native FFT. This API only works in
- * browser environment when the device has microphone.
- *
- * Note: this code snippet only works when the device has a microphone. It will
- * request permission to open the microphone when running.
- * ```js
- * const mic = await tf.data.microphone({
- * fftSize: 1024,
- * columnTruncateLength: 232,
- * numFramesPerSpectrogram: 43,
- * sampleRateHz:44100,
- * includeSpectrogram: true,
- * includeWaveform: true
- * });
- * const audioData = await mic.capture();
- * const spectrogramTensor = audioData.spectrogram;
- * spectrogramTensor.print();
- * const waveformTensor = audioData.waveform;
- * waveformTensor.print();
- * mic.stop();
- * ```
- *
- * @param microphoneConfig A `MicrophoneConfig` object that contains
- * configurations of reading audio data from microphone.
- *
- * @doc {
- * heading: 'Data',
- * subheading: 'Creation',
- * namespace: 'data',
- * ignoreCI: true
- * }
- */
- async function microphone(microphoneConfig) {
- return MicrophoneIterator.create(microphoneConfig);
- }
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$3 = '0.0.0';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var index = /*#__PURE__*/Object.freeze({
- __proto__: null,
- array: array,
- Dataset: Dataset,
- zip: zip,
- CSVDataset: CSVDataset,
- TextLineDataset: TextLineDataset,
- csv: csv,
- func: func,
- generator: generator,
- microphone: microphone,
- webcam: webcam,
- FileDataSource: FileDataSource,
- URLDataSource: URLDataSource,
- version_data: version$3
- });
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertNotComplex(tensor, opName) {
- if (!Array.isArray(tensor)) {
- tensor = [tensor];
- }
- tensor.forEach(t => {
- if (t != null) {
- assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
- }
- });
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl;
- const split$4 = split$1;
- const tile$3 = tile$1;
- const topkImpl$1 = topkImpl;
- const whereImpl$1 = whereImpl;
- /**
- * @deprecated remove once all fused kernels are modularized.
- *
- * Use fused_utils.applyActivation instead.
- */
- function mapActivation(x, activation, preluActivationWeights) {
- if (activation === 'linear') {
- return clone(x);
- }
- else if (activation === 'relu') {
- return relu(x);
- }
- else if (activation === 'elu') {
- return elu(x);
- }
- else if (activation === 'relu6') {
- return relu6(x);
- }
- else if (activation === 'prelu') {
- return prelu(x, preluActivationWeights);
- }
- throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
- }
- class MathBackendCPU extends KernelBackend {
- constructor() {
- super();
- this.blockSize = 48;
- this.firstUse = true;
- this.data = new DataStorage(this, engine());
- }
- write(values, shape, dtype) {
- if (this.firstUse) {
- this.firstUse = false;
- if (env().get('IS_NODE')) {
- warn('\n============================\n' +
- 'Hi there 👋. Looks like you are running TensorFlow.js in ' +
- 'Node.js. To speed things up dramatically, install our node ' +
- 'backend, which binds to TensorFlow C++, by running ' +
- 'npm i @tensorflow/tfjs-node, ' +
- 'or npm i @tensorflow/tfjs-node-gpu if you have CUDA. ' +
- 'Then call require(\'@tensorflow/tfjs-node\'); (-gpu ' +
- 'suffix for CUDA) at the start of your program. ' +
- 'Visit https://github.com/tensorflow/tfjs-node for more details.' +
- '\n============================');
- }
- }
- const dataId = {};
- this.data.set(dataId, { values, dtype, refCount: 1 });
- return dataId;
- }
- /**
- * Create a data bucket in cpu backend.
- * @param shape Shape of the `TensorInfo`.
- * @param dtype DType of the `TensorInfo`.
- * @param values The value of the `TensorInfo` stored as a flattened array.
- */
- makeTensorInfo(shape, dtype, values) {
- const outId = this.write(values, shape, dtype);
- return { dataId: outId, shape, dtype };
- }
- /** Increase refCount of a `TensorData`. */
- incRef(dataId) {
- const tensorData = this.data.get(dataId);
- tensorData.refCount++;
- }
- /** Decrease refCount of a `TensorData`. */
- decRef(dataId) {
- if (this.data.has(dataId)) {
- const tensorData = this.data.get(dataId);
- tensorData.refCount--;
- }
- }
- move(dataId, values, shape, dtype) {
- this.data.set(dataId, { values, dtype, refCount: 1 });
- }
- numDataIds() {
- return this.data.numDataIds();
- }
- async read(dataId) {
- return this.readSync(dataId);
- }
- readSync(dataId) {
- const { dtype, complexTensorInfos } = this.data.get(dataId);
- if (dtype === 'complex64') {
- const realValues = this.readSync(complexTensorInfos.real.dataId);
- const imagValues = this.readSync(complexTensorInfos.imag.dataId);
- return mergeRealAndImagArrays(realValues, imagValues);
- }
- return this.data.get(dataId).values;
- }
- bufferSync(t) {
- const data = this.readSync(t.dataId);
- let decodedData = data;
- if (t.dtype === 'string') {
- try {
- // Decode the bytes into string.
- decodedData = data.map(d => decodeString(d));
- }
- catch {
- throw new Error('Failed to decode encoded string bytes into utf-8');
- }
- }
- return buffer(t.shape, t.dtype, decodedData);
- }
- makeOutput(values, shape, dtype) {
- const dataId = this.write(values, shape, dtype);
- return engine().makeTensorFromDataId(dataId, shape, dtype, this);
- }
- disposeData(dataId) {
- if (this.data.has(dataId)) {
- const { complexTensorInfos } = this.data.get(dataId);
- if (complexTensorInfos != null) {
- this.disposeData(complexTensorInfos.real.dataId);
- this.disposeData(complexTensorInfos.imag.dataId);
- }
- this.data.delete(dataId);
- }
- }
- disposeIntermediateTensorInfo(tensorInfo) {
- const dataId = tensorInfo.dataId;
- if (this.data.has(dataId)) {
- const tensorData = this.data.get(dataId);
- tensorData.refCount--;
- if (tensorData.refCount < 1) {
- this.disposeData(dataId);
- }
- }
- }
- async time(f) {
- const start = now();
- f();
- const kernelMs = now() - start;
- return { kernelMs };
- }
- memory() {
- return {
- // Unreliable due to automatic gc. The numbers above are cumulative.
- unreliable: true,
- reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
- 'collection, the true allocated memory may be less.']
- };
- }
- stridedSlice(x, begin, end, strides) {
- assertNotComplex(x, 'stridedSlice');
- const outShape = computeOutShape(begin, end, strides);
- if (outShape.some(axis => axis === 0)) {
- return tensor([], outShape);
- }
- const buffer$1 = buffer(outShape, x.dtype);
- const xBuf = this.bufferSync(x);
- for (let i = 0; i < buffer$1.size; i++) {
- const loc = buffer$1.indexToLoc(i);
- const newLoc = new Array(loc.length);
- for (let j = 0; j < newLoc.length; j++) {
- newLoc[j] = loc[j] * strides[j] + begin[j];
- }
- buffer$1.set(xBuf.get(...newLoc), ...loc);
- }
- return buffer$1.toTensor();
- }
- diag(x) {
- const xVals = this.readSync(x.dataId);
- const buffer$1 = buffer([x.size, x.size], x.dtype);
- const vals = buffer$1.values;
- for (let i = 0; i < xVals.length; i++) {
- vals[i * x.size + i] = xVals[i];
- }
- return buffer$1.toTensor();
- }
- unstack(x, axis) {
- const num = x.shape[axis];
- const outShape = new Array(x.rank - 1);
- let outIndex = 0;
- for (let i = 0; i < x.rank; i++) {
- if (i !== axis) {
- outShape[outIndex++] = x.shape[i];
- }
- }
- const begin = new Array(x.rank).fill(0);
- const size = x.shape.slice();
- size[axis] = 1;
- const res = new Array(num);
- for (let i = 0; i < res.length; i++) {
- begin[axis] = i;
- res[i] = slice(x, begin, size).reshape(outShape);
- }
- return res;
- }
- reverse(x, axis) {
- assertNotComplex(x, 'reverse');
- const buffer$1 = buffer(x.shape, x.dtype);
- const xBuf = this.bufferSync(x);
- for (let i = 0; i < buffer$1.size; i++) {
- const outLoc = buffer$1.indexToLoc(i);
- const inLoc = outLoc.slice();
- axis.forEach(ax => inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]);
- buffer$1.set(xBuf.get(...inLoc), ...outLoc);
- }
- return buffer$1.toTensor();
- }
- neg(x) {
- assertNotComplex(x, 'neg');
- // TODO(lina128): Use mul directly once neg is modularized.
- return mul(scalar(-1), x);
- }
- addN(tensors) {
- assertNotComplex(tensors, 'addN');
- const vals = tensors.map(t => this.readSync(t.dataId));
- const result = buffer(tensors[0].shape, tensors[0].dtype);
- const resultVals = result.values;
- for (let i = 0; i < tensors.length; i++) {
- const currVals = vals[i];
- for (let j = 0; j < resultVals.length; j++) {
- resultVals[j] += currVals[j];
- }
- }
- return result.toTensor();
- }
- softmax(logits, dim) {
- const axes = parseAxisParam([dim], logits.shape);
- // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
- // modularization.
- const maxLogit = max(logits, axes);
- const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
- // TODO(lina128): Use sub directly once softmax is modularized.
- const a = sub(logits, maxLogit.reshape(expandedShape));
- const b = exp(a);
- const sumExp = this.sum(b, axes).reshape(expandedShape);
- // TODO(annxingyuan): Call divImpl rather than op as part of softmax
- // kernel modularization.
- return div(b, sumExp);
- }
- pow(a, b) {
- assertNotComplex([a, b], 'pow');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aValue, bValue) => Math.pow(aValue, bValue));
- }
- floorDiv(a, b) {
- assertNotComplex([a, b], 'floorDiv');
- const op = (a, b) => Math.floor(a / b);
- const outputDtype = 'int32';
- return this.broadcastedBinaryOp(a, b, outputDtype, op);
- }
- sum(x, axes) {
- assertNotComplex(x, 'sum');
- assertAxesAreInnerMostDims('sum', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const resultDtype = upcastType(x.dtype, 'int32');
- const result = zeros(outShape, resultDtype);
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let sum = 0;
- for (let j = 0; j < reduceSize; ++j) {
- sum += aVals[offset + j];
- }
- vals[i] = sum;
- }
- return result;
- }
- prod(x, axes) {
- assertNotComplex(x, 'sum');
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const resultDtype = upcastType(x.dtype, 'int32');
- const result = zeros(outShape, resultDtype);
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let prod = 1;
- for (let j = 0; j < reduceSize; ++j) {
- prod *= aVals[offset + j];
- }
- vals[i] = prod;
- }
- return result;
- }
- unsortedSegmentSum(x, segmentIds, numSegments) {
- assertNotComplex(x, 'unsortedSegmentSum');
- const res = [];
- // Reshape the segment id's so that they can be broadcast with
- // x. The new shape should be [segmentIds.shape, 1, ..., 1]
- const numIters = x.rank - segmentIds.rank;
- for (let i = 0; i < numIters; ++i) {
- segmentIds = segmentIds.expandDims(i + 1);
- }
- for (let i = 0; i < numSegments; ++i) {
- const segmentId = scalar(i, 'int32');
- const mask = equal(segmentId, segmentIds).asType('float32');
- const sum = mask.mul(x).sum(0);
- res.push(sum);
- }
- return stack(res);
- }
- argMin(x, axis) {
- assertNotComplex(x, 'argMin');
- const axes = [axis];
- assertAxesAreInnerMostDims('argMin', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const result = zeros(outShape, 'int32');
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let min = aVals[offset];
- let minIndex = 0;
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- if (value < min) {
- min = value;
- minIndex = j;
- }
- }
- vals[i] = minIndex;
- }
- return result;
- }
- argMax(x, axis) {
- assertNotComplex(x, 'argMax');
- const axes = [axis];
- assertAxesAreInnerMostDims('argMax', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const result = zeros(outShape, 'int32');
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let max = aVals[offset];
- let maxIndex = 0;
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- if (value > max) {
- max = value;
- maxIndex = j;
- }
- }
- vals[i] = maxIndex;
- }
- return result;
- }
- cumsum(x, axis, exclusive, reverse) {
- assertNotComplex(x, 'cumsum');
- if (axis !== x.rank - 1) {
- throw new Error(`backend.cumsum in CPU expects an inner-most axis=${x.rank - 1} ` +
- `but got axis=${axis}`);
- }
- const resultDtype = upcastType(x.dtype, 'int32');
- const result = zeros(x.shape, resultDtype);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- const finalDim = x.shape[x.rank - 1];
- const indexAdjuster = reverse ?
- (i, j) => i + finalDim - j - 1 :
- (i, j) => i + j;
- for (let i = 0; i < aVals.length; i += finalDim) {
- for (let j = 0; j < finalDim; j++) {
- const idx = indexAdjuster(i, j);
- if (j === 0) {
- vals[idx] = exclusive ? 0 : aVals[idx];
- }
- else {
- const prevIdx = indexAdjuster(i, j - 1);
- vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
- aVals[idx] + vals[prevIdx];
- }
- }
- }
- return result;
- }
- equal(a, b) {
- assertNotComplex([a, b], 'equal');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal === bVal) ? 1 : 0;
- });
- }
- notEqual(a, b) {
- assertNotComplex([a, b], 'notEqual');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal !== bVal) ? 1 : 0;
- });
- }
- less(a, b) {
- assertNotComplex([a, b], 'less');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal < bVal) ? 1 : 0;
- });
- }
- lessEqual(a, b) {
- assertNotComplex([a, b], 'lessEqual');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal <= bVal) ? 1 : 0;
- });
- }
- greater(a, b) {
- assertNotComplex([a, b], 'greater');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal > bVal) ? 1 : 0;
- });
- }
- greaterEqual(a, b) {
- assertNotComplex([a, b], 'greaterEqual');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return (aVal >= bVal) ? 1 : 0;
- });
- }
- logicalAnd(a, b) {
- assertNotComplex([a, b], 'logicalAnd');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return aVal && bVal;
- });
- }
- logicalOr(a, b) {
- assertNotComplex([a, b], 'logicalOr');
- return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {
- return aVal || bVal;
- });
- }
- select(condition, a, b) {
- assertNotComplex([condition, a, b], 'select');
- const values = this.readSync(condition.dataId);
- const aValues = this.readSync(a.dataId);
- const bValues = this.readSync(b.dataId);
- const result = zeros(a.shape, upcastType(a.dtype, b.dtype));
- const newValues = this.readSync(result.dataId);
- let index = 0;
- const offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ?
- 1 :
- sizeFromShape(a.shape.slice(1));
- for (let i = 0; i < values.length; i++) {
- for (let j = 0; j < offset; j++) {
- if (values[i] === 1) {
- newValues[index++] = aValues[i];
- }
- else {
- newValues[index++] = bValues[i];
- }
- }
- }
- return result;
- }
- where(condition) {
- assertNotComplex([condition], 'where');
- const condVals = this.readSync(condition.dataId);
- return whereImpl$1(condition.shape, condVals);
- }
- topk(x, k, sorted) {
- assertNotComplex(x, 'topk');
- const xVals = this.readSync(x.dataId);
- return topkImpl$1(xVals, x.shape, x.dtype, k, sorted);
- }
- min(x, axes) {
- assertNotComplex(x, 'min');
- assertAxesAreInnerMostDims('min', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const result = zeros(outShape, x.dtype);
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let min = aVals[offset];
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- if (value < min) {
- min = value;
- }
- }
- vals[i] = min;
- }
- return result;
- }
- minimum(a, b) {
- assertNotComplex([a, b], 'minimum');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => Math.min(aVal, bVal));
- }
- mod(a, b) {
- assertNotComplex([a, b], 'mod');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => {
- const rem = aVal % bVal;
- if ((aVal < 0 && bVal < 0) || (aVal >= 0 && bVal >= 0)) {
- return rem;
- }
- else {
- return (rem + bVal) % bVal;
- }
- });
- }
- maximum(a, b) {
- assertNotComplex([a, b], 'maximum');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => Math.max(aVal, bVal));
- }
- all(x, axes) {
- assertNotComplex(x, 'all');
- assertAxesAreInnerMostDims('all', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const result = zeros(outShape, x.dtype);
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let all = aVals[offset];
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- all = all && value;
- }
- vals[i] = all;
- }
- return result;
- }
- any(x, axes) {
- assertNotComplex(x, 'any');
- assertAxesAreInnerMostDims('any', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const result = zeros(outShape, x.dtype);
- const reduceSize = sizeFromShape(reduceShape);
- const vals = this.readSync(result.dataId);
- const aVals = this.readSync(x.dataId);
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let anyVal = aVals[offset];
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- anyVal = anyVal || value;
- }
- vals[i] = anyVal;
- }
- return result;
- }
- squaredDifference(a, b) {
- assertNotComplex([a, b], 'squaredDifference');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => {
- const diff = aVal - bVal;
- return diff * diff;
- });
- }
- eluDer(dy, y) {
- assertNotComplex([dy, y], 'eluDer');
- const resultValues = new Float32Array(y.size);
- const values = this.readSync(y.dataId);
- const dyValues = this.readSync(dy.dataId);
- for (let i = 0; i < values.length; ++i) {
- const v = values[i];
- if (v >= 1) {
- resultValues[i] = dyValues[i];
- }
- else {
- resultValues[i] = dyValues[i] * (v + 1);
- }
- }
- return this.makeOutput(resultValues, y.shape, 'float32');
- }
- atan2(a, b) {
- assertNotComplex([a, b], 'atan2');
- return this.broadcastedBinaryOp(a, b, a.dtype, (aValue, bValue) => Math.atan2(aValue, bValue));
- }
- conv3d(x, filter, convInfo) {
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const padFront = convInfo.padInfo.front;
- const padLeft = convInfo.padInfo.left;
- const padTop = convInfo.padInfo.top;
- const y = buffer(convInfo.outShape, x.dtype);
- const xVals = this.readSync(x.dataId);
- const wVals = this.readSync(filter.dataId);
- const yVals = y.values;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- const xOffset1 = b * x.strides[0];
- const yOffset1 = b * y.strides[0];
- for (let yF = 0; yF < convInfo.outDepth; ++yF) {
- const yOffset2 = yOffset1 + yF * y.strides[1];
- const xFCorner = yF * convInfo.strideDepth - padFront;
- for (let wF = 0; wF < filterDepth; wF++) {
- const xF = xFCorner + wF * dilationDepth;
- if (xF < 0 || xF >= convInfo.inDepth) {
- continue;
- }
- const wOffset1 = wF * filter.strides[0];
- const xOffset2 = xOffset1 + xF * x.strides[1];
- for (let yR = 0; yR < convInfo.outHeight; ++yR) {
- const yOffset3 = yOffset2 + yR * y.strides[2];
- const xRCorner = yR * convInfo.strideHeight - padTop;
- for (let wR = 0; wR < filterHeight; wR++) {
- const xR = xRCorner + wR * dilationHeight;
- if (xR < 0 || xR >= convInfo.inHeight) {
- continue;
- }
- const wOffset2 = wOffset1 + wR * filter.strides[1];
- const xOffset3 = xOffset2 + xR * x.strides[2];
- for (let yC = 0; yC < convInfo.outWidth; ++yC) {
- const yOffset4 = yOffset3 + yC * convInfo.outChannels;
- const xCCorner = yC * convInfo.strideWidth - padLeft;
- for (let wC = 0; wC < filterWidth; wC++) {
- const xC = xCCorner + wC * dilationWidth;
- if (xC < 0 || xC >= convInfo.inWidth) {
- continue;
- }
- const wOffset3 = wOffset2 + wC * filter.strides[2];
- const xOffset4 = xOffset3 + xC * convInfo.inChannels;
- let wOffset4 = wOffset3;
- for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
- const xVal = xVals[xOffset4 + d1];
- for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
- yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
- }
- wOffset4 += convInfo.outChannels;
- }
- }
- }
- }
- }
- }
- }
- }
- return y.toTensor();
- }
- conv2dDerInput(dy, filter, convInfo) {
- assertNotComplex([dy, filter], 'conv2dDerInput');
- const dx = buffer(convInfo.inShape, 'float32');
- const dxValues = dx.values;
- const dyValues = this.readSync(dy.dataId);
- const fltValues = this.readSync(filter.dataId);
- const [fltS0, fltS1, fltS2] = filter.strides;
- const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth, dataFormat } = convInfo;
- const topPad = filterHeight - 1 - convInfo.padInfo.top;
- const leftPad = filterWidth - 1 - convInfo.padInfo.left;
- const isChannelsLast = dataFormat === 'channelsLast';
- const xBatchStride = dx.strides[0];
- const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
- const xColStride = isChannelsLast ? dx.strides[2] : 1;
- const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
- const yBatchStride = dy.strides[0];
- const yRowStride = isChannelsLast ? dy.strides[1] : dy.strides[2];
- const yColStride = isChannelsLast ? dy.strides[2] : 1;
- const yChannelStride = isChannelsLast ? 1 : dy.strides[1];
- for (let b = 0; b < batchSize; ++b) {
- for (let d1 = 0; d1 < inChannels; ++d1) {
- for (let xR = 0; xR < inHeight; ++xR) {
- const xRCorner = xR - topPad;
- const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
- const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
- for (let xC = 0; xC < inWidth; ++xC) {
- const xCCorner = xC - leftPad;
- const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
- const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
- let dotProd = 0;
- for (let yR = xRMin; yR < yRMax; ++yR) {
- const wR = yR * strideHeight - xRCorner;
- for (let yC = xCMin; yC < yCMax; ++yC) {
- const wC = yC * strideWidth - xCCorner;
- const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
- const fltOffset = fltS0 * (filterHeight - 1 - wR) +
- fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
- for (let d2 = 0; d2 < outChannels; ++d2) {
- const pixel = dyValues[dyOffset + yChannelStride * d2];
- const weight = fltValues[fltOffset + d2];
- dotProd += pixel * weight;
- }
- }
- }
- const dxOffset = xBatchStride * b + xRowStride * xR +
- xColStride * xC + xChannelStride * d1;
- dxValues[dxOffset] = dotProd;
- }
- }
- }
- }
- return dx.toTensor();
- }
- conv3dDerInput(dy, filter, convInfo) {
- const dx = buffer(convInfo.inShape, 'float32');
- const dxValues = dx.values;
- const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
- const dyValues = this.readSync(dy.dataId);
- const [dyS0, dyS1, dyS2, dyS3] = dy.strides;
- const fltValues = this.readSync(filter.dataId);
- const [fltS0, fltS1, fltS2, fltS3] = filter.strides;
- const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
- const frontPad = filterDepth - 1 - convInfo.padInfo.front;
- const topPad = filterHeight - 1 - convInfo.padInfo.top;
- const leftPad = filterWidth - 1 - convInfo.padInfo.left;
- for (let b = 0; b < batchSize; ++b) {
- for (let d1 = 0; d1 < inChannels; ++d1) {
- // Frames of depth
- for (let xF = 0; xF < inDepth; ++xF) {
- const xFCorner = xF - frontPad;
- const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
- const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
- // Rows as per standard 2d matrix notation
- for (let xR = 0; xR < inHeight; ++xR) {
- const xRCorner = xR - topPad;
- const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
- const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
- // Columns as per standard 2d matrix notation
- for (let xC = 0; xC < inWidth; ++xC) {
- const xCCorner = xC - leftPad;
- const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
- const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
- let dotProd = 0;
- for (let yF = xFMin; yF < yFMax; ++yF) {
- const wF = yF * strideDepth - xFCorner;
- for (let yR = xRMin; yR < yRMax; ++yR) {
- const wR = yR * strideHeight - xRCorner;
- for (let yC = xCMin; yC < yCMax; ++yC) {
- const wC = yC * strideWidth - xCCorner;
- const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
- const fltOffset = fltS0 * (filterDepth - 1 - wF) +
- fltS1 * (filterHeight - 1 - wR) +
- fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
- for (let d2 = 0; d2 < outChannels; ++d2) {
- const pixel = dyValues[dyOffset + d2];
- const weight = fltValues[fltOffset + d2];
- dotProd += pixel * weight;
- }
- }
- }
- }
- dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
- dotProd;
- }
- }
- }
- }
- }
- return dx.toTensor();
- }
- conv2dDerFilter(x, dy, convInfo) {
- assertNotComplex([x, dy], 'conv2dDerFilter');
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- const dW = buffer(convInfo.filterShape, 'float32');
- const leftPad = convInfo.padInfo.left;
- const topPad = convInfo.padInfo.top;
- const xBuf = this.bufferSync(x);
- const dyBuf = this.bufferSync(dy);
- for (let wR = 0; wR < filterHeight; ++wR) {
- const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
- const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
- for (let wC = 0; wC < filterWidth; ++wC) {
- const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
- const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
- for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
- for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
- // Need to convolve.
- let dotProd = 0;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- for (let yR = yRMin; yR < yRMax; ++yR) {
- const xR = wR + yR * strideHeight - topPad;
- for (let yC = yCMin; yC < yCMax; ++yC) {
- const xC = wC + yC * strideWidth - leftPad;
- if (isChannelsLast) {
- dotProd +=
- xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
- }
- else {
- dotProd +=
- xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);
- }
- }
- }
- }
- dW.set(dotProd, wR, wC, d1, d2);
- }
- }
- }
- }
- return dW.toTensor();
- }
- conv3dDerFilter(x, dy, convInfo) {
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dw = buffer(convInfo.filterShape, 'float32');
- const dwValues = dw.values;
- const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
- const dyValues = this.readSync(dy.dataId);
- const [dyS0, dyS1, dyS2, dyS3] = dy.strides;
- const xValues = this.readSync(x.dataId);
- const [xS0, xS1, xS2, xS3] = x.strides;
- const frontPad = convInfo.padInfo.front;
- const leftPad = convInfo.padInfo.left;
- const topPad = convInfo.padInfo.top;
- for (let wF = 0; wF < filterDepth; ++wF) {
- const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
- const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
- const wOffset1 = wF * dwS0;
- for (let wR = 0; wR < filterHeight; ++wR) {
- const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
- const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
- const wOffset2 = wR * dwS1 + wOffset1;
- for (let wC = 0; wC < filterWidth; ++wC) {
- const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
- const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
- const wOffset3 = wC * dwS2 + wOffset2;
- for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
- const wOffset4 = d1 * dwS3 + wOffset3;
- for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
- let dotProd = 0;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- const xOffset1 = b * xS0;
- const yOffset1 = b * dyS0;
- for (let yF = yFMin; yF < yFMax; ++yF) {
- const xF = wF + yF * strideDepth - frontPad;
- const xOffset2 = xF * xS1 + xOffset1;
- const yOffset2 = yF * dyS1 + yOffset1;
- for (let yR = yRMin; yR < yRMax; ++yR) {
- const xR = wR + yR * strideHeight - topPad;
- const xOffset3 = xR * xS2 + xOffset2;
- const yOffset3 = yR * dyS2 + yOffset2;
- for (let yC = yCMin; yC < yCMax; ++yC) {
- const xC = wC + yC * strideWidth - leftPad;
- const xOffset4 = xC * xS3 + xOffset3;
- const yOffset4 = yC * dyS3 + yOffset3;
- dotProd +=
- xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
- }
- }
- }
- }
- dwValues[wOffset4 + d2] = dotProd;
- }
- }
- }
- }
- }
- return dw.toTensor();
- }
- fusedDepthwiseConv2D({ input, filter, convInfo, bias, activation, preluActivationWeights }) {
- let result = this.depthwiseConv2D(input, filter, convInfo);
- if (bias) {
- // TODO(lina128): Use add directly once fusedDepthwiseConv2D is
- // modularized.
- result = add$1(result, bias);
- }
- if (activation) {
- result =
- mapActivation(result, activation, preluActivationWeights);
- }
- return result;
- }
- depthwiseConv2D(x, filter, convInfo) {
- assertNotComplex([x, filter], 'depthwiseConv2D');
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const padLeft = convInfo.padInfo.left;
- const padTop = convInfo.padInfo.top;
- const chMul = convInfo.outChannels / convInfo.inChannels;
- const y = buffer(convInfo.outShape, x.dtype);
- const xVals = this.readSync(x.dataId);
- const wVals = this.readSync(filter.dataId);
- const yVals = y.values;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- const xOffset1 = b * x.strides[0];
- const yOffset1 = b * y.strides[0];
- for (let yR = 0; yR < convInfo.outHeight; ++yR) {
- const yOffset2 = yOffset1 + yR * y.strides[1];
- const xRCorner = yR * convInfo.strideHeight - padLeft;
- for (let wR = 0; wR < filterHeight; ++wR) {
- const xR = xRCorner + wR * dilationHeight;
- if (xR < 0 || xR >= convInfo.inHeight) {
- continue;
- }
- const wOffset1 = wR * filter.strides[0];
- const xOffset2 = xOffset1 + xR * x.strides[1];
- for (let yC = 0; yC < convInfo.outWidth; ++yC) {
- const yOffset3 = yOffset2 + yC * y.strides[2];
- const xCCorner = yC * convInfo.strideWidth - padTop;
- for (let wC = 0; wC < filterWidth; ++wC) {
- const xC = xCCorner + wC * dilationWidth;
- if (xC < 0 || xC >= convInfo.inWidth) {
- continue;
- }
- const wOffset2 = wOffset1 + wC * filter.strides[1];
- const xOffset3 = xOffset2 + xC * convInfo.inChannels;
- let yOffset4 = yOffset3;
- let wOffset3 = wOffset2;
- for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
- const xVal = xVals[xOffset3 + d1];
- for (let q = 0; q < chMul; ++q) {
- yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
- }
- yOffset4 += chMul;
- wOffset3 += chMul;
- }
- }
- }
- }
- }
- }
- return y.toTensor();
- }
- depthwiseConv2DDerInput(dy, filter, convInfo) {
- assertNotComplex([dy, filter], 'depthwiseConv2DDerInput');
- const dx = buffer(convInfo.inShape, 'float32');
- const dxValues = dx.values;
- const [dxS0, dxS1, dxS2] = dx.strides;
- const dyValues = this.readSync(dy.dataId);
- const [dyS0, dyS1, dyS2] = dy.strides;
- const fltValues = this.readSync(filter.dataId);
- const [fltS0, fltS1, fltS2] = filter.strides;
- const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
- const topPad = filterHeight - 1 - convInfo.padInfo.top;
- const leftPad = filterWidth - 1 - convInfo.padInfo.left;
- const chMul = outChannels / inChannels;
- for (let b = 0; b < batchSize; ++b) {
- for (let d1 = 0; d1 < inChannels; ++d1) {
- for (let xR = 0; xR < inHeight; ++xR) {
- const xRCorner = xR - topPad;
- const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
- const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
- for (let xC = 0; xC < inWidth; ++xC) {
- const xCCorner = xC - leftPad;
- const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
- const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
- let dotProd = 0;
- for (let yR = xRMin; yR < yRMax; ++yR) {
- const wR = yR * strideHeight - xRCorner;
- for (let yC = xCMin; yC < yCMax; ++yC) {
- const wC = yC * strideWidth - xCCorner;
- const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
- const fltOffset = fltS0 * (filterHeight - 1 - wR) +
- fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
- for (let dm = 0; dm < chMul; ++dm) {
- const d2 = d1 * chMul + dm;
- const pixel = dyValues[dyOffset + d2];
- const weight = fltValues[fltOffset + dm];
- dotProd += pixel * weight;
- }
- }
- }
- dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
- }
- }
- }
- }
- return dx.toTensor();
- }
- depthwiseConv2DDerFilter(x, dy, convInfo) {
- assertNotComplex([x, dy], 'depthwiseConv2DDerFilter');
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dW = buffer(convInfo.filterShape, 'float32');
- const leftPad = convInfo.padInfo.left;
- const topPad = convInfo.padInfo.top;
- const chMul = convInfo.outChannels / convInfo.inChannels;
- const xBuf = this.bufferSync(x);
- const dyBuf = this.bufferSync(dy);
- for (let wR = 0; wR < filterHeight; ++wR) {
- const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
- const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
- for (let wC = 0; wC < filterWidth; ++wC) {
- const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
- const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
- for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
- const d1 = Math.trunc(d2 / chMul);
- const dm = d2 % chMul;
- let dotProd = 0;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- for (let yR = yRMin; yR < yRMax; ++yR) {
- const xR = wR + yR * strideHeight - topPad;
- for (let yC = yCMin; yC < yCMax; ++yC) {
- const xC = wC + yC * strideWidth - leftPad;
- dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
- }
- }
- }
- dW.set(dotProd, wR, wC, d1, dm);
- }
- }
- }
- return dW.toTensor();
- }
- tile(x, reps) {
- assertNotComplex(x, 'tile');
- return tile$3(this.bufferSync(x), reps);
- }
- gather(x, indices, axis) {
- assertNotComplex([x, indices], 'gather');
- const newShape = x.shape.slice();
- const indicesValues = this.readSync(indices.dataId);
- newShape[axis] = indicesValues.length;
- const result = buffer(newShape, x.dtype);
- const xBuf = this.bufferSync(x);
- for (let i = 0; i < result.size; ++i) {
- const newLoc = result.indexToLoc(i);
- const originalLoc = newLoc.slice();
- originalLoc[axis] = indicesValues[newLoc[axis]];
- const originalIndex = xBuf.locToIndex(originalLoc);
- result.values[i] = xBuf.values[originalIndex];
- }
- return result.toTensor();
- }
- batchToSpaceND(x, blockShape, crops) {
- assertNotComplex([x], 'batchToSpaceND');
- const prod = blockShape.reduce((a, b) => a * b);
- const reshaped = getReshaped(x.shape, blockShape, prod);
- const permuted = getPermuted(reshaped.length, blockShape.length);
- const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
- const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
- const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
- return transpose(x.reshape(reshaped), permuted)
- .reshape(reshapedPermuted)
- .slice(sliceBeginCoords, sliceSize);
- }
- pool3d(x, convInfo, poolType) {
- assertNotComplex(x, 'pool3d');
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = convInfo.padInfo.front;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
- Number.POSITIVE_INFINITY);
- const xValues = this.readSync(x.dataId);
- const output = buffer(convInfo.outShape, x.dtype);
- const outputVals = output.values;
- const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
- convInfo.outShape[3] * convInfo.outShape[4];
- const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
- const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
- const outputColStrides = convInfo.outShape[4];
- for (let batch = 0; batch < convInfo.batchSize; ++batch) {
- const outputBatchOffset = batch * outputBatchStrides;
- const inputBatchOffset = batch * x.strides[0];
- for (let channel = 0; channel < convInfo.inChannels; ++channel) {
- for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
- const xDepthCorner = yDepth * strideDepth - padFront;
- let xDepthMin = xDepthCorner;
- while (xDepthMin < 0) {
- xDepthMin += dilationDepth;
- }
- const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
- const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
- for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
- const xRowCorner = yRow * strideHeight - padTop;
- let xRowMin = xRowCorner;
- while (xRowMin < 0) {
- xRowMin += dilationHeight;
- }
- const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
- const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
- for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
- const xColCorner = yCol * strideWidth - padLeft;
- let xColMin = xColCorner;
- while (xColMin < 0) {
- xColMin += dilationWidth;
- }
- const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
- // Shader code begins
- const outputColOffset = outputRowOffset + yCol * outputColStrides;
- let minMaxValue = initialValue;
- let avgValue = 0;
- let count = 0;
- for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
- const xDepthOffset = inputBatchOffset + xDepth * x.strides[1];
- for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
- const xRowOffset = xDepthOffset + xRow * x.strides[2];
- for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
- const xColOffset = xRowOffset + xCol * x.strides[3];
- const pixel = xValues[xColOffset + channel];
- if ((poolType === 'max' && pixel > minMaxValue)) {
- minMaxValue = pixel;
- }
- else if (poolType === 'avg') {
- avgValue += pixel;
- count++;
- }
- if (isNaN(minMaxValue)) {
- break;
- }
- }
- if (isNaN(minMaxValue)) {
- break;
- }
- }
- if (isNaN(minMaxValue)) {
- break;
- }
- }
- const outputOffset = outputColOffset + channel;
- outputVals[outputOffset] =
- poolType === 'avg' ? avgValue / count : minMaxValue;
- }
- }
- }
- }
- }
- return output.toTensor();
- }
- avgPool3d(x, convInfo) {
- assertNotComplex(x, 'avgPool3d');
- return this.pool3d(x, convInfo, 'avg').toFloat();
- }
- avgPool3dBackprop(dy, x, convInfo) {
- assertNotComplex([dy, x], 'avgPool3dBackprop');
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const dx = buffer(x.shape, 'float32');
- const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
- const dyBuf = this.bufferSync(dy);
- for (let batch = 0; batch < convInfo.batchSize; ++batch) {
- for (let channel = 0; channel < convInfo.inChannels; ++channel) {
- for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
- for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
- for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
- // Shader code begins.
- const dyDepthCorner = dxDepth - padFront;
- const dyRowCorner = dxRow - padTop;
- const dyColCorner = dxCol - padLeft;
- let dotProd = 0;
- for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
- const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
- if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
- Math.floor(dyDepth) !== dyDepth) {
- continue;
- }
- for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
- const dyRow = (dyRowCorner + wRow) / strideHeight;
- if (dyRow < 0 || dyRow >= convInfo.outHeight ||
- Math.floor(dyRow) !== dyRow) {
- continue;
- }
- for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
- const dyCol = (dyColCorner + wCol) / strideWidth;
- if (dyCol < 0 || dyCol >= convInfo.outWidth ||
- Math.floor(dyCol) !== dyCol) {
- continue;
- }
- const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
- dotProd += pixel;
- }
- }
- }
- dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
- }
- }
- }
- }
- }
- return dx.toTensor();
- }
- maxPool3d(x, convInfo) {
- assertNotComplex(x, 'maxPool3d');
- return this.pool3d(x, convInfo, 'max').toFloat();
- }
- maxPool3dPositions(x, convInfo) {
- const maxPositions = buffer(convInfo.outShape, 'int32');
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = convInfo.padInfo.front;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const xBuf = this.bufferSync(x);
- for (let batch = 0; batch < convInfo.batchSize; ++batch) {
- for (let channel = 0; channel < convInfo.inChannels; ++channel) {
- for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
- const xDepthCorner = yDepth * strideDepth - padFront;
- let xDepthMin = xDepthCorner;
- while (xDepthMin < 0) {
- xDepthMin += dilationDepth;
- }
- const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
- for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
- const xRowCorner = yRow * strideHeight - padTop;
- let xRowMin = xRowCorner;
- while (xRowMin < 0) {
- xRowMin += dilationHeight;
- }
- const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
- for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
- const xColCorner = yCol * strideWidth - padLeft;
- let xColMin = xColCorner;
- while (xColMin < 0) {
- xColMin += dilationWidth;
- }
- const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
- // Shader code begins
- let maxValue = Number.NEGATIVE_INFINITY;
- let maxPosition = -1;
- for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
- const wDepth = xDepth - xDepthCorner;
- for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
- const wRow = xRow - xRowCorner;
- for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
- const wCol = xCol - xColCorner;
- const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
- if (pixel >= maxValue) {
- maxValue = pixel;
- maxPosition = wDepth * effectiveFilterHeight *
- effectiveFilterWidth +
- wRow * effectiveFilterHeight + wCol;
- }
- }
- }
- }
- maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
- }
- }
- }
- }
- }
- return maxPositions.toTensor();
- }
- maxPool3dBackprop(dy, x, y, convInfo) {
- assertNotComplex([x, y], 'maxPool3dBackprop');
- const maxPositions = this.maxPool3dPositions(x, convInfo);
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const dx = buffer(x.shape, 'float32');
- const maxPosBuf = this.bufferSync(maxPositions);
- const dyBuf = this.bufferSync(dy);
- for (let batch = 0; batch < convInfo.batchSize; ++batch) {
- for (let channel = 0; channel < convInfo.inChannels; ++channel) {
- for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
- for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
- for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
- // Shader code begins
- const dyDepthCorner = dxDepth - padFront;
- const dyRowCorner = dxRow - padTop;
- const dyColCorner = dxCol - padLeft;
- let dotProd = 0;
- for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
- const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
- if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
- Math.floor(dyDepth) !== dyDepth) {
- continue;
- }
- for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
- const dyRow = (dyRowCorner + wRow) / strideHeight;
- if (dyRow < 0 || dyRow >= convInfo.outHeight ||
- Math.floor(dyRow) !== dyRow) {
- continue;
- }
- for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
- const dyCol = (dyColCorner + wCol) / strideWidth;
- if (dyCol < 0 || dyCol >= convInfo.outWidth ||
- Math.floor(dyCol) !== dyCol) {
- continue;
- }
- const maxPos = effectiveFilterDepth *
- effectiveFilterHeight * effectiveFilterWidth -
- 1 -
- maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
- const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
- wRow * effectiveFilterWidth + wCol;
- const mask = maxPos === curPos ? 1 : 0;
- if (mask === 0) {
- continue;
- }
- const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
- dotProd += pixel * mask;
- }
- }
- }
- dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
- }
- }
- }
- }
- }
- return dx.toTensor();
- }
- resizeBilinear(x, newHeight, newWidth, alignCorners) {
- assertNotComplex(x, 'resizeBilinear');
- const [batch, oldHeight, oldWidth, numChannels] = x.shape;
- const xValues = this.readSync(x.dataId);
- const result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
- const effectiveInputSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- const effectiveOutputSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- let outputIdx = 0;
- const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
- const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
- for (let b = 0; b < batch; b++) {
- for (let r = 0; r < newHeight; r++) {
- const sourceFracRow = effectiveRowSizeRatio * r;
- const sourceRowFloor = Math.floor(sourceFracRow);
- const rowFrac = sourceFracRow - sourceRowFloor;
- const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
- const topRowOffset = b * x.strides[0] + sourceRowFloor * x.strides[1];
- const botRowOffset = b * x.strides[0] + sourceRowCeil * x.strides[1];
- for (let c = 0; c < newWidth; c++) {
- const sourceFracCol = effectiveColSizeRatio * c;
- const sourceColFloor = Math.floor(sourceFracCol);
- const colFrac = sourceFracCol - sourceColFloor;
- const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
- const topLeftOffest = topRowOffset + sourceColFloor * x.strides[2];
- const botLeftOffset = botRowOffset + sourceColFloor * x.strides[2];
- const topRightOffset = topRowOffset + sourceColCeil * x.strides[2];
- const botRightOffest = botRowOffset + sourceColCeil * x.strides[2];
- for (let d = 0; d < numChannels; d++) {
- // Begin shader.
- // Compute the fractional index of the source.
- const topLeft = xValues[topLeftOffest + d];
- const bottomLeft = xValues[botLeftOffset + d];
- const topRight = xValues[topRightOffset + d];
- const bottomRight = xValues[botRightOffest + d];
- const top = topLeft + (topRight - topLeft) * colFrac;
- const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
- const newValue = top + (bottom - top) * rowFrac;
- result[outputIdx++] = newValue;
- }
- }
- }
- }
- return tensor(result, [batch, newHeight, newWidth, numChannels]);
- }
- resizeBilinearBackprop(dy, x, alignCorners) {
- assertNotComplex([dy, x], 'resizeBilinearBackprop');
- const [batch, xHeight, xWidth, depth] = x.shape;
- const [, yHeight, yWidth] = dy.shape;
- const output = new Float32Array(batch * xHeight * xWidth * depth);
- // In the backwards pass, we want to find the pixels that were generated
- // for each pixel in the input image the forward pass and add the
- // corresponding coefficient from dy to the gradient (with some
- // interpolation).
- const effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- const effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- const heightScale = effectiveXSize[0] / effectiveYSize[0];
- const widthScale = effectiveXSize[1] / effectiveYSize[1];
- // Reference implementation
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
- const dyValues = this.readSync(dy.dataId);
- let offset = 0;
- for (let b = 0; b < batch; b++) {
- const bOffset = b * x.strides[0];
- for (let r = 0; r < yHeight; r++) {
- const dxR = r * heightScale;
- const topDxRIndex = Math.floor(dxR);
- const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
- const topDxROffset = bOffset + topDxRIndex * x.strides[1];
- const bottomDxROffset = bOffset + bottomDxRIndex * x.strides[1];
- const dxRLerp = dxR - topDxRIndex;
- const inverseDxRLerp = 1.0 - dxRLerp;
- for (let c = 0; c < yWidth; c++) {
- const dxC = c * widthScale;
- const leftDxCIndex = Math.floor(dxC);
- const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
- const dxCLerp = dxC - leftDxCIndex;
- const inverseDxCLerp = 1.0 - dxCLerp;
- const topLeftRCOffset = topDxROffset + leftDxCIndex * x.strides[2];
- const topRightRCOffset = topDxROffset + rightDxCIndex * x.strides[2];
- const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * x.strides[2];
- const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * x.strides[2];
- const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
- const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
- const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
- const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
- for (let d = 0; d < depth; d++) {
- const dyVal = dyValues[offset++];
- output[topLeftRCOffset + d] +=
- dyVal * inverseDxRLerpTimesInverseDxCLerp;
- output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
- output[bottomLeftRCOffset + d] +=
- dyVal * dxRLerpTimesInverseDxCLerp;
- output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
- }
- }
- }
- }
- return tensor4d(output, [batch, xWidth, xHeight, depth], x.dtype);
- }
- resizeNearestNeighbor(x, newHeight, newWidth, alignCorners) {
- assertNotComplex(x, 'resizeNearestNeighbor');
- const [batch, oldHeight, oldWidth, numChannels] = x.shape;
- const xValues = this.readSync(x.dataId);
- const output = new Float32Array(batch * newHeight * newWidth * numChannels);
- const effectiveInputSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- const effectiveOutputSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
- const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
- let outputOffset = 0;
- for (let b = 0; b < batch; b++) {
- const batchOffset = b * x.strides[0];
- for (let r = 0; r < newHeight; r++) {
- const sourceFracRow = effectiveRowSizeRatio * r;
- const sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) :
- Math.floor(sourceFracRow));
- const rowOffset = batchOffset + sourceNearestRow * x.strides[1];
- for (let c = 0; c < newWidth; c++) {
- const sourceFracCol = effectiveColSizeRatio * c;
- const sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
- Math.floor(sourceFracCol));
- const colOffset = rowOffset + sourceNearestCol * x.strides[2];
- for (let d = 0; d < numChannels; d++) {
- // Begin shader.
- // Compute the fractional index of the source.
- const newVal = xValues[colOffset + d];
- output[outputOffset++] = newVal;
- }
- }
- }
- }
- return tensor(output, [batch, newHeight, newWidth, numChannels], x.dtype);
- }
- resizeNearestNeighborBackprop(dy, x, alignCorners) {
- assertNotComplex([dy, x], 'resizeNearestNeighborBackprop');
- const [batch, xHeight, xWidth, depth] = x.shape;
- const [, yHeight, yWidth] = dy.shape;
- const output = new Float32Array(batch * xHeight * xWidth * depth);
- const dyValues = this.readSync(dy.dataId);
- // In the backwards pass, we want to find the pixels that were generated
- // for each pixel in the input image the forward pass
- const effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- const effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- const heightScale = effectiveXSize[0] / effectiveYSize[0];
- const widthScale = effectiveXSize[1] / effectiveYSize[1];
- const invHeightScale = 1 / heightScale;
- const invWidthScale = 1 / widthScale;
- // This defines the size of the window of values around a particular
- // index in dy that we want to search for contributions to dx.
- const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
- const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
- // Loop over the output space.
- for (let b = 0; b < batch; b++) {
- const batchOffset = b * x.strides[0];
- for (let r = 0; r < xHeight; r++) {
- const rowOffset = batchOffset + r * x.strides[1];
- // Compute bounds for where in dy we will look
- const startRLerp = Math.floor(r * invHeightScale);
- const startDyR = Math.floor(startRLerp - (winHeight / 2));
- for (let c = 0; c < xWidth; c++) {
- const colOffset = rowOffset + c * x.strides[2];
- // Compute bounds for where in dy we will look
- const startCLerp = Math.floor(c * invWidthScale);
- const startDyC = Math.floor(startCLerp - (winWidth / 2));
- for (let d = 0; d < depth; d++) {
- let accum = 0;
- // loop over dy
- for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
- const dyR = dyRIndex + startDyR;
- // Guard against the window exceeding the bounds of dy
- if (dyR < 0 || dyR >= yHeight) {
- continue;
- }
- const dyROffset = batchOffset + dyR * dy.strides[1];
- const sourceFracRow = dyR * heightScale;
- const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
- Math.floor(sourceFracRow));
- if (r !== sourceNearestRow) {
- continue;
- }
- for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
- const dyC = dyCIndex + startDyC;
- // Guard against the window exceeding the bounds of dy
- if (dyC < 0 || dyC >= yWidth) {
- continue;
- }
- const dyCOffset = dyROffset + dyC * dy.strides[2];
- const sourceFracCol = dyC * widthScale;
- const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
- Math.floor(sourceFracCol));
- if (c === sourceNearestCol) {
- accum += dyValues[dyCOffset + d];
- }
- }
- }
- output[colOffset + d] = accum;
- }
- }
- }
- }
- return tensor4d(output, x.shape, x.dtype);
- }
- localResponseNormalization4D(x, depthRadius, bias, alpha, beta) {
- assertNotComplex(x, 'localResponseNormalization4D');
- const channels = x.shape[3];
- const maxD = channels - 1;
- const xValues = this.readSync(x.dataId);
- const size = x.size;
- const result = new Float32Array(size);
- function sumAcrossChannels(offset) {
- const currentChannel = offset % channels;
- let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
- const endSumOffset = offset - currentChannel +
- Math.min(currentChannel + depthRadius, maxD);
- let sum = 0.0;
- for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
- const z = xValues[beginSumOffset];
- sum += z * z;
- }
- return sum;
- }
- for (let offset = 0; offset < size; offset++) {
- const sum = sumAcrossChannels(offset);
- const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
- result[offset] = val;
- }
- return tensor4d(result, x.shape);
- }
- LRNGrad(dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
- assertNotComplex(dy, 'LRNGrad');
- const channels = dy.shape[3];
- const dyValues = this.readSync(dy.dataId);
- const inputImageValues = this.readSync(inputImage.dataId);
- const outputImageValues = this.readSync(outputImage.dataId);
- const result = new Float32Array(dy.size);
- const size = dy.size;
- for (let offset = 0; offset < size; offset++) {
- const currentChannel = offset % channels;
- const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
- const depthEnd = (offset - currentChannel) +
- Math.min(channels, currentChannel + depthRadius + 1);
- let norm = 0;
- for (let k = depthBegin; k < depthEnd; k++) {
- norm += Math.pow(inputImageValues[k], 2);
- }
- norm = alpha * norm + bias;
- for (let k = depthBegin; k < depthEnd; k++) {
- let dyi = -2 * alpha * beta * inputImageValues[k] *
- outputImageValues[offset] / norm;
- if (offset === k) {
- dyi += Math.pow(norm, -beta);
- }
- dyi *= dyValues[offset];
- result[k] += dyi;
- }
- }
- return tensor4d(result, dy.shape);
- }
- multinomial(logits, normalized, numSamples, seed) {
- assertNotComplex(logits, 'multinomial');
- const probabilities = normalized ? logits : softmax(logits);
- const batchSize = probabilities.shape[0];
- const numEvents = probabilities.shape[1];
- const res = zeros([batchSize, numSamples], 'int32');
- const resVals = this.readSync(res.dataId);
- const probVals = this.readSync(probabilities.dataId);
- for (let b = 0; b < batchSize; ++b) {
- const offset = b * numEvents;
- // The cdf won't include the last event. It will be implicit if no other
- // event happened.
- const cdf = new Float32Array(numEvents - 1);
- cdf[0] = probVals[offset];
- for (let event = 1; event < cdf.length; ++event) {
- cdf[event] = cdf[event - 1] + probVals[offset + event];
- }
- const random = seedrandom_1(seed.toString());
- const outOffset = b * numSamples;
- for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
- const r = random();
- // Assume last event happened by default.
- resVals[outOffset + sampleId] = cdf.length;
- for (let event = 0; event < cdf.length; event++) {
- if (r < cdf[event]) {
- resVals[outOffset + sampleId] = event;
- break;
- }
- }
- }
- }
- return res;
- }
- oneHot(indices, depth, onValue, offValue) {
- assertNotComplex(indices, 'oneHot');
- const res = new Float32Array(indices.size * depth);
- res.fill(offValue);
- const indicesVal = this.readSync(indices.dataId);
- for (let event = 0; event < indices.size; ++event) {
- if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
- res[event * depth + indicesVal[event]] = onValue;
- }
- }
- return tensor2d(res, [indices.size, depth], 'int32');
- }
- nonMaxSuppression(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
- assertNotComplex(boxes, 'nonMaxSuppression');
- const boxesVals = this.readSync(boxes.dataId);
- const scoresVals = this.readSync(scores.dataId);
- return nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
- }
- depthToSpace(x, blockSize, dataFormat) {
- assert(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
- assert(blockSize > 1, () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);
- const batchSize = x.shape[0];
- const inputHeight = x.shape[1];
- const inputWidth = x.shape[2];
- const inputDepth = x.shape[3];
- const outputHeight = inputHeight * blockSize;
- const outputWidth = inputWidth * blockSize;
- const outputDepth = inputDepth / (blockSize * blockSize);
- const xValues = this.readSync(x.dataId);
- const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
- let outputIdx = 0;
- for (let b = 0; b < batchSize; ++b) {
- for (let h = 0; h < outputHeight; ++h) {
- const inH = Math.floor(h / blockSize);
- const offsetH = (h % blockSize);
- for (let w = 0; w < outputWidth; ++w) {
- const inW = Math.floor(w / blockSize);
- const offsetW = (w % blockSize);
- const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
- for (let d = 0; d < outputDepth; ++d) {
- const inD = d + offsetD;
- const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
- result[outputIdx++] = xValues[inputIdx];
- }
- }
- }
- }
- return tensor4d(result, [batchSize, outputHeight, outputWidth, outputDepth]);
- }
- broadcastedBinaryOp(a, b, dtype, op) {
- const newShape = assertAndGetBroadcastShape(a.shape, b.shape);
- const result = buffer(newShape, dtype);
- const aVals = this.readSync(a.dataId);
- const bVals = this.readSync(b.dataId);
- const aBroadcastDims = getBroadcastDims(a.shape, newShape);
- const bBroadcastDims = getBroadcastDims(b.shape, newShape);
- const resVals = result.values;
- if (aBroadcastDims.length + bBroadcastDims.length === 0) {
- for (let i = 0; i < resVals.length; ++i) {
- resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
- }
- }
- else {
- const aBuf = this.bufferSync(a);
- const bBuf = this.bufferSync(b);
- for (let i = 0; i < resVals.length; ++i) {
- const loc = result.indexToLoc(i);
- const aLoc = loc.slice(-a.rank);
- aBroadcastDims.forEach(d => aLoc[d] = 0);
- const aIndex = aBuf.locToIndex(aLoc);
- const bLoc = loc.slice(-b.rank);
- bBroadcastDims.forEach(d => bLoc[d] = 0);
- const bIndex = bBuf.locToIndex(bLoc);
- resVals[i] = op(aVals[aIndex], bVals[bIndex]);
- }
- }
- return result.toTensor();
- }
- split(x, sizeSplits, axis) {
- return split$4(x, sizeSplits, axis);
- }
- dispose() { }
- floatPrecision() {
- return 32;
- }
- /** Returns the smallest representable number. */
- epsilon() {
- return super.epsilon();
- }
- cropAndResize(images, boxes, boxIndex, cropSize, method, extrapolationValue) {
- const [batch, imageHeight, imageWidth, numChannels] = images.shape;
- const numBoxes = boxes.shape[0];
- const [cropHeight, cropWidth] = cropSize;
- const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
- const boxVals = this.readSync(boxes.dataId);
- const boxIndVals = this.readSync(boxIndex.dataId);
- const imageVals = this.readSync(images.dataId);
- const inStride = images.strides; // to calculate flat indexes into image
- const outStride = output.strides; // to calculate flat indexes into output
- // Reference implementation
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
- for (let b = 0; b < numBoxes; b++) {
- const startInd = b * 4;
- const y1 = boxVals[startInd];
- const x1 = boxVals[startInd + 1];
- const y2 = boxVals[startInd + 2];
- const x2 = boxVals[startInd + 3];
- const bInd = boxIndVals[b];
- if (bInd >= batch) {
- continue;
- }
- const heightScale = (cropHeight > 1) ?
- (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) :
- 0;
- const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
- for (let y = 0; y < cropHeight; y++) {
- const yInd = (cropHeight > 1) ?
- y1 * (imageHeight - 1) + y * (heightScale) :
- 0.5 * (y1 + y2) * (imageHeight - 1);
- if (yInd < 0 || yInd > imageHeight - 1) {
- for (let x = 0; x < cropWidth; x++) {
- for (let c = 0; c < numChannels; c++) {
- const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
- output.values[ind] = extrapolationValue;
- }
- }
- continue;
- }
- if (method === 'bilinear') {
- const topInd = Math.floor(yInd);
- const bottomInd = Math.ceil(yInd);
- const yLerp = yInd - topInd;
- for (let x = 0; x < cropWidth; x++) {
- const xInd = (cropWidth > 1) ?
- x1 * (imageWidth - 1) + x * widthScale :
- 0.5 * (x1 + x2) * (imageWidth - 1);
- if (xInd < 0 || xInd > imageWidth - 1) {
- for (let c = 0; c < numChannels; c++) {
- const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
- output.values[ind] = extrapolationValue;
- }
- continue;
- }
- const leftInd = Math.floor(xInd);
- const rightInd = Math.ceil(xInd);
- const xLerp = xInd - leftInd;
- for (let c = 0; c < numChannels; c++) {
- let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
- bInd * inStride[0];
- const topLeft = imageVals[ind];
- ind = c + rightInd * inStride[2] + topInd * inStride[1] +
- bInd * inStride[0];
- const topRight = imageVals[ind];
- ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
- bInd * inStride[0];
- const bottomLeft = imageVals[ind];
- ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
- bInd * inStride[0];
- const bottomRight = imageVals[ind];
- const top = topLeft + (topRight - topLeft) * xLerp;
- const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
- ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
- output.values[ind] = top + ((bottom - top) * yLerp);
- }
- }
- }
- else { // method == "nearest"
- for (let x = 0; x < cropWidth; ++x) {
- const xInd = (cropWidth > 1) ?
- x1 * (imageWidth - 1) + x * widthScale :
- 0.5 * (x1 + x2) * (imageWidth - 1);
- if (xInd < 0 || xInd > imageWidth - 1) {
- for (let c = 0; c < numChannels; c++) {
- const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
- output.values[ind] = extrapolationValue;
- }
- continue;
- }
- const closestX = Math.round(xInd);
- const closestY = Math.round(yInd);
- for (let c = 0; c < numChannels; c++) {
- const inInd = c + closestX * inStride[2] +
- closestY * inStride[1] + bInd * inStride[0];
- const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
- output.values[outInd] = imageVals[inInd];
- }
- }
- }
- }
- }
- return output.toTensor();
- }
- sparseToDense(sparseIndices, sparseValues, outputShape, defaultValue) {
- const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
- const sumDupeIndices = false;
- return this.scatter(sparseIndices, sparseValues, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices);
- }
- gatherND(x, indices) {
- const indicesShape = indices.shape;
- const sliceRank = indicesShape[indicesShape.length - 1];
- const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(x, indices);
- if (numSlices === 0) {
- return tensor([], resultShape, x.dtype);
- }
- const buffer = new TensorBuffer([numSlices, sliceSize], x.dtype);
- const indicesData = this.readSync(indices.dataId);
- const xData = this.readSync(x.dataId);
- for (let i = 0; i < numSlices; i++) {
- const index = [];
- let flattenIndex = 0;
- for (let j = 0; j < sliceRank; j++) {
- const dim = indicesData[i * sliceRank + j];
- flattenIndex += dim * strides[j];
- index.push(dim);
- }
- if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) {
- throw new Error(`Invalid indices: ${index} does not index into ${x.shape}`);
- }
- for (let k = 0; k < sliceSize; k++) {
- buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k];
- }
- }
- return buffer.toTensor().reshape(resultShape);
- }
- scatterND(indices, updates, shape) {
- const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
- const defaultValue = scalar(0);
- const sumDupeIndices = true;
- return this.scatter(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices);
- }
- fill(shape, value, dtype) {
- dtype = dtype || inferDtype(value);
- const values = getArrayFromDType(dtype, sizeFromShape(shape));
- values.fill(value);
- return engine().makeTensor(values, shape, dtype, this);
- }
- onesLike(x) {
- if (x.dtype === 'string') {
- throw new Error('onesLike is not supported for string tensors');
- }
- else {
- return this.fill(x.shape, 1, x.dtype);
- }
- }
- zerosLike(x) {
- const values = getArrayFromDType(x.dtype, sizeFromShape(x.shape));
- return this.makeOutput(values, x.shape, x.dtype);
- }
- linspace(start, stop, num) {
- return linspaceImpl(start, stop, num);
- }
- scatter(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
- const flattenShape = [outputSize / sliceSize, sliceSize];
- const indicesData = this.readSync(indices.dataId);
- const updatesData = this.readSync(updates.dataId);
- if (outputSize === 0) {
- return tensor([], shape, updates.dtype);
- }
- const buffer = new TensorBuffer(flattenShape, updates.dtype);
- buffer.values.fill(this.readSync(defaultValue.dataId)[0]);
- for (let i = 0; i < numUpdates; i++) {
- const index = [];
- let flattenIndex = 0;
- for (let j = 0; j < sliceRank; j++) {
- const dim = indicesData[i * sliceRank + j];
- index.push(dim);
- flattenIndex += dim * strides[j];
- }
- if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
- throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
- }
- for (let k = 0; k < sliceSize; k++) {
- if (sumDupeIndices) {
- buffer.values[flattenIndex * sliceSize + k] +=
- updatesData[i * sliceSize + k];
- }
- else {
- buffer.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
- updatesData[0] :
- updatesData[i * sliceSize + k];
- }
- }
- }
- return buffer.toTensor().reshape(shape);
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function simpleAbsImpl(vals) {
- const resultValues = new Float32Array(vals.length);
- for (let i = 0; i < vals.length; ++i) {
- resultValues[i] = Math.abs(vals[i]);
- }
- return resultValues;
- }
- const abs$1 = (args) => {
- const { x } = args.inputs;
- const cpuBackend = args.backend;
- let resultValues = new Float32Array(sizeFromShape(x.shape));
- if (x.dtype !== 'complex64') {
- const values = cpuBackend.data.get(x.dataId).values;
- resultValues = simpleAbsImpl(values);
- }
- else {
- const complexVals = cpuBackend.data.get(x.dataId);
- const real = complexVals.complexTensorInfos.real;
- const imag = complexVals.complexTensorInfos.imag;
- const realVals = cpuBackend.data.get(real.dataId).values;
- const imagVals = cpuBackend.data.get(imag.dataId).values;
- for (let i = 0; i < realVals.length; i++) {
- const real = realVals[i];
- const imag = imagVals[i];
- resultValues[i] = Math.hypot(real, imag);
- }
- }
- return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
- };
- const absConfig = {
- kernelName: Abs,
- backendName: 'cpu',
- kernelFunc: abs$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Template that creates implementation for binary ops. Supports broadcast.
- */
- function createSimpleBinaryKernelImpl(op) {
- return (aShape, bShape, aVals, bVals, dtype) => {
- const newShape = assertAndGetBroadcastShape(aShape, bShape);
- const resultRank = newShape.length;
- const resultStrides = computeStrides(newShape);
- const resultSize = sizeFromShape(newShape);
- const result = getTypedArrayFromDType(dtype, resultSize);
- const aRank = aShape.length;
- const bRank = bShape.length;
- const aStrides = computeStrides(aShape);
- const bStrides = computeStrides(bShape);
- const aBroadcastDims = getBroadcastDims(aShape, newShape);
- const bBroadcastDims = getBroadcastDims(bShape, newShape);
- if (aBroadcastDims.length + bBroadcastDims.length === 0) {
- for (let i = 0; i < result.length; ++i) {
- result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
- }
- }
- else {
- for (let i = 0; i < result.length; ++i) {
- const loc = indexToLoc(i, resultRank, resultStrides);
- const aLoc = loc.slice(-aRank);
- aBroadcastDims.forEach(d => aLoc[d] = 0);
- const aIndex = locToIndex(aLoc, aRank, aStrides);
- const bLoc = loc.slice(-bRank);
- bBroadcastDims.forEach(d => bLoc[d] = 0);
- const bIndex = locToIndex(bLoc, bRank, bStrides);
- result[i] = op(aVals[aIndex], bVals[bIndex]);
- }
- }
- return [result, newShape];
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function complex$1(args) {
- const { inputs, backend } = args;
- const { real, imag } = inputs;
- const realVals = backend.data.get(real.dataId).values;
- const imagVals = backend.data.get(imag.dataId).values;
- const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
- const complex = backend.data.get(complexInfo.dataId);
- // The complex tensor owns the underlying real and imag tensorInfos, only the
- // complex tensor tracks refCount, when complexData is disposed the
- // underlying tensorData will be disposed.
- complex.complexTensorInfos = {
- real: backend.makeTensorInfo(real.shape, 'float32', realVals),
- imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
- };
- return complexInfo;
- }
- const complexConfig = {
- kernelName: Complex,
- backendName: 'cpu',
- kernelFunc: complex$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function identity$1(args) {
- const { inputs, backend } = args;
- const { x } = inputs;
- backend.incRef(x.dataId);
- return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
- }
- const identityConfig = {
- kernelName: Identity,
- backendName: 'cpu',
- kernelFunc: identity$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function real$1(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const real = backend.data.get(input.dataId).complexTensorInfos.real;
- const realVal = backend.data.get(real.dataId).values;
- // When complex tensor is disposed, its underlying parts will be disposed too.
- // Make new tensor out of the real value of the complex. This makes sure the
- // value is still accessible even if complex tensor is disposed.
- return backend.makeTensorInfo(real.shape, real.dtype, realVal);
- }
- const realConfig = {
- kernelName: Real,
- backendName: 'cpu',
- kernelFunc: real$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function cast$2(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { dtype } = attrs;
- // Casting to complex64.
- if (dtype === 'complex64') {
- if (x.dtype === 'complex64') {
- return identity$1({ inputs: { x }, backend });
- }
- // TODO(lina128): Import kernel function once zeros is modularized.
- const zerosTensor = zeros(x.shape);
- const floatX = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
- const result = complex$1({ inputs: { real: floatX, imag: zerosTensor }, backend });
- zerosTensor.dispose();
- backend.disposeIntermediateTensorInfo(floatX);
- return result;
- }
- // Casting from complex64
- if (x.dtype === 'complex64') {
- const realPart = real$1({ inputs: { input: x }, backend });
- const result = cast$2({ inputs: { x: realPart }, backend, attrs: { dtype } });
- backend.disposeIntermediateTensorInfo(realPart);
- return result;
- }
- if (!hasEncodingLoss(x.dtype, dtype)) {
- // We don't change the underlying data, since we cast to higher
- // precision.
- const result = identity$1({ inputs: { x }, backend });
- return { dataId: result.dataId, shape: result.shape, dtype };
- }
- if (dtype === 'int32') {
- const values = backend.data.get(x.dataId).values;
- const resultValues = Int32Array.from(values);
- return backend.makeTensorInfo(x.shape, 'int32', resultValues);
- }
- if (dtype === 'bool') {
- // This is essentially the result of notEqual(x, 0). We avoid using
- // kernel notEqual to avoid circular dependency, i.e. binary_utils ->
- // cast -> notEqual -> binary_utils.
- const xVals = backend.data.get(x.dataId).values;
- const zero = toTypedArray([0], x.dtype);
- const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(x.shape, [], xVals, zero, 'bool');
- return backend.makeTensorInfo(resultShape, 'bool', resultData);
- }
- throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
- }
- const castConfig = {
- kernelName: Cast,
- backendName: 'cpu',
- kernelFunc: cast$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Template that creates a `KernelFunc` for binary ops.
- * @param name Kernel name.
- * @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
- * @param binaryKernelComplexImpl Optional. If exists, represents a
- * `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
- * is `complex64`.
- * @param dtype Optional. If set, the result has this dtype. Otherwise, the
- * result has the same dtype as the first input. This is mainly used in
- * comparison kernels, such as Equal, Less, Greater, etc.
- */
- function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
- if (complexImpl == null) {
- return ({ inputs, backend }) => {
- const { a, b } = inputs;
- const cpuBackend = backend;
- assertNotComplex([a, b], name);
- const aVals = cpuBackend.data.get(a.dataId).values;
- const bVals = cpuBackend.data.get(b.dataId).values;
- const $dtype = dtype || a.dtype;
- const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
- return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
- };
- }
- return ({ inputs, backend }) => {
- const { a, b } = inputs;
- const cpuBackend = backend;
- if (a.dtype === 'complex64' || b.dtype === 'complex64') {
- const $aComplex = cast$2({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
- const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
- const aReal = $aComplexVals.complexTensorInfos.real;
- const aImag = $aComplexVals.complexTensorInfos.imag;
- const aRealVals = cpuBackend.data.get(aReal.dataId).values;
- const aImagVals = cpuBackend.data.get(aImag.dataId).values;
- const $bComplex = cast$2({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
- const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
- const bReal = $bComplexVals.complexTensorInfos.real;
- const bImag = $bComplexVals.complexTensorInfos.imag;
- const bRealVals = cpuBackend.data.get(bReal.dataId).values;
- const bImagVals = cpuBackend.data.get(bImag.dataId).values;
- const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
- const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
- const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
- const result = complex$1({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
- cpuBackend.disposeIntermediateTensorInfo($aComplex);
- cpuBackend.disposeIntermediateTensorInfo($bComplex);
- cpuBackend.disposeIntermediateTensorInfo(resultReal);
- cpuBackend.disposeIntermediateTensorInfo(resultImag);
- return result;
- }
- else {
- const aVals = cpuBackend.data.get(a.dataId).values;
- const bVals = cpuBackend.data.get(b.dataId).values;
- const $dtype = dtype || a.dtype;
- const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
- return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
- }
- };
- }
- /**
- * Template that creates the complex type implementation for binary ops.
- * Supports broadcast.
- */
- function createComplexBinaryKernelImpl(op) {
- return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
- const resultShape = assertAndGetBroadcastShape(aShape, bShape);
- const resultSize = sizeFromShape(resultShape);
- const resultRank = resultShape.length;
- const resultStrides = computeStrides(resultShape);
- const resultRealVals = getTypedArrayFromDType('float32', resultSize);
- const resultImagVals = getTypedArrayFromDType('float32', resultSize);
- const aBroadcastDims = getBroadcastDims(aShape, resultShape);
- const bBroadcastDims = getBroadcastDims(bShape, resultShape);
- const aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
- const bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
- const aRank = aShape.length;
- const aStrides = computeStrides(aShape);
- const bRank = bShape.length;
- const bStrides = computeStrides(bShape);
- if (aBroadcastDims.length + bBroadcastDims.length === 0) {
- for (let i = 0; i < resultRealVals.length; i++) {
- const aIdx = i % aVals.length;
- const bIdx = i % bVals.length;
- const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
- resultRealVals[i] = result.real;
- resultImagVals[i] = result.imag;
- }
- }
- else {
- for (let i = 0; i < resultRealVals.length; i++) {
- const loc = indexToLoc(i, resultRank, resultStrides);
- const aLoc = loc.slice(-aRank);
- aBroadcastDims.forEach(d => aLoc[d] = 0);
- const aIndex = locToIndex(aLoc, aRank, aStrides);
- const bLoc = loc.slice(-bRank);
- bBroadcastDims.forEach(d => bLoc[d] = 0);
- const bIndex = locToIndex(bLoc, bRank, bStrides);
- const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
- resultRealVals[i] = opResult.real;
- resultImagVals[i] = opResult.imag;
- }
- }
- return [resultRealVals, resultImagVals, resultShape];
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
- const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
- return { real: aReal + bReal, imag: aImag + bImag };
- }));
- const add$4 = binaryKernelFunc(Add, addImpl, addComplexImpl);
- const addConfig = {
- kernelName: Add,
- backendName: 'cpu',
- kernelFunc: add$4
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Template that creates implementation for unary op.
- */
- function createSimpleUnaryImpl(op) {
- return (values, dtype, attrs) => {
- const newValues = getTypedArrayFromDType(dtype, values.length);
- for (let i = 0; i < values.length; ++i) {
- newValues[i] = op(values[i], attrs);
- }
- return newValues;
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Template that creates a `KernelFunc` for unary ops.
- * @param name Kernel name.
- * @param op A `SimpleUnaryOperation` for the kernel.
- * @param dtype Optional. If set, the result has this dtype. Otherwise, the
- * result has the same dtype as the input. This is mainly used in certain
- * kernels that return bool type, such as isFinite, isInf, etc.
- */
- function unaryKernelFunc(name, op, dtype) {
- return ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- assertNotComplex(x, name);
- if (x.dtype === 'string' || dtype === 'string') {
- throw new Error('unaryKernelFunc does not support string input/output');
- }
- const cpuBackend = backend;
- const values = cpuBackend.data.get(x.dataId).values;
- const xSize = sizeFromShape(x.shape);
- const $dtype = dtype || x.dtype;
- const newValues = getArrayFromDType($dtype, xSize);
- for (let i = 0; i < xSize; ++i) {
- newValues[i] = op(values[i], attrs);
- }
- return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
- };
- }
- /**
- * Template that creates a `KernelFunc` for unary ops from the given
- * `SimpleUnaryImpl`..
- * @param name Kernel name.
- * @param unaryImpl A `SimpleUnaryImpl` that implements the op.
- * @param dtype Optional. If set, the result has this dtype. Otherwise, the
- * result has the same dtype as the input. This is mainly used in certain
- * kernels that return bool type, such as isFinite, isInf, etc.
- */
- function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
- return ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- assertNotComplex(x, name);
- if (x.dtype === 'string' || dtype === 'string') {
- throw new Error('unaryKernelFunc does not support string input/output');
- }
- const cpuBackend = backend;
- const values = cpuBackend.data.get(x.dataId).values;
- const $dtype = dtype || x.dtype;
- const newValues = unaryImpl(values, $dtype, attrs);
- return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
- const ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
- const ceilConfig = {
- kernelName: Ceil,
- backendName: 'cpu',
- kernelFunc: ceil$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
- const exp$1 = unaryKernelFuncFromImpl(Exp, expImpl);
- const expConfig = {
- kernelName: Exp,
- backendName: 'cpu',
- kernelFunc: exp$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
- const expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
- const expm1Config = {
- kernelName: Expm1,
- backendName: 'cpu',
- kernelFunc: expm1$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
- const floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
- const floorConfig = {
- kernelName: Floor,
- backendName: 'cpu',
- kernelFunc: floor$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
- const log$2 = unaryKernelFuncFromImpl(Log, logImpl);
- const logConfig = {
- kernelName: Log,
- backendName: 'cpu',
- kernelFunc: log$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxImpl(aVals, reduceSize, outShape, dtype) {
- const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
- for (let i = 0; i < vals.length; ++i) {
- const offset = i * reduceSize;
- let max = aVals[offset];
- for (let j = 0; j < reduceSize; ++j) {
- const value = aVals[offset + j];
- if (value > max) {
- max = value;
- }
- }
- vals[i] = max;
- }
- return vals;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
- const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
- return {
- real: aReal * bReal - aImag * bImag,
- imag: aReal * bImag + aImag * bReal
- };
- }));
- const multiply$2 = binaryKernelFunc(Multiply, multiplyImpl, multiplyComplexImpl);
- const multiplyConfig = {
- kernelName: Multiply,
- backendName: 'cpu',
- kernelFunc: multiply$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
- const notEqual$1 = binaryKernelFunc(NotEqual, notEqualImpl, null /* complexOp */, 'bool');
- const notEqualConfig = {
- kernelName: NotEqual,
- backendName: 'cpu',
- kernelFunc: notEqual$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
- const rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
- const rsqrtConfig = {
- kernelName: Rsqrt,
- backendName: 'cpu',
- kernelFunc: rsqrt$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function sliceImpl(vals, begin, size, shape, dtype) {
- const isContinous = isSliceContinous(shape, begin, size);
- const length = sizeFromShape(size);
- const xStrides = computeStrides(shape);
- if (isContinous) {
- const flatOffset = computeFlatOffset(begin, xStrides);
- return vals.subarray(flatOffset, flatOffset + length);
- }
- const outVals = getTypedArrayFromDType(dtype, length);
- for (let i = 0; i < length; ++i) {
- const rank = size.length;
- const strides = computeStrides(size);
- const loc = indexToLoc(i, rank, strides);
- const xLoc = loc.map((idx, j) => idx + begin[j]);
- const xIndex = locToIndex(xLoc, shape.length, xStrides);
- outVals[i] = vals[xIndex];
- }
- return outVals;
- }
- function slice$1(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { begin, size } = attrs;
- assertNotComplex(x, 'slice');
- const [$begin, $size] = parseSliceParams(x, begin, size);
- assertParamsValid(x, $begin, $size);
- const vals = backend.data.get(x.dataId).values;
- const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
- return backend.makeTensorInfo($size, x.dtype, outVals);
- }
- const sliceConfig = {
- kernelName: Slice,
- backendName: 'cpu',
- kernelFunc: slice$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
- const diff = a - b;
- return diff * diff;
- }));
- const squaredDifference$1 = binaryKernelFunc(SquaredDifference, squaredDifferenceImpl);
- const squaredDifferenceConfig = {
- kernelName: SquaredDifference,
- backendName: 'cpu',
- kernelFunc: squaredDifference$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
- const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
- return { real: aReal - bReal, imag: aImag - bImag };
- }));
- const sub$1 = binaryKernelFunc(Sub, subImpl, subComplexImpl);
- const subConfig = {
- kernelName: Sub,
- backendName: 'cpu',
- kernelFunc: sub$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function transposeImpl(xVals, xShape, dtype, perm, newShape) {
- const xRank = xShape.length;
- const xSize = sizeFromShape(xShape);
- const xStrides = computeStrides(xShape);
- const newStrides = computeStrides(newShape);
- const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
- for (let i = 0; i < xSize; ++i) {
- const loc = indexToLoc(i, xRank, xStrides);
- // Permute location.
- const newLoc = new Array(loc.length);
- for (let i = 0; i < newLoc.length; i++) {
- newLoc[i] = loc[perm[i]];
- }
- const newIndex = locToIndex(newLoc, xRank, newStrides);
- result[newIndex] = xVals[i];
- }
- return result;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function uniqueImpl(values, axis, shape, dtype) {
- // Normalize and validate axis.
- const $axis = parseAxisParam(axis, shape)[0];
- // Calculate the new shape that is suitable for extracting data along the
- // given axis.
- //
- // The rank is 3.
- // The size of the 1st dimension is the size of all the axes < the given axis.
- // The size of the 2nd dimension is the same as the size of the given axis.
- // The size of the 3rd dimension is the size of all the axes > the given axis.
- //
- // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
- // newShape would be: [2*3, 5, 4].
- //
- // Note that this is not the final output shape. This will be the shape for an
- // intermediate TensorBuffer (see inputBuffer below) to allow us to extract
- // values along the given axis. To demonstrate how it works, consider the
- // following example:
- //
- // Input: a 3D tensor, with shape [1, 2, 3]
- // [
- // [
- // [1,2,3],
- // [4,5,6]
- // ]
- // ]
- // Axis: 2 (the last axis).
- // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
- //
- // For this example, newShape would be: [2, 3, 1], where 2 is calculated from
- // 1*2. The re-shaped data would look like:
- //
- // [
- // [
- // [1], [2], [3]
- // ],
- // [
- // [4], [5], [6]
- // ]
- // ]
- //
- // Then, we can construct a 3-level nested loop by the following dimension
- // order to extract the values along the axis (dimension1):
- // i: dimension1 // 0,1,2 (newShape[1])
- // m: dimension0 // 0,1 (newShape[0])
- // n: dimension2 // 0 (newShape[2])
- //
- // m, i, n
- // ---------
- // Iteration 0: data at [0, 0, 0] => "1"
- // Iteration 1: data at [1, 0, 0] => "4"
- // We got [1,4].
- // Iteration 2: data at [0, 1, 0] => "2"
- // Iteration 3: data at [1, 1, 0] => "5"
- // We got [2,5].
- // Iteration 4: data at [0, 2, 0] => "3"
- // Iteration 5: data at [1, 2, 0] => "6"
- // We got [3,6].
- const newShape = [1, shape[0], 1];
- for (let i = 0; i < $axis; i++) {
- newShape[0] *= shape[i];
- }
- newShape[1] = shape[$axis];
- for (let i = $axis + 1; i < shape.length; i++) {
- newShape[2] *= shape[i];
- }
- // A map from unique elements (their string representations) to their values
- // in "indices" (below).
- const uniqueElements = {};
- // The indices of each unique element in the original tensor along the given
- // axis. It is 1D and has the same size as the given axis.
- const indices = new Int32Array(shape[$axis]);
- // Create a buffer so we can easily extract value at a given location.
- const inputBuffer = new TensorBuffer(newShape, dtype, values);
- // The indices along the given axis that have unique elements. This is a
- // de-duped version of "indices" above.
- const uniqueIndices = [];
- const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
- for (let i = 0; i < shape[$axis]; i++) {
- // Extract values along the axis.
- let element;
- if (is1DTensor) {
- // Fast path for 1D tensor input.
- element = values[i].toString();
- }
- else {
- const axisValues = [];
- for (let m = 0; m < newShape[0]; m++) {
- for (let n = 0; n < newShape[2]; n++) {
- axisValues.push(inputBuffer.get(m, i, n));
- }
- }
- element = axisValues.join(',');
- }
- // Dedup and update various indices.
- if (uniqueElements[element] !== undefined) {
- indices[i] = uniqueElements[element];
- }
- else {
- const uniqueIndex = Object.keys(uniqueElements).length;
- uniqueElements[element] = uniqueIndex;
- indices[i] = uniqueIndex;
- uniqueIndices.push(i);
- }
- }
- // Now we know where each of the unique elements are located along the axis
- // (uniqueIndices). Extract them from input buffer and store them in the
- // output buffer.
- const outputTmpShape = newShape.slice();
- outputTmpShape[1] = Object.keys(uniqueElements).length;
- const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
- uniqueIndices.forEach((uniqueElementIndex, i) => {
- for (let m = 0; m < newShape[0]; m++) {
- for (let n = 0; n < newShape[2]; n++) {
- outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
- }
- }
- });
- // The output shape can be calculated from the input shape with the size of
- // the given axis replaced by the number of unique elements along that axis.
- const outputShape = shape.slice();
- outputShape[$axis] = outputTmpShape[1];
- return {
- outputValues: outputBuffer.values,
- outputShape,
- indices,
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- var shared = /*#__PURE__*/Object.freeze({
- __proto__: null,
- simpleAbsImpl: simpleAbsImpl,
- addImpl: addImpl,
- ceilImpl: ceilImpl,
- expImpl: expImpl,
- expm1Impl: expm1Impl,
- floorImpl: floorImpl,
- logImpl: logImpl,
- maxImpl: maxImpl,
- multiplyImpl: multiplyImpl,
- notEqualImpl: notEqualImpl,
- rsqrtImpl: rsqrtImpl,
- sliceImpl: sliceImpl,
- squaredDifferenceImpl: squaredDifferenceImpl,
- subImpl: subImpl,
- transposeImpl: transposeImpl,
- uniqueImpl: uniqueImpl
- });
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$4 = '0.0.0';
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Side effects for default initialization of MathBackendCPU
- registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const elu$3 = unaryKernelFunc(Elu, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
- const eluConfig = {
- kernelName: Elu,
- backendName: 'cpu',
- kernelFunc: elu$3,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
- function prelu$2(args) {
- const { inputs, backend } = args;
- const { x, alpha } = inputs;
- assertNotComplex([x, alpha], 'prelu');
- const aVals = backend.data.get(x.dataId).values;
- const bVals = backend.data.get(alpha.dataId).values;
- const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, x.dtype);
- return backend.makeTensorInfo(resultShape, x.dtype, resultData);
- }
- const preluConfig = {
- kernelName: Prelu,
- backendName: 'cpu',
- kernelFunc: prelu$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const relu$1 = unaryKernelFunc(Relu, (xi) => Math.max(0, xi));
- const reluConfig = {
- kernelName: Relu,
- backendName: 'cpu',
- kernelFunc: relu$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const relu6$1 = unaryKernelFunc(Relu6, (xi) => Math.min(Math.max(0, xi), 6));
- const relu6Config = {
- kernelName: Relu6,
- backendName: 'cpu',
- kernelFunc: relu6$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function applyActivation$1(backend, x, activation, preluActivationWeights) {
- if (activation === 'linear') {
- return identity$1({ inputs: { x }, backend });
- }
- else if (activation === 'relu') {
- return relu$1({ inputs: { x }, backend });
- }
- else if (activation === 'elu') {
- return elu$3({ inputs: { x }, backend });
- }
- else if (activation === 'relu6') {
- return relu6$1({ inputs: { x }, backend });
- }
- else if (activation === 'prelu') {
- return prelu$2({ inputs: { x, alpha: preluActivationWeights }, backend });
- }
- throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function reshape$2(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { shape } = attrs;
- const xSize = sizeFromShape(x.shape);
- const $shape = inferFromImplicitShape(shape, xSize);
- const $xSize = sizeFromShape($shape);
- assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
- `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
- `shape must have the same number of elements.`);
- backend.incRef(x.dataId);
- const xData = backend.data.get(x.dataId);
- if (xData.complexTensorInfos != null) {
- const real = xData.complexTensorInfos.real;
- const imag = xData.complexTensorInfos.imag;
- real.shape = $shape;
- imag.shape = $shape;
- }
- return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
- }
- const reshapeConfig = {
- kernelName: Reshape,
- backendName: 'cpu',
- kernelFunc: reshape$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function batchMatMul(args) {
- const { inputs, backend, attrs } = args;
- const { a, b } = inputs;
- const { transposeA, transposeB } = attrs;
- assertNotComplex([a, b], 'matMul');
- const aRank = a.shape.length;
- const bRank = b.shape.length;
- const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
- const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
- const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
- const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
- const outerDimsA = a.shape.slice(0, -2);
- const outerDimsB = b.shape.slice(0, -2);
- assert(arraysEqual(outerDimsA, outerDimsB), () => `Error in matMul: outer dimensions (${outerDimsA}) and (` +
- `${outerDimsB}) of Tensors with shapes ${a.shape} and ` +
- `${b.shape} must match.`);
- assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
- `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
- `${b.shape} and transposeA=${transposeA}` +
- ` and transposeB=${transposeB} must match.`);
- const outShape = a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);
- const batchDimA = sizeFromShape(outerDimsA);
- const batchDimB = sizeFromShape(outerDimsB);
- const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
- [batchDimA, outerShapeA, innerShapeA];
- const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
- [batchDimB, innerShapeB, outerShapeB];
- // The rest of the implementation is designed to operate on rank-3 tensors
- const a3d = reshape$2({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
- const b3d = reshape$2({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
- const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
- const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
- const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
- const batchDim = a3d.shape[0];
- const a3dValues = backend.data.get(a3d.dataId).values;
- const b3dValues = backend.data.get(b3d.dataId).values;
- const a3dStrides = computeStrides(a3d.shape);
- const b3dStrides = computeStrides(b3d.shape);
- const [aBatch, aOuterStep, aInnerStep] = transposeA ?
- [a3dStrides[0], 1, a3dStrides[1]] :
- [a3dStrides[0], a3dStrides[1], 1];
- const [bInnerStep, bOuterStep, bBatch] = transposeB ?
- [1, b3dStrides[1], b3dStrides[0]] :
- [b3dStrides[1], 1, b3dStrides[0]];
- const size = leftDim * rightDim;
- const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
- const resVals = result.values;
- const blockSize = backend.blockSize;
- for (let bi = 0; bi < batchDim; bi++) {
- for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
- for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
- for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
- // for when blockSize doesn't evenly divide the input
- const iBlock = Math.min(i0 + blockSize, leftDim);
- const jBlock = Math.min(j0 + blockSize, rightDim);
- const kBlock = Math.min(k0 + blockSize, sharedDim);
- for (let i = i0; i < iBlock; i++) {
- for (let j = j0; j < jBlock; j++) {
- let sum = 0.0;
- for (let k = k0; k < kBlock; k++) {
- sum +=
- a3dValues[bi * aBatch + i * aOuterStep + k * aInnerStep] *
- b3dValues[k * bInnerStep + j * bOuterStep + bi * bBatch];
- }
- resVals[bi * size + (i * rightDim + j)] += sum;
- }
- }
- }
- }
- }
- }
- backend.disposeIntermediateTensorInfo(a3d);
- backend.disposeIntermediateTensorInfo(b3d);
- // set correct shape on output.
- return backend.makeTensorInfo(outShape, result.dtype, result.values);
- }
- const batchMatMulConfig = {
- kernelName: BatchMatMul,
- backendName: 'cpu',
- kernelFunc: batchMatMul,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function _fusedMatMul(args) {
- const { inputs, backend, attrs } = args;
- const { a, b, bias, preluActivationWeights } = inputs;
- const { transposeA, transposeB, activation } = attrs;
- let current;
- let addRes;
- let activationRes;
- const intermediates = [];
- const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
- current = matMulRes;
- if (bias) {
- addRes = add$4({ inputs: { a: current, b: bias }, backend });
- intermediates.push(current);
- current = addRes;
- }
- if (activation) {
- activationRes =
- applyActivation$1(backend, current, activation, preluActivationWeights);
- intermediates.push(current);
- current = activationRes;
- }
- for (const i of intermediates) {
- backend.disposeIntermediateTensorInfo(i);
- }
- return current;
- }
- const _fusedMatMulConfig = {
- kernelName: _FusedMatMul,
- backendName: 'cpu',
- kernelFunc: _fusedMatMul,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const acos$1 = unaryKernelFunc(Acos, (xi) => Math.acos(xi));
- const acosConfig = {
- kernelName: Acos,
- backendName: 'cpu',
- kernelFunc: acos$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const acosh$1 = unaryKernelFunc(Acosh, (xi) => Math.acosh(xi));
- const acoshConfig = {
- kernelName: Acosh,
- backendName: 'cpu',
- kernelFunc: acosh$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const asin$1 = unaryKernelFunc(Asin, (xi) => Math.asin(xi));
- const asinConfig = {
- kernelName: Asin,
- backendName: 'cpu',
- kernelFunc: asin$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const asinh$1 = unaryKernelFunc(Asinh, (xi) => Math.asinh(xi));
- const asinhConfig = {
- kernelName: Asinh,
- backendName: 'cpu',
- kernelFunc: asinh$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const atan$1 = unaryKernelFunc(Atan, (xi) => Math.atan(xi));
- const atanConfig = {
- kernelName: Atan,
- backendName: 'cpu',
- kernelFunc: atan$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const atanh$1 = unaryKernelFunc(Atanh, (xi) => Math.atanh(xi));
- const atanhConfig = {
- kernelName: Atanh,
- backendName: 'cpu',
- kernelFunc: atanh$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function pool$1(xValues, xShape, dtype, strides, convInfo, poolType) {
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
- Number.POSITIVE_INFINITY);
- const output = buffer(convInfo.outShape, dtype);
- const outputVals = output.values;
- const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
- const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
- const outputColStrides = convInfo.outShape[3];
- for (let b = 0; b < convInfo.batchSize; ++b) {
- const outputBatchOffset = b * outputBatchStrides;
- const inputBatchOffset = b * strides[0];
- for (let d = 0; d < convInfo.inChannels; ++d) {
- for (let yR = 0; yR < convInfo.outHeight; ++yR) {
- const xRCorner = yR * strideHeight - padTop;
- const xRMin = Math.max(0, xRCorner);
- const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
- const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
- for (let yC = 0; yC < convInfo.outWidth; ++yC) {
- const xCCorner = yC * strideWidth - padLeft;
- const xCMin = Math.max(0, xCCorner);
- const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
- let minMaxValue = initialValue;
- let avgValue = 0;
- let count = 0;
- for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
- const xROffset = inputBatchOffset + xR * strides[1];
- for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
- const xCOffset = xROffset + xC * strides[2];
- const pixel = xValues[xCOffset + d];
- if ((poolType === 'max' && pixel > minMaxValue)) {
- minMaxValue = pixel;
- }
- else if (poolType === 'avg') {
- avgValue += pixel;
- count++;
- }
- }
- if (isNaN(minMaxValue)) {
- break;
- }
- }
- const outputOffset = outputRowOffset + yC * outputColStrides + d;
- outputVals[outputOffset] =
- poolType === 'avg' ? avgValue / count : minMaxValue;
- }
- }
- }
- }
- return output;
- }
- function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
- const maxPositions = buffer(convInfo.outShape, 'int32');
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const xBuf = buffer(xShape, dtype, xValues);
- for (let b = 0; b < convInfo.batchSize; ++b) {
- for (let d = 0; d < convInfo.inChannels; ++d) {
- for (let yR = 0; yR < convInfo.outHeight; ++yR) {
- const xRCorner = yR * strideHeight - padTop;
- let xRMin = xRCorner;
- while (xRMin < 0) {
- xRMin += dilationHeight;
- }
- // const xRMin = Math.max(0, xRCorner);
- const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
- for (let yC = 0; yC < convInfo.outWidth; ++yC) {
- const xCCorner = yC * strideWidth - padLeft;
- let xCMin = xCCorner;
- while (xCMin < 0) {
- xCMin += dilationWidth;
- }
- const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
- let maxValue = Number.NEGATIVE_INFINITY;
- let maxPosition = -1;
- for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
- const wR = xR - xRCorner;
- for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
- const wC = xC - xCCorner;
- const pixel = xBuf.get(b, xR, xC, d);
- if (pixel > maxValue) {
- maxValue = pixel;
- if (flattenPositions) {
- maxPosition = includeBatchInIndex ?
- ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
- convInfo.inChannels +
- d :
- (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
- }
- else {
- maxPosition = wR * effectiveFilterWidth + wC;
- }
- }
- }
- }
- maxPositions.set(maxPosition, b, yR, yC, d);
- }
- }
- }
- }
- return maxPositions;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function avgPool$1(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- assertNotComplex(x, 'avgPool');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const dilations = 1;
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- let res;
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- res = identity$1({ inputs: { x }, backend });
- }
- else {
- const xValues = backend.data.get(x.dataId).values;
- const strides = computeStrides(x.shape);
- const buffer = pool$1(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
- res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
- }
- return res;
- }
- const avgPoolConfig = {
- kernelName: AvgPool,
- backendName: 'cpu',
- kernelFunc: avgPool$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function avgPoolBackprop$1(args) {
- const { inputs, backend, attrs } = args;
- const { dy, input } = inputs;
- const x = input;
- assertNotComplex([dy, input], 'avgPoolBackprop');
- const { filterSize, strides, pad } = attrs;
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const dx = buffer(x.shape, 'float32');
- const avgMultiplier = 1 / (filterHeight * filterWidth);
- const dyData = backend.data.get(dy.dataId).values;
- const dyBuf = buffer(dy.shape, 'float32', dyData);
- for (let b = 0; b < convInfo.batchSize; ++b) {
- for (let d = 0; d < convInfo.inChannels; ++d) {
- for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
- for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
- // Shader code begins.
- const dyRCorner = dxR - padTop;
- const dyCCorner = dxC - padLeft;
- let dotProd = 0;
- for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
- const dyR = (dyRCorner + wR) / strideHeight;
- if (dyR < 0 || dyR >= convInfo.outHeight ||
- Math.floor(dyR) !== dyR) {
- continue;
- }
- for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
- const dyC = (dyCCorner + wC) / strideWidth;
- if (dyC < 0 || dyC >= convInfo.outWidth ||
- Math.floor(dyC) !== dyC) {
- continue;
- }
- const pixel = dyBuf.get(b, dyR, dyC, d);
- dotProd += pixel;
- }
- }
- dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
- }
- }
- }
- }
- return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
- }
- const avgPoolBackpropConfig = {
- kernelName: AvgPoolBackprop,
- backendName: 'cpu',
- kernelFunc: avgPoolBackprop$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function batchNorm$1(args) {
- const { inputs, backend, attrs } = args;
- const { x, scale, offset, mean, variance } = inputs;
- assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
- 'equal ranks.');
- assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
- 'equal ranks.');
- assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
- 'equal ranks.');
- assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
- let { varianceEpsilon } = attrs;
- if (varianceEpsilon == null) {
- varianceEpsilon = 0.001;
- }
- const xVals = backend.data.get(x.dataId).values;
- const mVals = backend.data.get(mean.dataId).values;
- const varVals = backend.data.get(variance.dataId).values;
- const sVals = scale ? backend.data.get(scale.dataId).values :
- new Float32Array([1]);
- const offVals = offset ?
- backend.data.get(offset.dataId).values :
- new Float32Array([0]);
- const outVals = new Float32Array(xVals.length);
- const offValsLength = offVals.length;
- const sValsLength = sVals.length;
- const varValsLength = varVals.length;
- const mValsLength = mVals.length;
- let offi = 0;
- let mi = 0;
- let si = 0;
- let vi = 0;
- for (let i = 0; i < xVals.length; ++i) {
- outVals[i] = offVals[offi++] +
- (xVals[i] - mVals[mi++]) * sVals[si++] /
- Math.sqrt(varVals[vi++] + varianceEpsilon);
- if (offi >= offValsLength) {
- offi = 0;
- }
- if (mi >= mValsLength) {
- mi = 0;
- }
- if (si >= sValsLength) {
- si = 0;
- }
- if (vi >= varValsLength) {
- vi = 0;
- }
- }
- return backend.makeTensorInfo(x.shape, x.dtype, outVals);
- }
- const batchNormConfig = {
- kernelName: FusedBatchNorm,
- backendName: 'cpu',
- kernelFunc: batchNorm$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const clip = unaryKernelFunc(ClipByValue, (xi, attrs) => {
- const clipAttrs = attrs;
- if (xi > clipAttrs.clipValueMax) {
- return clipAttrs.clipValueMax;
- }
- return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
- });
- const clipConfig = {
- kernelName: ClipByValue,
- backendName: 'cpu',
- kernelFunc: clip,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function imag$1(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
- const imagVal = backend.data.get(imag.dataId).values;
- // When complex tensor is disposed, its underlying parts will be disposed too.
- // Make new tensor out of the imag value of the complex. This makes sure the
- // value is still accessible even if complex tensor is disposed.
- return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
- }
- const imagConfig = {
- kernelName: Imag,
- backendName: 'cpu',
- kernelFunc: imag$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function concat$1(args) {
- const { inputs, backend, attrs } = args;
- const { axis } = attrs;
- const $axis = parseAxisParam(axis, inputs[0].shape)[0];
- let outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
- if (sizeFromShape(outShape) === 0) {
- return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
- }
- // Keep only non-empty tensors (ignore tensors with 0 in their shape).
- const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
- if ($inputs.length === 1) {
- return $inputs[0];
- }
- const shapes = $inputs.map(t => t.shape);
- assertParamsConsistent(shapes, $axis);
- if ($inputs[0].dtype === 'complex64') {
- const reals = $inputs.map((t) => real$1({ inputs: { input: t }, backend }));
- const imags = $inputs.map((t) => imag$1({ inputs: { input: t }, backend }));
- const realConcated = concat$1({ inputs: reals, backend, attrs: { axis: $axis } });
- const imagConcated = concat$1({ inputs: imags, backend, attrs: { axis: $axis } });
- const result = complex$1({ inputs: { real: realConcated, imag: imagConcated }, backend });
- reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
- imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
- backend.disposeIntermediateTensorInfo(realConcated);
- backend.disposeIntermediateTensorInfo(imagConcated);
- return result;
- }
- // Any concat of n-dimensional tensors across any axis can be reduced to
- // a concatenation of two-dimensional tensors across the axis 1 by first
- // partitioning the axes of the original tensors into those less than the
- // axis to be concatenated and the rest. Then reshape the tensors
- // into a two-dimensional tensor by collapsing these two sets of axes and
- // concatenate the resulting matrices across the axis 1, finally reshaping
- // the result to have the proper shape.
- const inputs2D = $inputs.map(t => {
- const innerSize = sizeFromShape(t.shape.slice($axis));
- const shape = [-1, innerSize];
- return reshape$2({ inputs: { x: t }, backend, attrs: { shape } });
- });
- // Concats 2d tensors along axis=1.
- outShape =
- computeOutShape$1(inputs2D.map(t => t.shape), 1 /* axis */);
- const outVals = getTypedArrayFromDType($inputs[0].dtype, sizeFromShape(outShape));
- if (inputs2D[0].shape[0] === 1) {
- // Use built-in TypedArray.set() method for speed.
- let offset = 0;
- inputs2D.forEach(t => {
- const val = backend.data.get(t.dataId).values;
- const size = sizeFromShape(t.shape);
- outVals.set(val, offset);
- offset += size;
- });
- }
- else {
- let colOffset = 0;
- inputs2D.forEach(t => {
- const tVals = backend.data.get(t.dataId).values;
- let tIdx = 0;
- for (let row = 0; row < t.shape[0]; ++row) {
- const resIdx = row * outShape[1] + colOffset;
- for (let col = 0; col < t.shape[1]; ++col) {
- outVals[resIdx + col] = tVals[tIdx++];
- }
- }
- colOffset += t.shape[1];
- });
- }
- const finalOutShape = computeOutShape$1($inputs.map(t => t.shape), $axis);
- const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
- inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
- return outInfo;
- }
- const concatConfig = {
- kernelName: Concat,
- backendName: 'cpu',
- kernelFunc: concat$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function conv2D(args) {
- const { inputs, backend, attrs } = args;
- const { x, filter } = inputs;
- const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
- assertNotComplex([x, filter], 'conv2d');
- const $dataFormat = convertConv2DDataFormat(dataFormat);
- const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const padLeft = convInfo.padInfo.left;
- const padTop = convInfo.padInfo.top;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- const y = new TensorBuffer(convInfo.outShape, x.dtype);
- const xStrides = computeStrides(x.shape);
- const filterStrides = computeStrides(filter.shape);
- const xBatchStride = xStrides[0];
- const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
- const xColStride = isChannelsLast ? xStrides[2] : 1;
- const xChannelStride = isChannelsLast ? 1 : xStrides[1];
- const yBatchStride = y.strides[0];
- const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
- const yColStride = isChannelsLast ? y.strides[2] : 1;
- const yChannelStride = isChannelsLast ? 1 : y.strides[1];
- const xVals = backend.data.get(x.dataId).values;
- const wVals = backend.data.get(filter.dataId).values;
- const yVals = y.values;
- for (let b = 0; b < convInfo.batchSize; ++b) {
- const xOffset1 = b * xBatchStride;
- const yOffset1 = b * yBatchStride;
- for (let yR = 0; yR < convInfo.outHeight; ++yR) {
- const yOffset2 = yOffset1 + yR * yRowStride;
- const xRCorner = yR * convInfo.strideHeight - padTop;
- for (let wR = 0; wR < filterHeight; ++wR) {
- const xR = xRCorner + wR * dilationHeight;
- if (xR < 0 || xR >= convInfo.inHeight) {
- continue;
- }
- const wOffset1 = wR * filterStrides[0];
- const xOffset2 = xOffset1 + xR * xRowStride;
- for (let yC = 0; yC < convInfo.outWidth; ++yC) {
- const yOffset3 = yOffset2 + yC * yColStride;
- const xCCorner = yC * convInfo.strideWidth - padLeft;
- for (let wC = 0; wC < filterWidth; ++wC) {
- const xC = xCCorner + wC * dilationWidth;
- if (xC < 0 || xC >= convInfo.inWidth) {
- continue;
- }
- const wOffset2 = wOffset1 + wC * filterStrides[1];
- const xOffset3 = xOffset2 + xC * xColStride;
- let wOffset3 = wOffset2;
- for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
- const xVal = xVals[xOffset3 + d1 * xChannelStride];
- for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
- yVals[yOffset3 + d2 * yChannelStride] +=
- xVal * wVals[wOffset3 + d2];
- }
- wOffset3 += convInfo.outChannels;
- }
- }
- }
- }
- }
- }
- return backend.makeTensorInfo(y.shape, y.dtype, yVals);
- }
- const conv2DConfig = {
- kernelName: Conv2D,
- backendName: 'cpu',
- kernelFunc: conv2D
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const cos$1 = unaryKernelFunc(Cos, (xi) => Math.cos(xi));
- const cosConfig = {
- kernelName: Cos,
- backendName: 'cpu',
- kernelFunc: cos$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const cosh$1 = unaryKernelFunc(Cosh, (xi) => Math.cosh(xi));
- const coshConfig = {
- kernelName: Cosh,
- backendName: 'cpu',
- kernelFunc: cosh$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const dilation2dConfig = {
- kernelName: Dilation2D,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend, attrs }) => {
- const { x, filter } = inputs;
- const { strides, pad, dilations } = attrs;
- const cpuBackend = backend;
- const xVals = cpuBackend.data.get(x.dataId).values;
- const xRank = x.shape.length;
- const filterVals = cpuBackend.data.get(filter.dataId).values;
- const filterRank = filter.shape.length;
- const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
- const outSize = sizeFromShape(outShape);
- const outRank = outShape.length;
- const outputVals = getArrayFromDType(x.dtype, outSize);
- // Upsampling the input by fill in `dilation size - 1` values between each
- // input value.
- // This implementation follows the TF c++ implementation:
- // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
- for (let b = 0; b < batchSize; ++b) {
- for (let hOut = 0; hOut < outHeight; ++hOut) {
- const hBeg = hOut * strideHeight - padInfo.top;
- for (let wOut = 0; wOut < outWidth; ++wOut) {
- const wBeg = wOut * strideWidth - padInfo.left;
- for (let d = 0; d < inChannels; ++d) {
- let curVal = Number.MIN_SAFE_INTEGER;
- for (let h = 0; h < filterHeight; ++h) {
- const hIn = hBeg + h * dilationHeight;
- if (hIn >= 0 && hIn < inHeight) {
- for (let w = 0; w < filterWidth; ++w) {
- const wIn = wBeg + w * dilationWidth;
- if (wIn >= 0 && wIn < inWidth) {
- const xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
- const filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
- const val = xVals[xIndex] + filterVals[filterIndex];
- if (val > curVal) {
- curVal = val;
- }
- }
- }
- }
- }
- const outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
- outputVals[outputIndex] = curVal;
- }
- }
- }
- }
- const dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
- return { dataId, shape: outShape, dtype: x.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const dilation2dBackpropFilterConfig = {
- kernelName: Dilation2DBackpropFilter,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend, attrs }) => {
- const { x, filter, dy } = inputs;
- const { strides, pad, dilations } = attrs;
- const cpuBackend = backend;
- const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
- const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
- const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
- assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
- `must have the same rank as output ${outShape.length}, but got ` +
- `${dy.rank}`);
- const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
- // The computed filter gradients has the same dimensions as the filter:
- // [filterHeight, filterWidth, depth]
- const gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
- // In the case of multiple argmax branches, we only back-propagate along the
- // last branch, i.e., the one with largest value of `h * filter_cols + w`,
- // similarly to the max-pooling backward routines.
- // This implementation follows the TF c++ implementation:
- // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
- for (let b = 0; b < batchSize; ++b) {
- for (let hOut = 0; hOut < outHeight; ++hOut) {
- const hBeg = hOut * strideHeight - padInfo.top;
- for (let wOut = 0; wOut < outWidth; ++wOut) {
- const wBeg = wOut * strideWidth - padInfo.left;
- for (let d = 0; d < inChannels; ++d) {
- let curVal = Number.MIN_SAFE_INTEGER;
- let hMax = 0;
- let wMax = 0;
- for (let h = 0; h < filterHeight; ++h) {
- const hIn = hBeg + h * dilationHeight;
- if (hIn >= 0 && hIn < inHeight) {
- for (let w = 0; w < filterWidth; ++w) {
- const wIn = wBeg + w * dilationWidth;
- if (wIn >= 0 && wIn < inWidth) {
- const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
- if (val > curVal) {
- curVal = val;
- hMax = h;
- wMax = w;
- }
- }
- }
- }
- }
- gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
- }
- }
- }
- }
- const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
- return { dataId, shape: filter.shape, dtype: filter.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const dilation2dBackpropInputConfig = {
- kernelName: Dilation2DBackpropInput,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend, attrs }) => {
- const { x, filter, dy } = inputs;
- const { strides, pad, dilations } = attrs;
- const cpuBackend = backend;
- const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
- const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
- const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
- assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
- `must have the same rank as output ${outShape.length}, but got ` +
- `${dy.rank}`);
- const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
- // The computed gradients has the same dimensions as the input:
- // [batch, inputHeight, inputCols, inChannel]
- const gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
- // In the case of multiple argmax branches, we only back-propagate along the
- // last branch, i.e., the one with largest value of `h * filter_cols + w`,
- // similarly to the max-pooling backward routines.
- // This implementation follows the TF c++ implementation:
- // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
- for (let b = 0; b < batchSize; ++b) {
- for (let hOut = 0; hOut < outHeight; ++hOut) {
- const hBeg = hOut * strideHeight - padInfo.top;
- for (let wOut = 0; wOut < outWidth; ++wOut) {
- const wBeg = wOut * strideWidth - padInfo.left;
- for (let d = 0; d < inChannels; ++d) {
- let curVal = Number.MIN_SAFE_INTEGER;
- let hInMax = (hBeg < 0) ? 0 : hBeg;
- let wInMax = (wBeg < 0) ? 0 : wBeg;
- for (let h = 0; h < filterHeight; ++h) {
- const hIn = hBeg + h * dilationHeight;
- if (hIn >= 0 && hIn < inHeight) {
- for (let w = 0; w < filterWidth; ++w) {
- const wIn = wBeg + w * dilationWidth;
- if (wIn >= 0 && wIn < inWidth) {
- const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
- if (val > curVal) {
- curVal = val;
- hInMax = hIn;
- wInMax = wIn;
- }
- }
- }
- }
- }
- gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
- }
- }
- }
- }
- const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
- return { dataId, shape: x.shape, dtype: x.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const divImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
- const div$1 = binaryKernelFunc(Div, divImpl);
- const divConfig = {
- kernelName: Div,
- backendName: 'cpu',
- kernelFunc: div$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const p = ERF_P;
- const a1 = ERF_A1;
- const a2 = ERF_A2;
- const a3 = ERF_A3;
- const a4 = ERF_A4;
- const a5 = ERF_A5;
- const erf$1 = unaryKernelFunc(Erf, (xi) => {
- const sign = Math.sign(xi);
- const v = Math.abs(xi);
- const t = 1.0 / (1.0 + p * v);
- return sign *
- (1.0 -
- (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
- Math.exp(-v * v));
- });
- const erfConfig = {
- kernelName: Erf,
- backendName: 'cpu',
- kernelFunc: erf$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Calculate FFT of inner most elements of batch tensor.
- */
- function fftBatch(input, inverse, cpuBackend) {
- const inputShape = input.shape;
- const batch = inputShape[0];
- const innerDim = inputShape[1];
- const inputVals = cpuBackend.data.get(input.dataId);
- const real2D = inputVals.complexTensorInfos.real;
- const imag2D = inputVals.complexTensorInfos.imag;
- // Collects real and imaginary values separately.
- const resultShape = [batch, innerDim];
- const resultSize = sizeFromShape(resultShape);
- const resultReal = getTypedArrayFromDType('float32', resultSize);
- const resultImag = getTypedArrayFromDType('float32', resultSize);
- for (let b = 0; b < batch; b++) {
- // TODO: Support slice ops for complex type.
- const r = slice$1({
- inputs: { x: real2D },
- backend: cpuBackend,
- attrs: { begin: [b, 0], size: [1, innerDim] }
- });
- const i = slice$1({
- inputs: { x: imag2D },
- backend: cpuBackend,
- attrs: { begin: [b, 0], size: [1, innerDim] }
- });
- const input = complex$1({ inputs: { real: r, imag: i }, backend: cpuBackend });
- // Run FFT by batch element.
- const { real, imag } = fftImpl(input, inverse, cpuBackend);
- const res = mergeRealAndImagArrays(real, imag);
- for (let d = 0; d < innerDim; d++) {
- const c = getComplexWithIndex(res, d);
- resultReal[b * innerDim + d] = c.real;
- resultImag[b * innerDim + d] = c.imag;
- }
- cpuBackend.disposeIntermediateTensorInfo(r);
- cpuBackend.disposeIntermediateTensorInfo(i);
- cpuBackend.disposeIntermediateTensorInfo(input);
- }
- const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
- const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
- const result = complex$1({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
- cpuBackend.disposeIntermediateTensorInfo($realInfo);
- cpuBackend.disposeIntermediateTensorInfo($imagInfo);
- return result;
- }
- function fftImpl(input, inverse, cpuBackend) {
- const inputSize = sizeFromShape(input.shape);
- const inputVals = cpuBackend.data.get(input.dataId);
- const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
- const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
- if (isExponentOf2(inputSize)) {
- const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
- const resultShape = [input.shape[0], input.shape[1]];
- if (inverse) {
- const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
- const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
- const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
- const sizeInfoCopy = identity$1({ inputs: { x: sizeInfo }, backend: cpuBackend });
- const divRealInfo = divConfig.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
- const divImagInfo = divConfig.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
- const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
- const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
- cpuBackend.disposeIntermediateTensorInfo(realInfo);
- cpuBackend.disposeIntermediateTensorInfo(imagInfo);
- cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
- cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
- cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
- cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
- return { real: divRealVals, imag: divImagVals };
- }
- return result;
- }
- else {
- const data = mergeRealAndImagArrays(realVals, imagVals);
- const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
- return splitRealAndImagArrays(rawOutput);
- }
- }
- function isExponentOf2(size) {
- return (size & size - 1) === 0;
- }
- // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
- function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
- if (size === 1) {
- return { real: realVals, imag: imagVals };
- }
- const data = mergeRealAndImagArrays(realVals, imagVals);
- const half = size / 2;
- const evenComplex = complexWithEvenIndex(data);
- const evenRealVals = evenComplex.real;
- const evenImagVals = evenComplex.imag;
- const evenShape = [evenRealVals.length];
- const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
- const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
- const evenTensorInfo = complex$1({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
- const oddComplex = complexWithOddIndex(data);
- const oddRealVals = oddComplex.real;
- const oddImagVals = oddComplex.imag;
- const oddShape = [oddRealVals.length];
- const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
- const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
- const oddTensorInfo = complex$1({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
- // Recursive call for half part of original input.
- const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
- const $evenRealVals = $evenComplex.real;
- const $evenImagVals = $evenComplex.imag;
- const $evenShape = [$evenRealVals.length];
- const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
- const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
- const $evenTensorInfo = complex$1({
- inputs: { real: $evenRealInfo, imag: $evenImagInfo },
- backend: cpuBackend
- });
- const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
- const $oddRealVals = $oddComplex.real;
- const $oddImagVals = $oddComplex.imag;
- const $oddShape = [$oddRealVals.length];
- const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
- const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
- const $oddTensorInfo = complex$1({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
- const e = exponents(size, inverse);
- const eShape = [e.real.length];
- const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
- const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
- const complexInfo = complex$1({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
- const exponentInfo = multiply$2({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
- const addPart = add$4({
- inputs: { a: $evenTensorInfo, b: exponentInfo },
- backend: cpuBackend
- });
- const subPart = sub$1({
- inputs: { a: $evenTensorInfo, b: exponentInfo },
- backend: cpuBackend
- });
- const addPartReal = real$1({ inputs: { input: addPart }, backend: cpuBackend });
- const subPartReal = real$1({ inputs: { input: subPart }, backend: cpuBackend });
- const addPartImag = imag$1({ inputs: { input: addPart }, backend: cpuBackend });
- const subPartImag = imag$1({ inputs: { input: subPart }, backend: cpuBackend });
- const $real = concat$1({
- inputs: [addPartReal, subPartReal],
- backend: cpuBackend,
- attrs: { axis: 0 }
- });
- const $imag = concat$1({
- inputs: [addPartImag, subPartImag],
- backend: cpuBackend,
- attrs: { axis: 0 }
- });
- const $realVals = cpuBackend.data.get($real.dataId).values;
- const $imagVals = cpuBackend.data.get($imag.dataId).values;
- cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
- cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
- cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
- cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
- cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
- cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
- cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
- cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
- cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
- cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
- cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
- cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
- cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
- cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
- cpuBackend.disposeIntermediateTensorInfo(complexInfo);
- cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
- cpuBackend.disposeIntermediateTensorInfo(addPart);
- cpuBackend.disposeIntermediateTensorInfo(subPart);
- cpuBackend.disposeIntermediateTensorInfo(addPartReal);
- cpuBackend.disposeIntermediateTensorInfo(addPartImag);
- cpuBackend.disposeIntermediateTensorInfo(subPartReal);
- cpuBackend.disposeIntermediateTensorInfo(subPartImag);
- cpuBackend.disposeIntermediateTensorInfo($real);
- cpuBackend.disposeIntermediateTensorInfo($imag);
- return { real: $realVals, imag: $imagVals };
- }
- // Calculate fourier transform by multplying sinusoid matrix.
- function fourierTransformByMatmul(data, size, inverse) {
- const ret = new Float32Array(size * 2);
- // TODO: Use matmul instead once it supports complex64 type.
- for (let r = 0; r < size; r++) {
- let real = 0.0;
- let imag = 0.0;
- for (let c = 0; c < size; c++) {
- const e = exponent(r * c, size, inverse);
- const term = getComplexWithIndex(data, c);
- real += term.real * e.real - term.imag * e.imag;
- imag += term.real * e.imag + term.imag * e.real;
- }
- if (inverse) {
- real /= size;
- imag /= size;
- }
- assignToTypedArray(ret, real, imag, r);
- }
- return ret;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function fft$1(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const inputSize = sizeFromShape(input.shape);
- // Collapse all outer dimensions to a single batch dimension.
- const innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = inputSize / innerDimensionSize;
- const input2D = reshape$2({
- inputs: { x: input },
- backend,
- attrs: { shape: [batch, innerDimensionSize] }
- });
- const result = fftBatch(input2D, false, backend);
- const resultReshaped = reshape$2({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
- backend.disposeIntermediateTensorInfo(input2D);
- backend.disposeIntermediateTensorInfo(result);
- return resultReshaped;
- }
- const fftConfig = {
- kernelName: FFT,
- backendName: 'cpu',
- kernelFunc: fft$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const flipLeftRightConfig = {
- kernelName: FlipLeftRight,
- backendName: 'cpu',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { image } = inputs;
- const cpuBackend = backend;
- const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
- const [batch, imageHeight, imageWidth, numChannels] = image.shape;
- const imageVals = cpuBackend.data.get(image.dataId).values;
- for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
- const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
- for (let row = 0; row < imageHeight; row++) {
- const rowOffset = row * (imageWidth * numChannels);
- for (let col = 0; col < imageWidth; col++) {
- const colOffset = col * numChannels;
- for (let channel = 0; channel < numChannels; channel++) {
- const coords = [batch, row, col, channel];
- const x = coords[2];
- const coordX = Math.round(imageWidth - x);
- const outIdx = batchOffset + rowOffset + colOffset + channel;
- let outputValue = imageVals[outIdx];
- // If the coordinate position falls within the image boundaries...
- if (coordX >= 0 && coordX < imageWidth) {
- // set the output to the image value at the coordinate position.
- const rotatedColOffset = coordX * numChannels;
- const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
- outputValue = imageVals[imageIdx];
- }
- output[outIdx] = outputValue;
- }
- }
- }
- }
- const dataId = cpuBackend.write(output, image.shape, image.dtype);
- return { dataId, shape: image.shape, dtype: image.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function fusedConv2D(args) {
- const { inputs, backend, attrs } = args;
- const { x, filter, bias, preluActivationWeights } = inputs;
- const { strides, pad, dataFormat, dilations, dimRoundingMode, activation } = attrs;
- let result = conv2D({
- inputs: { x, filter },
- backend,
- attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
- });
- if (bias) {
- const resultOld = result;
- result = add$4({ inputs: { a: result, b: bias }, backend });
- backend.disposeIntermediateTensorInfo(resultOld);
- }
- if (activation) {
- const resultOld = result;
- result =
- applyActivation$1(backend, result, activation, preluActivationWeights);
- backend.disposeIntermediateTensorInfo(resultOld);
- }
- return result;
- }
- const fusedConv2DConfig = {
- kernelName: FusedConv2D,
- backendName: 'cpu',
- kernelFunc: fusedConv2D
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function ifft$1(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const inputSize = sizeFromShape(input.shape);
- // Collapse all outer dimensions to a single batch dimension.
- const innerDimensionSize = input.shape[input.shape.length - 1];
- const batch = inputSize / innerDimensionSize;
- const input2D = reshape$2({
- inputs: { x: input },
- backend,
- attrs: { shape: [batch, innerDimensionSize] }
- });
- const result = fftBatch(input2D, true, backend);
- const resultReshaped = reshape$2({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
- backend.disposeIntermediateTensorInfo(input2D);
- backend.disposeIntermediateTensorInfo(result);
- return resultReshaped;
- }
- const ifftConfig = {
- kernelName: IFFT,
- backendName: 'cpu',
- kernelFunc: ifft$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isFinite$2 = unaryKernelFunc(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
- const isFiniteConfig = {
- kernelName: IsFinite,
- backendName: 'cpu',
- kernelFunc: isFinite$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isInf$1 = unaryKernelFunc(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
- const isInfConfig = {
- kernelName: IsInf,
- backendName: 'cpu',
- kernelFunc: isInf$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const isNaN$2 = unaryKernelFunc(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
- const isNaNConfig = {
- kernelName: IsNan,
- backendName: 'cpu',
- kernelFunc: isNaN$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const log1p$1 = unaryKernelFunc(Log1p, (xi) => Math.log1p(xi));
- const log1pConfig = {
- kernelName: Log1p,
- backendName: 'cpu',
- kernelFunc: log1p$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const logicalNot$1 = unaryKernelFunc(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
- const logicalNotConfig = {
- kernelName: LogicalNot,
- backendName: 'cpu',
- kernelFunc: logicalNot$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxConfig = {
- kernelName: Max,
- backendName: 'cpu',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { reductionIndices, keepDims } = attrs;
- const cpuBackend = backend;
- let xShape = x.shape;
- const xRank = xShape.length;
- const origAxes = parseAxisParam(reductionIndices, xShape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, xRank);
- let xVals = cpuBackend.data.get(x.dataId).values;
- if (permutedAxes != null) {
- const newShape = new Array(xRank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = xShape[permutedAxes[i]];
- }
- xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
- axes = getInnerMostAxes(axes.length, xRank);
- xShape = newShape;
- }
- assertNotComplex(x, 'max');
- assertAxesAreInnerMostDims('max', axes, xRank);
- const [maxOutShape, reduceShape] = computeOutAndReduceShapes(xShape, axes);
- const reduceSize = sizeFromShape(reduceShape);
- const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
- const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
- let outShape = maxOutShape;
- if (keepDims) {
- // reshape
- const newShape = expandShapeToKeepDim(maxOutShape, origAxes);
- outShape = newShape;
- }
- return { dataId, shape: outShape, dtype: x.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPool$1(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- assertNotComplex(x, 'maxPool');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const dilations = 1;
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- let res;
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- res = identity$1({ inputs: { x }, backend });
- }
- else {
- const xValues = backend.data.get(x.dataId).values;
- const strides = computeStrides(x.shape);
- const buffer = pool$1(xValues, x.shape, x.dtype, strides, convInfo, 'max');
- res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
- }
- return res;
- }
- const maxPoolConfig = {
- kernelName: MaxPool,
- backendName: 'cpu',
- kernelFunc: maxPool$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPoolBackprop$1(args) {
- const { inputs, backend, attrs } = args;
- const { dy, input, output } = inputs;
- const x = input;
- assertNotComplex([input, output], 'maxPoolBackprop');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
- const xValues = backend.data.get(x.dataId).values;
- const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const dx = buffer(x.shape, 'float32');
- const dyData = backend.data.get(dy.dataId).values;
- const dyBuf = buffer(dy.shape, 'float32', dyData);
- for (let b = 0; b < convInfo.batchSize; ++b) {
- for (let d = 0; d < convInfo.inChannels; ++d) {
- for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
- for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
- // Shader code begins.
- const dyRCorner = dxR - padTop;
- const dyCCorner = dxC - padLeft;
- let dotProd = 0;
- for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
- const dyR = (dyRCorner + wR) / strideHeight;
- if (dyR < 0 || dyR >= convInfo.outHeight ||
- Math.floor(dyR) !== dyR) {
- continue;
- }
- for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
- const dyC = (dyCCorner + wC) / strideWidth;
- if (dyC < 0 || dyC >= convInfo.outWidth ||
- Math.floor(dyC) !== dyC) {
- continue;
- }
- const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
- maxPosBuf.get(b, dyR, dyC, d);
- const curPos = wR * effectiveFilterWidth + wC;
- const mask = maxPos === curPos ? 1 : 0;
- if (mask === 0) {
- continue;
- }
- const pixel = dyBuf.get(b, dyR, dyC, d);
- dotProd += pixel * mask;
- }
- }
- dx.set(dotProd, b, dxR, dxC, d);
- }
- }
- }
- }
- return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
- }
- const maxPoolBackpropConfig = {
- kernelName: MaxPoolBackprop,
- backendName: 'cpu',
- kernelFunc: maxPoolBackprop$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
- const strides = computeStrides(xShape);
- const maxPools = pool$1(xValues, xShape, dtype, strides, convInfo, 'max');
- const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
- return [maxPools.values, maxPositions.values];
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxPoolWithArgmaxConfig = {
- kernelName: MaxPoolWithArgmax,
- backendName: 'cpu',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { filterSize, strides, pad, includeBatchInIndex } = attrs;
- const cpuBackend = backend;
- assertNotComplex(x, 'MaxPoolWithArgmax');
- const values = cpuBackend.data.get(x.dataId).values;
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
- const [pooled, indexes] = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
- const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
- const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
- return [
- { dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
- { dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
- ];
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function mirrorPad$1(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { paddings, mode } = attrs;
- assertNotComplex(x, 'mirrorPad');
- const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
- const start = paddings.map(p => p[0]);
- const end = paddings.map((p, i) => p[0] + x.shape[i]);
- const offset = mode === 'reflect' ? 0 : 1;
- const xVals = backend.data.get(x.dataId).values;
- const xRank = x.shape.length;
- const xStrides = computeStrides(x.shape);
- const resultSize = sizeFromShape(outShape);
- const resultRank = outShape.length;
- const resultStrides = computeStrides(outShape);
- const resVals = getTypedArrayFromDType(x.dtype, resultSize);
- for (let i = 0; i < resultSize; i++) {
- let coords = indexToLoc(i, resultRank, resultStrides);
- for (let i = 0; i < resultRank; i++) {
- if (coords[i] < start[i]) {
- coords[i] = start[i] * 2 - coords[i] - offset;
- }
- else if (coords[i] >= end[i]) {
- coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
- }
- }
- coords = coords.map((c, i) => c - start[i]);
- const inIndex = locToIndex(coords, xRank, xStrides);
- resVals[i] = xVals[inIndex];
- }
- const outId = backend.write(resVals, outShape, x.dtype);
- return { dataId: outId, shape: outShape, dtype: x.dtype };
- }
- const mirrorPadConfig = {
- kernelName: MirrorPad,
- backendName: 'cpu',
- kernelFunc: mirrorPad$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl;
- const nonMaxSuppressionV4Config = {
- kernelName: NonMaxSuppressionV4,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend, attrs }) => {
- const { boxes, scores } = inputs;
- const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
- const cpuBackend = backend;
- assertNotComplex(boxes, 'NonMaxSuppressionPadded');
- const boxesVals = cpuBackend.data.get(boxes.dataId).values;
- const scoresVals = cpuBackend.data.get(scores.dataId).values;
- const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
- return [selectedIndices, validOutputs];
- }
- };
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl;
- const nonMaxSuppressionV5Config = {
- kernelName: NonMaxSuppressionV5,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend, attrs }) => {
- const { boxes, scores } = inputs;
- const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
- const cpuBackend = backend;
- assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
- const boxesVals = cpuBackend.data.get(boxes.dataId).values;
- const scoresVals = cpuBackend.data.get(scores.dataId).values;
- const maxOutputSizeVal = maxOutputSize;
- const iouThresholdVal = iouThreshold;
- const scoreThresholdVal = scoreThreshold;
- const softNmsSigmaVal = softNmsSigma;
- const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
- return [selectedIndices, selectedScores];
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function padV2(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { paddings, constantValue } = attrs;
- assertNotComplex(x, 'pad');
- const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
- const start = paddings.map(p => p[0]);
- const xVals = backend.data.get(x.dataId).values;
- const xSize = sizeFromShape(x.shape);
- const xRank = x.shape.length;
- const xStrides = computeStrides(x.shape);
- const resultSize = sizeFromShape(outShape);
- const resultRank = outShape.length;
- const resultStrides = computeStrides(outShape);
- const resVals = getTypedArrayFromDType(x.dtype, resultSize);
- if (constantValue !== 0) {
- resVals.fill(constantValue);
- }
- for (let i = 0; i < xSize; i++) {
- const coords = indexToLoc(i, xRank, xStrides);
- const outCoords = coords.map((c, i) => c + start[i]);
- const outIndex = locToIndex(outCoords, resultRank, resultStrides);
- resVals[outIndex] = xVals[i];
- }
- const outId = backend.write(resVals, outShape, x.dtype);
- return { dataId: outId, shape: outShape, dtype: x.dtype };
- }
- const padV2Config = {
- kernelName: PadV2,
- backendName: 'cpu',
- kernelFunc: padV2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const reciprocal$1 = unaryKernelFunc(Reciprocal, (xi) => 1 / xi);
- const reciprocalConfig = {
- kernelName: Reciprocal,
- backendName: 'cpu',
- kernelFunc: reciprocal$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const rotateWithOffsetConfig = {
- kernelName: RotateWithOffset,
- backendName: 'cpu',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { image } = inputs;
- const { radians, fillValue, center } = attrs;
- const cpuBackend = backend;
- const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
- const [batch, imageHeight, imageWidth, numChannels] = image.shape;
- const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
- const fullOpacityValue = 255;
- const sinFactor = Math.sin(radians);
- const cosFactor = Math.cos(radians);
- const imageVals = cpuBackend.data.get(image.dataId).values;
- for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
- const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
- for (let row = 0; row < imageHeight; row++) {
- const rowOffset = row * (imageWidth * numChannels);
- for (let col = 0; col < imageWidth; col++) {
- const colOffset = col * numChannels;
- for (let channel = 0; channel < numChannels; channel++) {
- const coords = [batch, row, col, channel];
- const x = coords[2];
- const y = coords[1];
- // coordX/coordY are the result of rotating and translating x/y.
- let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
- let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
- coordX = Math.round(coordX + centerX);
- coordY = Math.round(coordY + centerY);
- let outputValue = fillValue;
- if (typeof fillValue !== 'number') {
- if (channel === 3) {
- outputValue = fullOpacityValue;
- }
- else {
- outputValue = fillValue[channel];
- }
- }
- // If the coordinate position falls within the image boundaries...
- if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
- coordY < imageHeight) {
- // set the output to the image value at the coordinate position.
- const rotatedRowOffset = coordY * (imageWidth * numChannels);
- const rotatedColOffset = coordX * numChannels;
- const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
- outputValue = imageVals[imageIdx];
- }
- const outIdx = batchOffset + rowOffset + colOffset + channel;
- output[outIdx] = outputValue;
- }
- }
- }
- }
- const dataId = cpuBackend.write(output, image.shape, image.dtype);
- return { dataId, shape: image.shape, dtype: image.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const round$1 = unaryKernelFunc(Round, (xi) => {
- // The algorithm is based on banker's rounding.
- const base = Math.floor(xi);
- if (xi - base < 0.5) {
- return Math.floor(xi);
- }
- else if (xi - base > 0.5) {
- return Math.ceil(xi);
- }
- else {
- if (base % 2.0 === 0.0) {
- return base;
- }
- else {
- return base + 1.0;
- }
- }
- });
- const roundConfig = {
- kernelName: Round,
- backendName: 'cpu',
- kernelFunc: round$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const scaleAlpha = SELU_SCALEALPHA;
- const scale = SELU_SCALE;
- const selu$1 = unaryKernelFunc(Selu, (xi) => {
- if (xi >= 0) {
- return scale * xi;
- }
- else {
- return scaleAlpha * (Math.exp(xi) - 1);
- }
- });
- const seluConfig = {
- kernelName: Selu,
- backendName: 'cpu',
- kernelFunc: selu$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sigmoid$1 = unaryKernelFunc(Sigmoid, (xi) => 1 / (1 + Math.exp(-xi)));
- const sigmoidConfig = {
- kernelName: Sigmoid,
- backendName: 'cpu',
- kernelFunc: sigmoid$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sign$2 = unaryKernelFunc(Sign, (xi) => {
- if (xi < 0) {
- return -1;
- }
- else if (xi > 0) {
- return 1;
- }
- else {
- return 0;
- }
- });
- const signConfig = {
- kernelName: Sign,
- backendName: 'cpu',
- kernelFunc: sign$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sin$1 = unaryKernelFunc(Sin, (xi) => Math.sin(xi));
- const sinConfig = {
- kernelName: Sin,
- backendName: 'cpu',
- kernelFunc: sin$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sinh$1 = unaryKernelFunc(Sinh, (xi) => Math.sinh(xi));
- const sinhConfig = {
- kernelName: Sinh,
- backendName: 'cpu',
- kernelFunc: sinh$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
- // epsilon is the difference between 1.0 and the next representable float.
- // For a single precision 32 bit float this should be 2^-23, see:
- // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
- const epsilon$1 = 1.1920928955078125e-7;
- const threshold = Math.log(epsilon$1) + 2.0;
- const softplus$1 = unaryKernelFunc(Softplus, (xi) => {
- // Value above which exp(x) may overflow, but softplus(x) == x
- // is within machine epsilon.
- const tooLarge = xi > -threshold;
- // Value below which exp(x) may underflow, but softplus(x) == exp(x)
- // is within machine epsilon.
- const tooSmall = xi < threshold;
- const expX = Math.exp(xi);
- let result;
- if (tooSmall) {
- result = expX;
- }
- else if (tooLarge) {
- result = xi;
- }
- else {
- result = Math.log(1.0 + expX);
- }
- return result;
- });
- const softplusConfig = {
- kernelName: Softplus,
- backendName: 'cpu',
- kernelFunc: softplus$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function transpose$1(args) {
- const { inputs, attrs, backend } = args;
- const { x } = inputs;
- const { perm } = attrs;
- assertNotComplex(x, 'transpose');
- const xRank = x.shape.length;
- const newShape = new Array(xRank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = x.shape[perm[i]];
- }
- const values = backend.data.get(x.dataId).values;
- const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
- const dataId = backend.write(result, newShape, x.dtype);
- return { dataId, shape: newShape, dtype: x.dtype };
- }
- const transposeConfig = {
- kernelName: Transpose,
- backendName: 'cpu',
- kernelFunc: transpose$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function spaceToBatchND$1(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { blockShape, paddings } = attrs;
- assertNotComplex([x], 'spaceToBatchND');
- const prod = sizeFromShape(blockShape);
- const completePaddings = [[0, 0]];
- completePaddings.push(...paddings);
- for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
- completePaddings.push([0, 0]);
- }
- const paddedX = padV2Config.kernelFunc({
- inputs: { x },
- backend,
- attrs: { paddings: completePaddings, constantValue: 0 }
- });
- const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
- const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
- const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
- const reshapeInputs = { x: paddedX };
- const reshapeAttrs = { shape: reshapedPaddedShape };
- const paddedXReshaped = reshape$2({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
- const transposeInputs = { x: paddedXReshaped };
- const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
- const paddedXT = transpose$1({ inputs: transposeInputs, backend, attrs: transposeAttrs });
- const resultReshapeInputs = { x: paddedXT };
- const resultReshapeAttrs = { shape: flattenShape };
- const result = reshape$2({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
- backend.disposeIntermediateTensorInfo(paddedX);
- backend.disposeIntermediateTensorInfo(paddedXReshaped);
- backend.disposeIntermediateTensorInfo(paddedXT);
- return result;
- }
- const spaceToBatchNDConfig = {
- kernelName: SpaceToBatchND,
- backendName: 'cpu',
- kernelFunc: spaceToBatchND$1
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const sqrt$1 = unaryKernelFunc(Sqrt, (xi) => Math.sqrt(xi));
- const sqrtConfig = {
- kernelName: Sqrt,
- backendName: 'cpu',
- kernelFunc: sqrt$1,
- };
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const squareConfig = {
- kernelName: Square,
- backendName: 'cpu',
- kernelFunc: ({ inputs, backend }) => {
- const { x } = inputs;
- const cpuBackend = backend;
- assertNotComplex(x, 'square');
- const values = cpuBackend.data.get(x.dataId).values;
- const newValues = new Float32Array(values.length);
- for (let i = 0; i < values.length; ++i) {
- const value = values[i];
- newValues[i] = value * value;
- }
- const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
- return { dataId, shape: x.shape, dtype: x.dtype };
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const step$1 = unaryKernelFunc(Step, (xi, attrs) => {
- const stepAttrs = attrs;
- if (isNaN(xi)) {
- return NaN;
- }
- else {
- return xi > 0 ? 1 : stepAttrs.alpha;
- }
- });
- const stepConfig = {
- kernelName: Step,
- backendName: 'cpu',
- kernelFunc: step$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const tan$1 = unaryKernelFunc(Tan, (xi) => Math.tan(xi));
- const tanConfig = {
- kernelName: Tan,
- backendName: 'cpu',
- kernelFunc: tan$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const tanh$2 = unaryKernelFunc(Tanh, (xi) => Math.tanh(xi));
- const tanhConfig = {
- kernelName: Tanh,
- backendName: 'cpu',
- kernelFunc: tanh$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function unique$2(args) {
- const { inputs, attrs, backend } = args;
- const { axis } = attrs;
- const { x } = inputs;
- assertNotComplex(x, 'unique');
- const values = backend.data.get(x.dataId).values;
- const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
- return [
- backend.makeTensorInfo(outputShape, x.dtype, outputValues),
- backend.makeTensorInfo([indices.length], 'int32', indices),
- ];
- }
- const uniqueConfig = {
- kernelName: Unique,
- backendName: 'cpu',
- kernelFunc: unique$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // List all kernel configs here
- const kernelConfigs = [
- _fusedMatMulConfig,
- fusedConv2DConfig,
- absConfig,
- acosConfig,
- acoshConfig,
- addConfig,
- asinConfig,
- asinhConfig,
- atanConfig,
- atanhConfig,
- avgPoolConfig,
- avgPoolBackpropConfig,
- batchMatMulConfig,
- batchNormConfig,
- castConfig,
- ceilConfig,
- clipConfig,
- complexConfig,
- concatConfig,
- conv2DConfig,
- cosConfig,
- coshConfig,
- dilation2dConfig,
- dilation2dBackpropInputConfig,
- dilation2dBackpropFilterConfig,
- divConfig,
- eluConfig,
- erfConfig,
- expConfig,
- expm1Config,
- fftConfig,
- flipLeftRightConfig,
- floorConfig,
- identityConfig,
- ifftConfig,
- imagConfig,
- isFiniteConfig,
- isInfConfig,
- isNaNConfig,
- logConfig,
- log1pConfig,
- logicalNotConfig,
- maxPoolConfig,
- maxPoolBackpropConfig,
- maxPoolWithArgmaxConfig,
- maxConfig,
- mirrorPadConfig,
- multiplyConfig,
- nonMaxSuppressionV4Config,
- nonMaxSuppressionV5Config,
- notEqualConfig,
- padV2Config,
- preluConfig,
- realConfig,
- reciprocalConfig,
- reluConfig,
- relu6Config,
- reshapeConfig,
- rotateWithOffsetConfig,
- roundConfig,
- rsqrtConfig,
- seluConfig,
- sigmoidConfig,
- signConfig,
- sinConfig,
- sinhConfig,
- sliceConfig,
- softplusConfig,
- spaceToBatchNDConfig,
- sqrtConfig,
- squareConfig,
- squaredDifferenceConfig,
- stepConfig,
- subConfig,
- tanConfig,
- tanhConfig,
- transposeConfig,
- uniqueConfig,
- ];
- for (const kernelConfig of kernelConfigs) {
- registerKernel(kernelConfig);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const contexts = {};
- const WEBGL_ATTRIBUTES = {
- alpha: false,
- antialias: false,
- premultipliedAlpha: false,
- preserveDrawingBuffer: false,
- depth: false,
- stencil: false,
- failIfMajorPerformanceCaveat: true
- };
- function clearWebGLContext(webGLVersion) {
- delete contexts[webGLVersion];
- }
- function setWebGLContext(webGLVersion, gl) {
- contexts[webGLVersion] = gl;
- }
- function getWebGLContext(webGLVersion) {
- if (!(webGLVersion in contexts)) {
- const newCtx = getWebGLRenderingContext(webGLVersion);
- if (newCtx !== null) {
- contexts[webGLVersion] = newCtx;
- }
- else {
- console.log('Could not get context for WebGL version', webGLVersion);
- return null;
- }
- }
- const gl = contexts[webGLVersion];
- if (gl.isContextLost()) {
- delete contexts[webGLVersion];
- return getWebGLContext(webGLVersion);
- }
- gl.disable(gl.DEPTH_TEST);
- gl.disable(gl.STENCIL_TEST);
- gl.disable(gl.BLEND);
- gl.disable(gl.DITHER);
- gl.disable(gl.POLYGON_OFFSET_FILL);
- gl.disable(gl.SAMPLE_COVERAGE);
- gl.enable(gl.SCISSOR_TEST);
- gl.enable(gl.CULL_FACE);
- gl.cullFace(gl.BACK);
- return contexts[webGLVersion];
- }
- function createCanvas(webGLVersion) {
- if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
- return new OffscreenCanvas(300, 150);
- }
- else if (typeof document !== 'undefined') {
- return document.createElement('canvas');
- }
- else {
- throw new Error('Cannot create a canvas in this context');
- }
- }
- function getWebGLRenderingContext(webGLVersion) {
- if (webGLVersion !== 1 && webGLVersion !== 2) {
- throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
- }
- const canvas = createCanvas(webGLVersion);
- canvas.addEventListener('webglcontextlost', (ev) => {
- ev.preventDefault();
- delete contexts[webGLVersion];
- }, false);
- if (webGLVersion === 1) {
- return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
- canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
- }
- return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PackingScheme;
- (function (PackingScheme) {
- /**
- * All values in a single texel are densely packed without any constraints.
- *
- * This is how the shader encodes a tensor with shape = [2, 3, 4]
- * (indices are [batch, row, col]).
- *
- * 000|001 010|011 020|021
- * ------- ------- -------
- * 002|003 012|013 022|023
- *
- * 100|101 110|111 120|121
- * ------- ------- -------
- * 102|103 112|113 122|123
- *
- */
- PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
- /**
- * Single texels contain only values from the same batch, and from adjacent
- * rows and columns.
- *
- * This is how the shader encodes a tensor with shape = [2, 3, 5]
- * (indices are [batch, row, col]).
- *
- * 000|001 002|003 004|xxx 020|021 022|023 024|xxx
- * ------- ------- ------- ------- ------- -------
- * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
- *
- * 100|101 102|103 104|xxx 120|121 122|123 124|xxx
- * ------- ------- ------- ------- ------- -------
- * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
- *
- */
- PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
- })(PackingScheme || (PackingScheme = {}));
- var TextureUsage;
- (function (TextureUsage) {
- TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
- TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
- TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
- TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
- })(TextureUsage || (TextureUsage = {}));
- var PhysicalTextureType;
- (function (PhysicalTextureType) {
- PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
- PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
- PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
- PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
- PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
- })(PhysicalTextureType || (PhysicalTextureType = {}));
- function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
- return [columns, rows];
- }
- function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
- return matrixSize * channelsPerTexture;
- }
- function getColorMatrixTextureShapeWidthHeight(rows, columns) {
- return [columns * 4, rows];
- }
- /**
- * Get shape for densely packed RGBA texture.
- */
- function getDenseTexShape(shape) {
- const size = sizeFromShape(shape);
- const texelsNeeded = Math.ceil(size / 4);
- return sizeToSquarishShape(texelsNeeded);
- }
- function getMatrixSizeFromUnpackedArraySize(unpackedSize, channelsPerTexture) {
- if (unpackedSize % channelsPerTexture !== 0) {
- throw new Error(`unpackedSize (${unpackedSize}) must be a multiple of ` +
- `${channelsPerTexture}`);
- }
- return unpackedSize / channelsPerTexture;
- }
- function decodeMatrixFromUnpackedColorRGBAArray(unpackedArray, matrix, channels) {
- const requiredSize = unpackedArray.length * channels / 4;
- if (matrix.length < requiredSize) {
- throw new Error(`matrix length (${matrix.length}) must be >= ${requiredSize}`);
- }
- let dst = 0;
- for (let src = 0; src < unpackedArray.length; src += 4) {
- for (let c = 0; c < channels; c++) {
- matrix[dst++] = unpackedArray[src + c];
- }
- }
- }
- function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
- return [
- Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
- ];
- }
- function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
- const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
- return w * h * 4;
- }
- function getTextureConfig(
- // tslint:disable-next-line:no-any
- gl, textureHalfFloatExtension) {
- // tslint:disable-next-line:no-any
- const glany = gl;
- let internalFormatFloat;
- let internalFormatHalfFloat;
- let internalFormatPackedHalfFloat;
- let internalFormatPackedFloat;
- let textureFormatFloat;
- let downloadTextureFormat;
- let downloadUnpackNumChannels;
- let defaultNumChannels;
- let textureTypeHalfFloat;
- let textureTypeFloat;
- if (env().getNumber('WEBGL_VERSION') === 2) {
- internalFormatFloat = glany.R32F;
- internalFormatHalfFloat = glany.R16F;
- internalFormatPackedHalfFloat = glany.RGBA16F;
- internalFormatPackedFloat = glany.RGBA32F;
- textureFormatFloat = glany.RED;
- downloadUnpackNumChannels = 4;
- defaultNumChannels = 1;
- textureTypeHalfFloat = glany.HALF_FLOAT;
- textureTypeFloat = glany.FLOAT;
- }
- else {
- internalFormatFloat = gl.RGBA;
- internalFormatHalfFloat = gl.RGBA;
- internalFormatPackedHalfFloat = gl.RGBA;
- internalFormatPackedFloat = glany.RGBA;
- textureFormatFloat = gl.RGBA;
- downloadUnpackNumChannels = 4;
- defaultNumChannels = 4;
- textureTypeHalfFloat = textureHalfFloatExtension != null ?
- textureHalfFloatExtension.HALF_FLOAT_OES :
- null;
- textureTypeFloat = gl.FLOAT;
- }
- downloadTextureFormat = gl.RGBA;
- return {
- internalFormatFloat,
- internalFormatHalfFloat,
- internalFormatPackedHalfFloat,
- internalFormatPackedFloat,
- textureFormatFloat,
- downloadTextureFormat,
- downloadUnpackNumChannels,
- defaultNumChannels,
- textureTypeHalfFloat,
- textureTypeFloat
- };
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function callAndCheck(gl, func) {
- const returnValue = func();
- if (env().getBool('DEBUG')) {
- checkWebGLError(gl);
- }
- return returnValue;
- }
- function checkWebGLError(gl) {
- const error = gl.getError();
- if (error !== gl.NO_ERROR) {
- throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
- }
- }
- // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
- const MIN_FLOAT16 = 5.96e-8;
- const MAX_FLOAT16 = 65504;
- function canBeRepresented(num) {
- if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
- (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
- return true;
- }
- return false;
- }
- function getWebGLErrorMessage(gl, status) {
- switch (status) {
- case gl.NO_ERROR:
- return 'NO_ERROR';
- case gl.INVALID_ENUM:
- return 'INVALID_ENUM';
- case gl.INVALID_VALUE:
- return 'INVALID_VALUE';
- case gl.INVALID_OPERATION:
- return 'INVALID_OPERATION';
- case gl.INVALID_FRAMEBUFFER_OPERATION:
- return 'INVALID_FRAMEBUFFER_OPERATION';
- case gl.OUT_OF_MEMORY:
- return 'OUT_OF_MEMORY';
- case gl.CONTEXT_LOST_WEBGL:
- return 'CONTEXT_LOST_WEBGL';
- default:
- return `Unknown error code ${status}`;
- }
- }
- function getExtensionOrThrow(gl, extensionName) {
- return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.');
- }
- function createVertexShader(gl, vertexShaderSource) {
- const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.');
- callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource));
- callAndCheck(gl, () => gl.compileShader(vertexShader));
- if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
- console.log(gl.getShaderInfoLog(vertexShader));
- throw new Error('Failed to compile vertex shader.');
- }
- return vertexShader;
- }
- function createFragmentShader(gl, fragmentShaderSource) {
- const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.');
- callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource));
- callAndCheck(gl, () => gl.compileShader(fragmentShader));
- if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
- logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
- throw new Error('Failed to compile fragment shader.');
- }
- return fragmentShader;
- }
- const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
- function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
- const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
- if (lineNumberRegexResult == null) {
- console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);
- console.log(shaderSource);
- return;
- }
- const lineNumber = +lineNumberRegexResult[1];
- const shaderLines = shaderSource.split('\n');
- const pad = shaderLines.length.toString().length + 2;
- const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line);
- let maxLineLength = 0;
- for (let i = 0; i < linesWithLineNumbers.length; i++) {
- maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
- }
- const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
- const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
- const afterErrorLines = linesWithLineNumbers.slice(lineNumber);
- console.log(beforeErrorLines.join('\n'));
- console.log(shaderInfoLog.split('\n')[0]);
- console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
- console.log(afterErrorLines.join('\n'));
- }
- function createProgram(gl) {
- return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.');
- }
- function linkProgram(gl, program) {
- callAndCheck(gl, () => gl.linkProgram(program));
- if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
- console.log(gl.getProgramInfoLog(program));
- throw new Error('Failed to link vertex and fragment shaders.');
- }
- }
- function validateProgram(gl, program) {
- callAndCheck(gl, () => gl.validateProgram(program));
- if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
- console.log(gl.getProgramInfoLog(program));
- throw new Error('Shader program validation failed.');
- }
- }
- function createStaticVertexBuffer(gl, data) {
- const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
- callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
- callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));
- return buffer;
- }
- function createStaticIndexBuffer(gl, data) {
- const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
- callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));
- callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));
- return buffer;
- }
- function getNumChannels() {
- if (env().getNumber('WEBGL_VERSION') === 2) {
- return 1;
- }
- return 4;
- }
- function createTexture(gl) {
- return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.');
- }
- function validateTextureSize(width, height) {
- const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
- if ((width <= 0) || (height <= 0)) {
- const requested = `[${width}x${height}]`;
- throw new Error('Requested texture size ' + requested + ' is invalid.');
- }
- if ((width > maxTextureSize) || (height > maxTextureSize)) {
- const requested = `[${width}x${height}]`;
- const max = `[${maxTextureSize}x${maxTextureSize}]`;
- throw new Error('Requested texture size ' + requested +
- ' greater than WebGL maximum on this browser / GPU ' + max + '.');
- }
- }
- function createFramebuffer(gl) {
- return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.');
- }
- function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
- const loc = gl.getAttribLocation(program, attribute);
- if (loc === -1) {
- // The GPU compiler decided to strip out this attribute because it's unused,
- // thus no need to bind.
- return false;
- }
- callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
- callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes));
- callAndCheck(gl, () => gl.enableVertexAttribArray(loc));
- return true;
- }
- function bindTextureUnit(gl, texture, textureUnit) {
- validateTextureUnit(gl, textureUnit);
- callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
- }
- function unbindTextureUnit(gl, textureUnit) {
- validateTextureUnit(gl, textureUnit);
- callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
- }
- function getProgramUniformLocationOrThrow(gl, program, uniformName) {
- return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.');
- }
- function getProgramUniformLocation(gl, program, uniformName) {
- return gl.getUniformLocation(program, uniformName);
- }
- function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
- callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit));
- callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit));
- }
- function bindCanvasToFramebuffer(gl) {
- callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
- callAndCheck(gl, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height));
- callAndCheck(gl, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height));
- }
- function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
- callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
- callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));
- }
- function unbindColorTextureFromFramebuffer(gl, framebuffer) {
- callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
- callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));
- }
- function validateFramebuffer(gl) {
- const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
- if (status !== gl.FRAMEBUFFER_COMPLETE) {
- throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
- }
- }
- function getFramebufferErrorMessage(gl, status) {
- switch (status) {
- case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
- return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
- case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
- return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
- case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
- return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
- case gl.FRAMEBUFFER_UNSUPPORTED:
- return 'FRAMEBUFFER_UNSUPPORTED';
- default:
- return `unknown error ${status}`;
- }
- }
- function throwIfNull(gl, returnTOrNull, failureMessage) {
- const tOrNull = callAndCheck(gl, () => returnTOrNull());
- if (tOrNull == null) {
- throw new Error(failureMessage);
- }
- return tOrNull;
- }
- function validateTextureUnit(gl, textureUnit) {
- const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
- const glTextureUnit = textureUnit + gl.TEXTURE0;
- if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
- const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;
- throw new Error(`textureUnit must be in ${textureUnitRange}.`);
- }
- }
- function getBatchDim(shape, dimsToSkip = 2) {
- return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
- }
- function getRowsCols(shape) {
- if (shape.length === 0) {
- throw Error('Cannot get rows and columns of an empty shape array.');
- }
- return [
- shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
- ];
- }
- function getShapeAs3D(shape) {
- let shapeAs3D = [1, 1, 1];
- const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
- if (!isScalar) {
- shapeAs3D =
- [getBatchDim(shape), ...getRowsCols(shape)];
- }
- return shapeAs3D;
- }
- function getTextureShapeFromLogicalShape(logShape, isPacked = false) {
- let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
- if (isPacked) {
- maxTexSize = maxTexSize * 2;
- // This logic ensures we accurately count the number of packed texels needed
- // to accommodate the tensor. We can only pack values in the same texel if
- // they are from adjacent pairs of rows/cols within the same batch. So if a
- // tensor has 3 rows, we pretend it has 4 rows in order to account for the
- // fact that the texels containing the third row are half empty.
- logShape = logShape.map((d, i) => i >= logShape.length - 2 ?
- nearestLargerEven(logShape[i]) :
- logShape[i]);
- // Packed texture height is at least 2 (the channel height of a single
- // texel).
- if (logShape.length === 1) {
- logShape = [2, logShape[0]];
- }
- }
- // If logical shape is 2, we don't squeeze, since we want to match physical.
- if (logShape.length !== 2) {
- const squeezeResult = squeezeShape(logShape);
- logShape = squeezeResult.newShape;
- }
- let size = sizeFromShape(logShape);
- if (logShape.length <= 1 && size <= maxTexSize) {
- return [1, size];
- }
- else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
- logShape[1] <= maxTexSize) {
- return logShape;
- }
- else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
- logShape[2] <= maxTexSize) {
- return [logShape[0] * logShape[1], logShape[2]];
- }
- else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
- logShape[1] * logShape[2] <= maxTexSize) {
- return [logShape[0], logShape[1] * logShape[2]];
- }
- else if (logShape.length === 4 &&
- logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
- logShape[3] <= maxTexSize) {
- return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
- }
- else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
- logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
- return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
- }
- else {
- if (isPacked) {
- // For packed textures size equals the number of channels required to
- // accommodate the texture data. However in order to squarify such that
- // inner dimensions stay even, we rewrite size to equal the number of
- // texels. Then in the return statement we rehydrate the squarified
- // dimensions to channel units.
- const batchDim = getBatchDim(logShape);
- let rows = 2, cols = 2;
- if (logShape.length) {
- [rows, cols] = getRowsCols(logShape);
- }
- size = batchDim * (rows / 2) * (cols / 2);
- return sizeToSquarishShape(size).map(d => d * 2);
- }
- return sizeToSquarishShape(size);
- }
- }
- function isEven(n) {
- return n % 2 === 0;
- }
- /**
- * This determines whether reshaping a packed texture requires rearranging
- * the data within the texture, assuming 2x2 packing.
- */
- function isReshapeFree(shape1, shape2) {
- shape1 = shape1.slice(-2);
- shape2 = shape2.slice(-2);
- if (arraysEqual(shape1, shape2)) {
- return true;
- }
- if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.
- return true;
- }
- if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
- shape2[1] === 0) {
- return true;
- }
- if (shape1.length !== shape2.length) { // One of the shapes is a vector.
- const shape1Cols = shape1.slice(-1)[0];
- const shape2Cols = shape2.slice(-1)[0];
- if (shape1Cols === shape2Cols) {
- return true;
- }
- if (isEven(shape1Cols) && isEven(shape2Cols) &&
- (shape1[0] === 1 || shape2[0] === 1)) {
- return true;
- }
- }
- return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
- }
- // We cache webgl params because the environment gets reset between
- // unit tests and we don't want to constantly query the WebGLContext for
- // MAX_TEXTURE_SIZE.
- let MAX_TEXTURE_SIZE;
- let MAX_TEXTURES_IN_SHADER;
- function getWebGLMaxTextureSize(webGLVersion) {
- if (MAX_TEXTURE_SIZE == null) {
- const gl = getWebGLContext(webGLVersion);
- MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
- }
- return MAX_TEXTURE_SIZE;
- }
- function resetMaxTextureSize() {
- MAX_TEXTURE_SIZE = null;
- }
- function resetMaxTexturesInShader() {
- MAX_TEXTURES_IN_SHADER = null;
- }
- function getMaxTexturesInShader(webGLVersion) {
- if (MAX_TEXTURES_IN_SHADER == null) {
- const gl = getWebGLContext(webGLVersion);
- MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
- }
- // We cap at 16 to avoid spurious runtime "memory exhausted" error.
- return Math.min(16, MAX_TEXTURES_IN_SHADER);
- }
- function getWebGLDisjointQueryTimerVersion(webGLVersion) {
- if (webGLVersion === 0) {
- return 0;
- }
- let queryTimerVersion;
- const gl = getWebGLContext(webGLVersion);
- if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
- webGLVersion === 2) {
- queryTimerVersion = 2;
- }
- else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
- queryTimerVersion = 1;
- }
- else {
- queryTimerVersion = 0;
- }
- return queryTimerVersion;
- }
- function hasExtension(gl, extensionName) {
- const ext = gl.getExtension(extensionName);
- return ext != null;
- }
- function isWebGLVersionEnabled(webGLVersion) {
- try {
- const gl = getWebGLContext(webGLVersion);
- if (gl != null) {
- return true;
- }
- }
- catch (e) {
- console.log('Error when getting WebGL context: ', e);
- return false;
- }
- return false;
- }
- function isCapableOfRenderingToFloatTexture(webGLVersion) {
- if (webGLVersion === 0) {
- return false;
- }
- const gl = getWebGLContext(webGLVersion);
- if (webGLVersion === 1) {
- if (!hasExtension(gl, 'OES_texture_float')) {
- return false;
- }
- }
- else {
- if (!hasExtension(gl, 'EXT_color_buffer_float')) {
- return false;
- }
- }
- const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
- return isFrameBufferComplete;
- }
- /**
- * Check if we can download values from a float/half-float texture.
- *
- * Note that for performance reasons we use binding a texture to a framebuffer
- * as a proxy for ability to download float values later using readPixels. The
- * texture params of this texture will not match those in readPixels exactly
- * but if we are unable to bind some kind of float texture to the frameBuffer
- * then we definitely will not be able to read float values from it.
- */
- function isDownloadFloatTextureEnabled(webGLVersion) {
- if (webGLVersion === 0) {
- return false;
- }
- const gl = getWebGLContext(webGLVersion);
- if (webGLVersion === 1) {
- if (!hasExtension(gl, 'OES_texture_float')) {
- return false;
- }
- if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
- return false;
- }
- }
- else {
- if (hasExtension(gl, 'EXT_color_buffer_float')) {
- return createFloatTextureAndBindToFramebuffer(gl);
- }
- const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
- if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
- const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
- return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
- }
- return false;
- }
- const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
- return isFrameBufferComplete;
- }
- function createFloatTextureAndBindToFramebuffer(gl) {
- const texConfig = getTextureConfig(gl);
- const texture = gl.createTexture();
- gl.bindTexture(gl.TEXTURE_2D, texture);
- const width = 1;
- const height = 1;
- gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
- const frameBuffer = gl.createFramebuffer();
- gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
- gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
- const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
- gl.bindTexture(gl.TEXTURE_2D, null);
- gl.bindFramebuffer(gl.FRAMEBUFFER, null);
- gl.deleteTexture(texture);
- gl.deleteFramebuffer(frameBuffer);
- return isFrameBufferComplete;
- }
- function createHalfFloatTextureAndBindToFramebuffer(
- // tslint:disable-next-line:no-any
- gl, textureHalfFloatExtension) {
- const texConfig = getTextureConfig(gl, textureHalfFloatExtension);
- const texture = gl.createTexture();
- gl.bindTexture(gl.TEXTURE_2D, texture);
- const width = 1;
- const height = 1;
- gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
- const frameBuffer = gl.createFramebuffer();
- gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
- gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
- const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
- gl.bindTexture(gl.TEXTURE_2D, null);
- gl.bindFramebuffer(gl.FRAMEBUFFER, null);
- gl.deleteTexture(texture);
- gl.deleteFramebuffer(frameBuffer);
- return isFrameBufferComplete;
- }
- function isWebGLFenceEnabled(webGLVersion) {
- if (webGLVersion !== 2) {
- return false;
- }
- const gl = getWebGLContext(webGLVersion);
- // tslint:disable-next-line:no-any
- const isEnabled = gl.fenceSync != null;
- return isEnabled;
- }
- function assertNotComplex$1(tensor, opName) {
- if (!Array.isArray(tensor)) {
- tensor = [tensor];
- }
- tensor.forEach(t => {
- if (t != null) {
- assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` +
- 'in the WebGL backend.');
- }
- });
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ENV$1 = env();
- /**
- * This file contains WebGL-specific flag registrations.
- */
- /**
- * True if WebGL is supported.
- */
- ENV$1.registerFlag('HAS_WEBGL', () => ENV$1.getNumber('WEBGL_VERSION') > 0);
- /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
- ENV$1.registerFlag('WEBGL_VERSION', () => {
- if (isWebGLVersionEnabled(2)) {
- return 2;
- }
- else if (isWebGLVersionEnabled(1)) {
- return 1;
- }
- return 0;
- });
- /** Whether to check for numerical representation problems. */
- ENV$1.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false);
- ENV$1.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV$1.get('WEBGL_VERSION') === 2);
- /** Whether the WebGL backend will sometimes forward ops to the CPU. */
- ENV$1.registerFlag('WEBGL_CPU_FORWARD', () => true);
- /** Whether the WebGL backend will always use f16 textures for rendering. */
- ENV$1.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);
- /** Whether to turn all packing related flags on. */
- ENV$1.registerFlag('WEBGL_PACK', () => ENV$1.getBool('HAS_WEBGL'));
- /** Whether we will pack the batchnormalization op. */
- ENV$1.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack the clip op. */
- ENV$1.registerFlag('WEBGL_PACK_CLIP', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack the depthwise conv op. */
- // TODO: https://github.com/tensorflow/tfjs/issues/1679
- ENV$1.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => false);
- /** Whether we will pack binary ops. */
- ENV$1.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack unary ops. */
- ENV$1.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack array ops. */
- ENV$1.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack image ops. */
- ENV$1.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will pack reduce ops. */
- ENV$1.registerFlag('WEBGL_PACK_REDUCE', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether packed WebGL kernels lazily unpack their outputs. */
- ENV$1.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV$1.getBool('WEBGL_PACK'));
- /** Whether we will use the im2col algorithm to speed up convolutions. */
- ENV$1.registerFlag('WEBGL_CONV_IM2COL', () => ENV$1.getBool('WEBGL_PACK'));
- /** The maximum texture dimension. */
- ENV$1.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV$1.getNumber('WEBGL_VERSION')));
- /** The maximum texture dimension. */
- ENV$1.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV$1.getNumber('WEBGL_VERSION')));
- /**
- * The disjoint_query_timer extension version.
- * 0: disabled, 1: EXT_disjoint_timer_query, 2:
- * EXT_disjoint_timer_query_webgl2.
- * In Firefox with WebGL 2.0,
- * EXT_disjoint_timer_query_webgl2 is not available, so we must use the
- * WebGL 1.0 extension.
- */
- ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {
- const webGLVersion = ENV$1.getNumber('WEBGL_VERSION');
- if (webGLVersion === 0) {
- return 0;
- }
- return getWebGLDisjointQueryTimerVersion(webGLVersion);
- });
- /**
- * Whether the timer object from the disjoint_query_timer extension gives
- * timing information that is reliable.
- */
- ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV$1.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
- !isMobile());
- /**
- * Whether the device is physically capable of rendering to float32 textures.
- */
- ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV$1.getNumber('WEBGL_VERSION')));
- /**
- * Whether rendering to float32 textures is enabled. If disabled, renders to
- * float16 textures.
- */
- ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {
- return ENV$1.getBool('WEBGL_FORCE_F16_TEXTURES') ?
- false :
- ENV$1.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
- });
- /**
- * Whether downloading float textures is enabled (16 or 32 bit). If disabled,
- * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
- */
- ENV$1.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV$1.getNumber('WEBGL_VERSION')));
- /** Whether the fence API is available. */
- ENV$1.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV$1.getNumber('WEBGL_VERSION')));
- /**
- * Tensors with size <= than this will be uploaded as uniforms, not textures.
- */
- ENV$1.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {
- // Use uniform uploads only when 32bit floats are supported. In
- // 16bit
- // environments there are problems with comparing a 16bit texture value
- // with a 32bit uniform value.
- const useUniforms = ENV$1.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
- return useUniforms ? 4 : 0;
- });
- /**
- * If the total number of bytes allocated on the GPU is greater than this
- * number, we will aggressively delete textures upon disposal with
- * gl.deleteMatrixTexture, rather than making them available for reuse.
- *
- * Default value -1 indicates that we will never aggressively delete textures.
- */
- ENV$1.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => {
- return -1;
- }, threshold => {
- if (threshold < 0 && threshold !== -1) {
- throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` +
- `delete) or at least 0, but got ${threshold}.`);
- }
- });
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const { simpleAbsImpl: simpleAbsImplCPU, addImpl: addImplCPU, ceilImpl: ceilImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, multiplyImpl: multiplyImplCPU, rsqrtImpl: rsqrtImplCPU, sliceImpl: sliceImplCPU, subImpl: subImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared;
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class AddNProgram {
- constructor(outputShape, shapes) {
- this.outputShape = [];
- this.outputShape = outputShape;
- this.variableNames = shapes.map((_, i) => `T${i}`);
- const snippets = [];
- // Get target elements from every input tensor.
- this.variableNames.forEach(variable => {
- snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
- });
- // Calculate the sum of all elements.
- const operation = this.variableNames
- .map(variable => {
- return `v${variable}`;
- })
- .join(' + ');
- this.userCode = `
- void main() {
- ${snippets.join('\n ')}
-
- float result = ${operation};
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class AddNPackedProgram {
- constructor(outputShape, shapes) {
- this.outputShape = [];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- this.variableNames = shapes.map((_, i) => `T${i}`);
- const snippets = [];
- // Get target elements from every input tensor.
- this.variableNames.forEach(variable => {
- snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
- });
- // Calculate the sum of all elements.
- const operation = this.variableNames
- .map(variable => {
- return `v${variable}`;
- })
- .join(' + ');
- this.userCode = `
- void main() {
- ${snippets.join('\n ')}
-
- vec4 result = ${operation};
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ArgMinMaxProgram {
- constructor(reduceInfo, op, firstPass) {
- this.variableNames = ['A'];
- const { windowSize, batchSize, outSize } = reduceInfo;
- if (!firstPass) {
- this.variableNames.push('bestIndicesA');
- }
- this.outputShape = [batchSize, outSize];
- const compOp = (op === 'max') ? '>' : '<';
- const indexSnippet = firstPass ?
- 'inOffset + i;' :
- 'round(getBestIndicesA(batch, inOffset + i));';
- this.userCode = `
- void main() {
- ivec2 coords = getOutputCoords();
- int batch = coords[0];
- int outIdx = coords[1];
- int inOffset = outIdx * ${windowSize};
-
- int bestIndex = inOffset;
- float bestValue = getA(batch, bestIndex);
-
- for (int i = 0; i < ${windowSize}; i++) {
- int inIdx = ${indexSnippet};
- float candidate = getA(batch, inIdx);
- if (candidate ${compOp} bestValue) {
- bestValue = candidate;
- bestIndex = inIdx;
- }
- }
- setOutput(float(bestIndex));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function getVecChannels(name, rank) {
- return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
- }
- function getChannels(name, rank) {
- if (rank === 1) {
- return [name];
- }
- return getVecChannels(name, rank);
- }
- function getSourceCoords(rank, dims) {
- if (rank === 1) {
- return 'rc';
- }
- let coords = '';
- for (let i = 0; i < rank; i++) {
- coords += dims[i];
- if (i < rank - 1) {
- coords += ',';
- }
- }
- return coords;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function getGlslDifferences() {
- let version;
- let attribute;
- let varyingVs;
- let varyingFs;
- let texture2D;
- let output;
- let defineOutput;
- let defineSpecialNaN;
- let defineSpecialInf;
- let defineRound;
- if (env().getNumber('WEBGL_VERSION') === 2) {
- version = '#version 300 es';
- attribute = 'in';
- varyingVs = 'out';
- varyingFs = 'in';
- texture2D = 'texture';
- output = 'outputColor';
- defineOutput = 'out vec4 outputColor;';
- // Use custom isnan definition to work across differences between
- // implementations on various platforms. While this should happen in ANGLE
- // we still see differences between android and windows (on chrome) when
- // using isnan directly.
- defineSpecialNaN = `
- bool isnan_custom(float val) {
- return (val > 0.0 || val < 0.0) ? false : val != 0.0;
- }
-
- bvec4 isnan_custom(vec4 val) {
- return bvec4(isnan_custom(val.x),
- isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
- }
-
- #define isnan(value) isnan_custom(value)
- `;
- // In webgl 2 we do not need to specify a custom isinf so there is no
- // need for a special INFINITY constant.
- defineSpecialInf = ``;
- defineRound = `
- #define round(value) newRound(value)
- int newRound(float value) {
- return int(floor(value + 0.5));
- }
-
- ivec4 newRound(vec4 value) {
- return ivec4(floor(value + vec4(0.5)));
- }
- `;
- }
- else {
- version = '';
- attribute = 'attribute';
- varyingVs = 'varying';
- varyingFs = 'varying';
- texture2D = 'texture2D';
- output = 'gl_FragColor';
- defineOutput = '';
- // WebGL1 has no built in isnan so we define one here.
- defineSpecialNaN = `
- #define isnan(value) isnan_custom(value)
- bool isnan_custom(float val) {
- return (val > 0. || val < 1. || val == 0.) ? false : true;
- }
- bvec4 isnan_custom(vec4 val) {
- return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
- }
- `;
- defineSpecialInf = `
- uniform float INFINITY;
-
- bool isinf(float val) {
- return abs(val) == INFINITY;
- }
- bvec4 isinf(vec4 val) {
- return equal(abs(val), vec4(INFINITY));
- }
- `;
- defineRound = `
- int round(float value) {
- return int(floor(value + 0.5));
- }
-
- ivec4 round(vec4 value) {
- return ivec4(floor(value + vec4(0.5)));
- }
- `;
- }
- return {
- version,
- attribute,
- varyingVs,
- varyingFs,
- texture2D,
- output,
- defineOutput,
- defineSpecialNaN,
- defineSpecialInf,
- defineRound
- };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Produces GLSL code that derives logical coordinates from a flat
- * index. The code performs integer division with each stride and decrements
- * the index until the index equals the final dimension coordinate.
- */
- function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') {
- const strides = computeStrides(shape);
- return strides
- .map((stride, i) => {
- const line1 = `int ${coords[i]} = ${index} / ${stride}`;
- const line2 = i === strides.length - 1 ?
- `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
- `index -= ${coords[i]} * ${stride}`;
- return `${line1}; ${line2};`;
- })
- .join('');
- }
- function buildVec(x) {
- if (x.length === 1) {
- return `${x[0]}`;
- }
- return `vec${x.length}(${x.join(',')})`;
- }
- /**
- * Produces GLSL code that computes the dot product of the input x and y
- * vectors. Handles splitting inputs into increments of vec4s when necessary.
- */
- function dotify(x, y) {
- if (x.length !== y.length) {
- throw new Error(`Vectors to be dotted must be of the same length -` +
- `got ${x.length} and ${y.length}`);
- }
- const slices = [];
- const nearestVec4 = Math.floor(x.length / 4);
- const nearestVec4Remainder = x.length % 4;
- for (let i = 0; i < nearestVec4; i++) {
- const xSlice = x.slice(i * 4, i * 4 + 4);
- const ySlice = y.slice(i * 4, i * 4 + 4);
- slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
- }
- if (nearestVec4Remainder !== 0) {
- let xSlice = x.slice(nearestVec4 * 4);
- let ySlice = y.slice(nearestVec4 * 4);
- if (xSlice.length === 1) {
- xSlice = xSlice.map(d => `float(${d})`);
- ySlice = ySlice.map(d => `float(${d})`);
- }
- slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
- }
- return slices.map((d, i) => `dot(${d})`).join('+');
- }
- /**
- * Produces GLSL that computes the flat index from 3D coordinates.
- */
- function getFlatIndexFrom3D(shape) {
- const strides = computeStrides(shape).map(d => d.toString());
- return `
- int getFlatIndex(ivec3 coords) {
- return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
- }
-`;
- }
- const ENCODE_FLOAT_SNIPPET = `
- const float FLOAT_MAX = 1.70141184e38;
- const float FLOAT_MIN = 1.17549435e-38;
-
- lowp vec4 encode_float(highp float v) {
- if (isnan(v)) {
- return vec4(255, 255, 255, 255);
- }
-
- highp float av = abs(v);
-
- if(av < FLOAT_MIN) {
- return vec4(0.0, 0.0, 0.0, 0.0);
- } else if(v > FLOAT_MAX) {
- return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
- } else if(v < -FLOAT_MAX) {
- return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
- }
-
- highp vec4 c = vec4(0,0,0,0);
-
- highp float e = floor(log2(av));
- highp float m = exp2(fract(log2(av))) - 1.0;
-
- c[2] = floor(128.0 * m);
- m -= c[2] / 128.0;
- c[1] = floor(32768.0 * m);
- m -= c[1] / 32768.0;
- c[0] = floor(8388608.0 * m);
-
- highp float ebias = e + 127.0;
- c[3] = floor(ebias / 2.0);
- ebias -= c[3] * 2.0;
- c[2] += floor(ebias) * 128.0;
-
- c[3] += 128.0 * step(0.0, -v);
-
- return c / 255.0;
- }
-`;
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const { getBroadcastDims: getBroadcastDims$1 } = backend_util;
- function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) {
- const prefixSnippets = [];
- inputsInfo.forEach(x => {
- const size = sizeFromShape(x.shapeInfo.logicalShape);
- // Snippet when we decided to upload the values as uniform.
- if (x.shapeInfo.isUniform) {
- prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);
- }
- else {
- prefixSnippets.push(`uniform sampler2D ${x.name};`);
- prefixSnippets.push(`uniform int offset${x.name};`);
- }
- });
- const inputPrefixSnippet = prefixSnippets.join('\n');
- const inputSamplingSnippet = inputsInfo
- .map(x => getInputSamplingSnippet(x, outputShape, usesPackedTextures))
- .join('\n');
- const outTexShape = outputShape.texShape;
- const glsl = getGlslDifferences();
- const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
- let outputSamplingSnippet;
- let floatTextureSetOutputSnippet;
- let shaderPrefix = getShaderPrefix(glsl);
- if (outputShape.isPacked) {
- outputSamplingSnippet =
- getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
- floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
- }
- else {
- outputSamplingSnippet =
- getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
- floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
- }
- if (usesPackedTextures) {
- shaderPrefix += SHADER_PACKED_PREFIX;
- }
- const source = [
- shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
- inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode
- ].join('\n');
- return source;
- }
- function getSamplerFromInInfo(inInfo) {
- const shape = inInfo.shapeInfo.logicalShape;
- switch (shape.length) {
- case 0:
- return getSamplerScalar(inInfo);
- case 1:
- return getSampler1D(inInfo);
- case 2:
- return getSampler2D(inInfo);
- case 3:
- return getSampler3D(inInfo);
- case 4:
- return getSampler4D(inInfo);
- case 5:
- return getSampler5D(inInfo);
- case 6:
- return getSampler6D(inInfo);
- default:
- throw new Error(`${shape.length}-D input sampling` +
- ` is not yet supported`);
- }
- }
- function getPackedSamplerFromInInfo(inInfo) {
- const shape = inInfo.shapeInfo.logicalShape;
- switch (shape.length) {
- case 0:
- return getPackedSamplerScalar(inInfo);
- case 1:
- return getPackedSampler1D(inInfo);
- case 2:
- return getPackedSampler2D(inInfo);
- case 3:
- return getPackedSampler3D(inInfo);
- default:
- return getPackedSamplerND(inInfo);
- }
- }
- function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false) {
- let res = '';
- if (usesPackedTextures) {
- res += getPackedSamplerFromInInfo(inInfo);
- }
- else {
- res += getSamplerFromInInfo(inInfo);
- }
- const inShape = inInfo.shapeInfo.logicalShape;
- const outShape = outShapeInfo.logicalShape;
- if (inShape.length <= outShape.length) {
- if (usesPackedTextures) {
- res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
- }
- else {
- res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
- }
- }
- return res;
- }
- function getPackedOutputSamplingSnippet(outShape, outTexShape) {
- switch (outShape.length) {
- case 0:
- return getOutputScalarCoords();
- case 1:
- return getOutputPacked1DCoords(outShape, outTexShape);
- case 2:
- return getOutputPacked2DCoords(outShape, outTexShape);
- case 3:
- return getOutputPacked3DCoords(outShape, outTexShape);
- default:
- return getOutputPackedNDCoords(outShape, outTexShape);
- }
- }
- function getOutputSamplingSnippet(outShape, outTexShape) {
- switch (outShape.length) {
- case 0:
- return getOutputScalarCoords();
- case 1:
- return getOutput1DCoords(outShape, outTexShape);
- case 2:
- return getOutput2DCoords(outShape, outTexShape);
- case 3:
- return getOutput3DCoords(outShape, outTexShape);
- case 4:
- return getOutput4DCoords(outShape, outTexShape);
- case 5:
- return getOutput5DCoords(outShape, outTexShape);
- case 6:
- return getOutput6DCoords(outShape, outTexShape);
- default:
- throw new Error(`${outShape.length}-D output sampling is not yet supported`);
- }
- }
- function getFloatTextureSampleSnippet(glsl) {
- return `
- float sampleTexture(sampler2D textureSampler, vec2 uv) {
- return ${glsl.texture2D}(textureSampler, uv).r;
- }
- `;
- }
- function getFloatTextureSetRSnippet(glsl) {
- return `
- void setOutput(float val) {
- ${glsl.output} = vec4(val, 0, 0, 0);
- }
- `;
- }
- function getFloatTextureSetRGBASnippet(glsl) {
- return `
- void setOutput(vec4 val) {
- ${glsl.output} = val;
- }
- `;
- }
- function getShaderPrefix(glsl) {
- const SHADER_PREFIX = `${glsl.version}
- precision highp float;
- precision highp int;
- precision highp sampler2D;
- ${glsl.varyingFs} vec2 resultUV;
- ${glsl.defineOutput}
- const vec2 halfCR = vec2(0.5, 0.5);
-
- struct ivec5
- {
- int x;
- int y;
- int z;
- int w;
- int u;
- };
-
- struct ivec6
- {
- int x;
- int y;
- int z;
- int w;
- int u;
- int v;
- };
-
- uniform float NAN;
- ${glsl.defineSpecialNaN}
- ${glsl.defineSpecialInf}
- ${glsl.defineRound}
-
- int imod(int x, int y) {
- return x - y * (x / y);
- }
-
- int idiv(int a, int b, float sign) {
- int res = a / b;
- int mod = imod(a, b);
- if (sign < 0. && mod != 0) {
- res -= 1;
- }
- return res;
- }
-
- //Based on the work of Dave Hoskins
- //https://www.shadertoy.com/view/4djSRW
- #define HASHSCALE1 443.8975
- float random(float seed){
- vec2 p = resultUV * seed;
- vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
- p3 += dot(p3, p3.yzx + 19.19);
- return fract((p3.x + p3.y) * p3.z);
- }
-
- ${SAMPLE_1D_SNIPPET}
- ${SAMPLE_2D_SNIPPET}
- ${SAMPLE_3D_SNIPPET}
- `;
- return SHADER_PREFIX;
- }
- const SAMPLE_1D_SNIPPET = `
-vec2 uvFromFlat(int texNumR, int texNumC, int index) {
- int texR = index / texNumC;
- int texC = index - texR * texNumC;
- return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
-}
-vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
- int texelIndex = index / 2;
- int texR = texelIndex / texNumC;
- int texC = texelIndex - texR * texNumC;
- return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
-}
-`;
- const SAMPLE_2D_SNIPPET = `
-vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
- int texNumC, int row, int col) {
- int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
- int texR = texelIndex / texNumC;
- int texC = texelIndex - texR * texNumC;
- return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
-}
-`;
- const SAMPLE_3D_SNIPPET = `
-vec2 packedUVfrom3D(int texNumR, int texNumC,
- int texelsInBatch, int texelsInLogicalRow, int b,
- int row, int col) {
- int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
- int texR = index / texNumC;
- int texC = index - texR * texNumC;
- return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
-}
-`;
- const SHADER_PACKED_PREFIX = `
- float getChannel(vec4 frag, vec2 innerDims) {
- vec2 modCoord = mod(innerDims, 2.);
- return modCoord.x == 0. ?
- (modCoord.y == 0. ? frag.r : frag.g) :
- (modCoord.y == 0. ? frag.b : frag.a);
- }
- float getChannel(vec4 frag, int dim) {
- float modCoord = mod(float(dim), 2.);
- return modCoord == 0. ? frag.r : frag.g;
- }
-`;
- function getOutputScalarCoords() {
- return `
- int getOutputCoords() {
- return 0;
- }
- `;
- }
- function getOutputPacked1DCoords(shape, texShape) {
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (packedTexShape[0] === 1) {
- return `
- int getOutputCoords() {
- return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
- }
- `;
- }
- if (packedTexShape[1] === 1) {
- return `
- int getOutputCoords() {
- return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
- }
- `;
- }
- return `
- int getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
- return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
- }
- `;
- }
- function getOutput1DCoords(shape, texShape) {
- if (texShape[0] === 1) {
- return `
- int getOutputCoords() {
- return int(resultUV.x * ${texShape[1]}.0);
- }
- `;
- }
- if (texShape[1] === 1) {
- return `
- int getOutputCoords() {
- return int(resultUV.y * ${texShape[0]}.0);
- }
- `;
- }
- return `
- int getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- return resTexRC.x * ${texShape[1]} + resTexRC.y;
- }
- `;
- }
- function getOutputPacked3DCoords(shape, texShape) {
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- const texelsInLogicalRow = Math.ceil(shape[2] / 2);
- const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
- return `
- ivec3 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
- int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
-
- int b = index / ${texelsInBatch};
- index -= b * ${texelsInBatch};
-
- int r = 2 * (index / ${texelsInLogicalRow});
- int c = imod(index, ${texelsInLogicalRow}) * 2;
-
- return ivec3(b, r, c);
- }
- `;
- }
- function getOutput3DCoords(shape, texShape) {
- const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
- return `
- ivec3 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
- ${coordsFromIndexSnippet}
- return ivec3(r, c, d);
- }
- `;
- }
- function getOutputPackedNDCoords(shape, texShape) {
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
- const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
- let texelsInBatchN = texelsInBatch;
- let batches = ``;
- let coords = 'b, r, c';
- for (let b = 2; b < shape.length - 1; b++) {
- texelsInBatchN *= shape[shape.length - b - 1];
- batches = `
- int b${b} = index / ${texelsInBatchN};
- index -= b${b} * ${texelsInBatchN};
- ` + batches;
- coords = `b${b}, ` + coords;
- }
- return `
- ivec${shape.length} getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
- int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
-
- ${batches}
-
- int b = index / ${texelsInBatch};
- index -= b * ${texelsInBatch};
-
- int r = 2 * (index / ${texelsInLogicalRow});
- int c = imod(index, ${texelsInLogicalRow}) * 2;
-
- return ivec${shape.length}(${coords});
- }
- `;
- }
- function getOutput4DCoords(shape, texShape) {
- const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
- return `
- ivec4 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
- ${coordsFromIndexSnippet}
- return ivec4(r, c, d, d2);
- }
- `;
- }
- function getOutput5DCoords(shape, texShape) {
- const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
- return `
- ivec5 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
- ${texShape[1]}));
-
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
-
- ${coordsFromIndexSnippet}
-
- ivec5 outShape = ivec5(r, c, d, d2, d3);
- return outShape;
- }
- `;
- }
- function getOutput6DCoords(shape, texShape) {
- const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
- return `
- ivec6 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
-
- ${coordsFromIndexSnippet}
-
- ivec6 result = ivec6(r, c, d, d2, d3, d4);
- return result;
- }
- `;
- }
- function getOutputPacked2DCoords(shape, texShape) {
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (arraysEqual(shape, texShape)) {
- return `
- ivec2 getOutputCoords() {
- return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
- }
- `;
- }
- // texels needed to accommodate a logical row
- const texelsInLogicalRow = Math.ceil(shape[1] / 2);
- /**
- * getOutputCoords
- *
- * resTexRC: The rows and columns of the texels. If you move over one
- * texel to the right in the packed texture, you are moving over one column
- * (not two).
- *
- * index: The texel index
- */
- return `
- ivec2 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
-
- int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
- int r = 2 * (index / ${texelsInLogicalRow});
- int c = imod(index, ${texelsInLogicalRow}) * 2;
-
- return ivec2(r, c);
- }
- `;
- }
- function getOutput2DCoords(shape, texShape) {
- if (arraysEqual(shape, texShape)) {
- return `
- ivec2 getOutputCoords() {
- return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
- }
- `;
- }
- if (shape[1] === 1) {
- return `
- ivec2 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
- return ivec2(index, 0);
- }
- `;
- }
- if (shape[0] === 1) {
- return `
- ivec2 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
- return ivec2(0, index);
- }
- `;
- }
- return `
- ivec2 getOutputCoords() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
- int r = index / ${shape[1]};
- int c = index - r * ${shape[1]};
- return ivec2(r, c);
- }
- `;
- }
- function getFlatOffsetUniformName(texName) {
- return `offset${texName}`;
- }
- function getPackedSamplerScalar(inputInfo) {
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const glsl = getGlslDifferences();
- return `
- vec4 ${funcName}() {
- return ${glsl.texture2D}(${texName}, halfCR);
- }
- `;
- }
- function getSamplerScalar(inputInfo) {
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- if (inputInfo.shapeInfo.isUniform) {
- return `float ${funcName}() {return ${texName};}`;
- }
- const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;
- if (texNumR === 1 && texNumC === 1) {
- return `
- float ${funcName}() {
- return sampleTexture(${texName}, halfCR);
- }
- `;
- }
- const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;
- const offset = getFlatOffsetUniformName(texName);
- return `
- float ${funcName}() {
- vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getPackedSampler1D(inputInfo) {
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const texShape = inputInfo.shapeInfo.texShape;
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- const glsl = getGlslDifferences();
- return `
- vec4 ${funcName}(int index) {
- vec2 uv = packedUVfrom1D(
- ${packedTexShape[0]}, ${packedTexShape[1]}, index);
- return ${glsl.texture2D}(${texName}, uv);
- }
- `;
- }
- function getSampler1D(inputInfo) {
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int index) {
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const texShape = inputInfo.shapeInfo.texShape;
- const tNumR = texShape[0];
- const tNumC = texShape[1];
- if (tNumC === 1 && tNumR === 1) {
- return `
- float ${funcName}(int index) {
- return sampleTexture(${texName}, halfCR);
- }
- `;
- }
- const offset = getFlatOffsetUniformName(texName);
- if (tNumC === 1) {
- return `
- float ${funcName}(int index) {
- vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (tNumR === 1) {
- return `
- float ${funcName}(int index) {
- vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- return `
- float ${funcName}(int index) {
- vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getPackedSampler2D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const texShape = inputInfo.shapeInfo.texShape;
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- const glsl = getGlslDifferences();
- if (texShape != null && arraysEqual(shape, texShape)) {
- return `
- vec4 ${funcName}(int row, int col) {
- vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
-
- return ${glsl.texture2D}(${texName}, uv);
- }
- `;
- }
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- const valuesPerRow = Math.ceil(shape[1] / 2);
- return `
- vec4 ${funcName}(int row, int col) {
- vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
- return ${glsl.texture2D}(${texName}, uv);
- }
- `;
- }
- function getSampler2D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const texShape = inputInfo.shapeInfo.texShape;
- if (texShape != null && arraysEqual(shape, texShape)) {
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- return `
- float ${funcName}(int row, int col) {
- vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- const { newShape, keptDims } = squeezeShape(shape);
- const squeezedShape = newShape;
- if (squeezedShape.length < shape.length) {
- const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- const params = ['row', 'col'];
- return `
- ${getSamplerFromInInfo(newInputInfo)}
- float ${funcName}(int row, int col) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col) {
- int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- const offset = getFlatOffsetUniformName(texName);
- if (texNumC === 1) {
- // index is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col) {
- float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
- vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (texNumR === 1) {
- // index is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col) {
- float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
- vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- return `
- float ${funcName}(int row, int col) {
- // Explicitly use integer operations as dot() only works on floats.
- int index = row * ${shape[1]} + col + ${offset};
- vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
- return sampleTexture(${texName}, uv);
- }
-`;
- }
- function getPackedSampler3D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const texShape = inputInfo.shapeInfo.texShape;
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (shape[0] === 1) {
- const squeezedShape = shape.slice(1);
- const keptDims = [1, 2];
- const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- const params = ['b', 'row', 'col'];
- return `
- ${getPackedSamplerFromInInfo(newInputInfo)}
- vec4 ${funcName}(int b, int row, int col) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- const texNumR = packedTexShape[0];
- const texNumC = packedTexShape[1];
- const valuesPerRow = Math.ceil(shape[2] / 2);
- const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
- const glsl = getGlslDifferences();
- return `
- vec4 ${funcName}(int b, int row, int col) {
- vec2 uv = packedUVfrom3D(
- ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
- return ${glsl.texture2D}(${texName}, uv);
- }
- `;
- }
- function getSampler3D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const stride0 = shape[1] * shape[2];
- const stride1 = shape[2];
- const { newShape, keptDims } = squeezeShape(shape);
- const squeezedShape = newShape;
- if (squeezedShape.length < shape.length) {
- const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- const params = ['row', 'col', 'depth'];
- return `
- ${getSamplerFromInInfo(newInputInfo)}
- float ${funcName}(int row, int col, int depth) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth) {
- int index = round(dot(vec3(row, col, depth),
- vec3(${stride0}, ${stride1}, 1)));
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const texShape = inputInfo.shapeInfo.texShape;
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- const flatOffset = inputInfo.shapeInfo.flatOffset;
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth) {
- float texR = float(row);
- float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (texNumC === stride1 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth) {
- float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
- float texC = float(depth);
- vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- const offset = getFlatOffsetUniformName(texName);
- return `
- float ${funcName}(int row, int col, int depth) {
- // Explicitly use integer operations as dot() only works on floats.
- int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
- vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getPackedSamplerND(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const rank = shape.length;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const texShape = inputInfo.shapeInfo.texShape;
- const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- const texNumR = packedTexShape[0];
- const texNumC = packedTexShape[1];
- const valuesPerRow = Math.ceil(shape[rank - 1] / 2);
- let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
- let params = `int b, int row, int col`;
- let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
- for (let b = 2; b < rank - 1; b++) {
- params = `int b${b}, ` + params;
- texelsInBatch *= shape[rank - b - 1];
- index = `b${b} * ${texelsInBatch} + ` + index;
- }
- const glsl = getGlslDifferences();
- return `
- vec4 ${funcName}(${params}) {
- int index = ${index};
- int texR = index / ${texNumC};
- int texC = index - texR * ${texNumC};
- vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
- return ${glsl.texture2D}(${texName}, uv);
- }
- `;
- }
- function getSampler4D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const stride2 = shape[3];
- const stride1 = shape[2] * stride2;
- const stride0 = shape[1] * stride1;
- const { newShape, keptDims } = squeezeShape(shape);
- if (newShape.length < shape.length) {
- const newInputInfo = squeezeInputInfo(inputInfo, newShape);
- const params = ['row', 'col', 'depth', 'depth2'];
- return `
- ${getSamplerFromInInfo(newInputInfo)}
- float ${funcName}(int row, int col, int depth, int depth2) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2) {
- int index = round(dot(vec4(row, col, depth, depth2),
- vec4(${stride0}, ${stride1}, ${stride2}, 1)));
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const flatOffset = inputInfo.shapeInfo.flatOffset;
- const texShape = inputInfo.shapeInfo.texShape;
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2) {
- float texR = float(row);
- float texC =
- dot(vec3(col, depth, depth2),
- vec3(${stride1}, ${stride2}, 1));
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (texNumC === stride2 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2) {
- float texR = dot(vec3(row, col, depth),
- vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));
- float texC = float(depth2);
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- const offset = getFlatOffsetUniformName(texName);
- return `
- float ${funcName}(int row, int col, int depth, int depth2) {
- // Explicitly use integer operations as dot() only works on floats.
- int index = row * ${stride0} + col * ${stride1} +
- depth * ${stride2} + depth2;
- vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getSampler5D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const stride3 = shape[4];
- const stride2 = shape[3] * stride3;
- const stride1 = shape[2] * stride2;
- const stride0 = shape[1] * stride1;
- const { newShape, keptDims } = squeezeShape(shape);
- if (newShape.length < shape.length) {
- const newInputInfo = squeezeInputInfo(inputInfo, newShape);
- const params = ['row', 'col', 'depth', 'depth2', 'depth3'];
- return `
- ${getSamplerFromInInfo(newInputInfo)}
- float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
- float index = dot(
- vec4(row, col, depth, depth2),
- vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
- depth3;
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const flatOffset = inputInfo.shapeInfo.flatOffset;
- const texShape = inputInfo.shapeInfo.texShape;
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
- int texR = row;
- float texC = dot(vec4(col, depth, depth2, depth3),
- vec4(${stride1}, ${stride2}, ${stride3}, 1));
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (texNumC === stride3 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
- float texR = dot(
- vec4(row, col, depth, depth2),
- vec4(${shape[1] * shape[2] * shape[3]},
- ${shape[2] * shape[3]}, ${shape[3]}, 1));
- int texC = depth3;
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- const offset = getFlatOffsetUniformName(texName);
- return `
- float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
- // Explicitly use integer operations as dot() only works on floats.
- int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
- depth2 * ${stride3} + depth3 + ${offset};
- vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getSampler6D(inputInfo) {
- const shape = inputInfo.shapeInfo.logicalShape;
- const texName = inputInfo.name;
- const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- const { newShape, keptDims } = squeezeShape(shape);
- if (newShape.length < shape.length) {
- const newInputInfo = squeezeInputInfo(inputInfo, newShape);
- const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
- return `
- ${getSamplerFromInInfo(newInputInfo)}
- float ${funcName}(int row, int col, int depth,
- int depth2, int depth3, int depth4) {
- return ${funcName}(${getSqueezedParams(params, keptDims)});
- }
- `;
- }
- const stride4 = shape[5];
- const stride3 = shape[4] * stride4;
- const stride2 = shape[3] * stride3;
- const stride1 = shape[2] * stride2;
- const stride0 = shape[1] * stride1;
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth,
- int depth2, int depth3, int depth4) {
- int index = round(dot(
- vec4(row, col, depth, depth2),
- vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
- dot(
- vec2(depth3, depth4),
- vec2(${stride4}, 1)));
- ${getUniformSampler(inputInfo)}
- }
- `;
- }
- const flatOffset = inputInfo.shapeInfo.flatOffset;
- const texShape = inputInfo.shapeInfo.texShape;
- const texNumR = texShape[0];
- const texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth,
- int depth2, int depth3, int depth4) {
- int texR = row;
- float texC = dot(vec4(col, depth, depth2, depth3),
- vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
- float(depth4);
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- if (texNumC === stride4 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return `
- float ${funcName}(int row, int col, int depth,
- int depth2, int depth3, int depth4) {
- float texR = dot(vec4(row, col, depth, depth2),
- vec4(${shape[1] * shape[2] * shape[3] * shape[4]},
- ${shape[2] * shape[3] * shape[4]},
- ${shape[3] * shape[4]},
- ${shape[4]})) + float(depth3);
- int texC = depth4;
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${texNumC}.0, ${texNumR}.0);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- const offset = getFlatOffsetUniformName(texName);
- return `
- float ${funcName}(int row, int col, int depth,
- int depth2, int depth3, int depth4) {
- // Explicitly use integer operations as dot() only works on floats.
- int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
- depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
- vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
- return sampleTexture(${texName}, uv);
- }
- `;
- }
- function getUniformSampler(inputInfo) {
- const texName = inputInfo.name;
- const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
- if (inSize < 2) {
- return `return ${texName};`;
- }
- return `
- for (int i = 0; i < ${inSize}; i++) {
- if (i == index) {
- return ${texName}[i];
- }
- }
- `;
- }
- function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
- const texName = inputInfo.name;
- const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
- const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
- const inRank = inputInfo.shapeInfo.logicalShape.length;
- const outRank = outShapeInfo.logicalShape.length;
- const broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
- const type = getCoordsDataType(outRank);
- const rankDiff = outRank - inRank;
- let coordsSnippet;
- const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
- if (inRank === 0) {
- coordsSnippet = '';
- }
- else if (outRank < 2 && broadcastDims.length >= 1) {
- coordsSnippet = 'coords = 0;';
- }
- else {
- coordsSnippet =
- broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
- .join('\n');
- }
- let unpackedCoordsSnippet = '';
- if (outRank < 2 && inRank > 0) {
- unpackedCoordsSnippet = 'coords';
- }
- else {
- unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
- .map((s, i) => `coords.${fields[i + rankDiff]}`)
- .join(', ');
- }
- let output = `return outputValue;`;
- const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
- const isInputScalar = inSize === 1;
- const outSize = sizeFromShape(outShapeInfo.logicalShape);
- const isOutputScalar = outSize === 1;
- if (inRank === 1 && !isInputScalar && !isOutputScalar) {
- output = `
- return vec4(outputValue.xy, outputValue.xy);
- `;
- }
- else if (isInputScalar && !isOutputScalar) {
- if (outRank === 1) {
- output = `
- return vec4(outputValue.x, outputValue.x, 0., 0.);
- `;
- }
- else {
- output = `
- return vec4(outputValue.x);
- `;
- }
- }
- else if (broadcastDims.length) {
- const rows = inRank - 2;
- const cols = inRank - 1;
- if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
- output = `return vec4(outputValue.x);`;
- }
- else if (broadcastDims.indexOf(rows) > -1) {
- output = `return vec4(outputValue.x, outputValue.y, ` +
- `outputValue.x, outputValue.y);`;
- }
- else if (broadcastDims.indexOf(cols) > -1) {
- output = `return vec4(outputValue.xx, outputValue.zz);`;
- }
- }
- return `
- vec4 ${funcName}() {
- ${type} coords = getOutputCoords();
- ${coordsSnippet}
- vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
- ${output}
- }
- `;
- }
- function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
- const texName = inputInfo.name;
- const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
- const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
- const outTexShape = outShapeInfo.texShape;
- const inTexShape = inputInfo.shapeInfo.texShape;
- const inRank = inputInfo.shapeInfo.logicalShape.length;
- const outRank = outShapeInfo.logicalShape.length;
- if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
- inputInfo.shapeInfo.flatOffset == null &&
- arraysEqual(inTexShape, outTexShape)) {
- return `
- float ${funcName}() {
- return sampleTexture(${texName}, resultUV);
- }
- `;
- }
- const type = getCoordsDataType(outRank);
- const broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
- const rankDiff = outRank - inRank;
- let coordsSnippet;
- const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
- if (inRank === 0) {
- coordsSnippet = '';
- }
- else if (outRank < 2 && broadcastDims.length >= 1) {
- coordsSnippet = 'coords = 0;';
- }
- else {
- coordsSnippet =
- broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
- .join('\n');
- }
- let unpackedCoordsSnippet = '';
- if (outRank < 2 && inRank > 0) {
- unpackedCoordsSnippet = 'coords';
- }
- else {
- unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
- .map((s, i) => `coords.${fields[i + rankDiff]}`)
- .join(', ');
- }
- return `
- float ${funcName}() {
- ${type} coords = getOutputCoords();
- ${coordsSnippet}
- return get${texFuncSnippet}(${unpackedCoordsSnippet});
- }
- `;
- }
- function getCoordsDataType(rank) {
- if (rank <= 1) {
- return 'int';
- }
- else if (rank === 2) {
- return 'ivec2';
- }
- else if (rank === 3) {
- return 'ivec3';
- }
- else if (rank === 4) {
- return 'ivec4';
- }
- else if (rank === 5) {
- return 'ivec5';
- }
- else if (rank === 6) {
- return 'ivec6';
- }
- else {
- throw Error(`GPU for rank ${rank} is not yet supported`);
- }
- }
- /** Returns a new input info (a copy) that has a squeezed logical shape. */
- function squeezeInputInfo(inInfo, squeezedShape) {
- // Deep copy.
- const newInputInfo = JSON.parse(JSON.stringify(inInfo));
- newInputInfo.shapeInfo.logicalShape = squeezedShape;
- return newInputInfo;
- }
- function getSqueezedParams(params, keptDims) {
- return keptDims.map(d => params[d]).join(', ');
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ArgMinMaxPackedProgram {
- constructor(shape, windowSize, op, firstPass) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- assert(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
- op.slice(1)} supports only inputs with rank above 2.`);
- const inSize = shape[shape.length - 1];
- const outSize = Math.ceil(inSize / windowSize);
- this.outputShape = shape.slice(0, -1);
- if (outSize > 1) {
- this.outputShape.push(outSize);
- }
- if (!firstPass) {
- this.variableNames.push('bestIndicesA');
- }
- const outShape = this.outputShape;
- const rank = outShape.length;
- const dtype = getCoordsDataType(rank);
- const coords = getChannels('coords', rank);
- let sourceLocSetup;
- let sourceRank;
- if (outSize === 1) {
- sourceRank = rank + 1;
- const sourceLocDType = getCoordsDataType(sourceRank);
- sourceLocSetup = `
- ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
- ++${coords[rank - 1]};
- ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
- ++${coords[rank - 2]};
- ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
- --${coords[rank - 1]};
- ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
- --${coords[rank - 2]};`;
- }
- else {
- sourceRank = rank;
- sourceLocSetup = `
- ${dtype} sourceLocR = coords;
- ++${coords[rank - 1]};
- ${dtype} sourceLocG = coords;
- ++${coords[rank - 2]};
- ${dtype} sourceLocA = coords;
- --${coords[rank - 1]};
- ${dtype} sourceLocB = coords;
- --${coords[rank - 2]};`;
- }
- const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
- const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
- const intChannels = channels.map(x => 'int ' + x);
- const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
- const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
- const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
- const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
- const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
- const fetchCandidateIdx = firstPass ? '' : `
- inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
- getBestIndicesAChannel(${srcGCoords.join()}),
- getBestIndicesAChannel(${srcBCoords.join()}),
- getBestIndicesAChannel(${srcACoords.join()})));`;
- const fetchValue = `vec4(
- getAChannel(${srcRCoords.join()}),
- hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
- hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
- hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
- const getBestIndicesAChannelSnippet = firstPass ? '' : `
- float getBestIndicesAChannel(${intChannels.join()}) {
- return getChannel(getBestIndicesA(${channels.join()}),
- vec2(${channels.slice(-2).join()}));
- }`;
- this.userCode = `
- float getAChannel(${intChannels.join()}) {
- return getChannel(getA(${channels.join()}),
- vec2(${channels.slice(-2).join()}));
- }
- ${getBestIndicesAChannelSnippet}
- void main() {
- ${dtype} coords = getOutputCoords();
- bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
- bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
- ${sourceLocSetup}
- ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
- sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
- ivec4 inIdx = srcIdx;
- vec4 bestIndex = vec4(inIdx);
- vec4 bestValue = ${fetchValue};
-
- for (int i = 0; i < ${windowSize}; i++) {
- inIdx = srcIdx;
- ${fetchCandidateIdx}
- vec4 candidate = ${fetchValue};
- bvec4 nan = isnan(candidate);
- bvec4 replace = bvec4(
- vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
-
- bestValue = vec4(replace.x ? candidate.x : bestValue.x,
- replace.y ? candidate.y : bestValue.y,
- replace.z ? candidate.z : bestValue.z,
- replace.w ? candidate.w : bestValue.w);
- bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
- srcIdx++;
- }
- setOutput(bestIndex);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class AvgPool2DBackpropProgram {
- constructor(convInfo) {
- this.variableNames = ['dy'];
- this.outputShape = convInfo.inShape;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const avgMultiplier = 1 / (filterHeight * filterWidth);
- this.userCode = `
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
- const float avgMultiplier = float(${avgMultiplier});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
-
- ivec2 dyRCCorner = coords.yz - pads;
- int dyRCorner = dyRCCorner.x;
- int dyCCorner = dyRCCorner.y;
-
- // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- for (int wC = 0; wC < ${effectiveFilterWidth};
- wC+= ${dilationWidth}) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- float dyValue = getDy(b, idyR, idyC, d);
-
- dotProd += dyValue * avgMultiplier;
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class AvgPool3DBackpropProgram {
- constructor(convInfo) {
- this.variableNames = ['dy'];
- this.outputShape = convInfo.inShape;
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
- this.userCode = `
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
- const float avgMultiplier = float(${avgMultiplier});
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int ch = coords.u;
-
- ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
- int dyDCorner = dyCorner.x;
- int dyRCorner = dyCorner.y;
- int dyCCorner = dyCorner.z;
-
- // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get
- // dx(xD, xR, xC, ch).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
-
- for (int wD = 0; wD < ${effectiveFilterDepth};
- wD += ${dilationDepth}) {
- float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
-
- if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
- continue;
- }
- int idyD = int(dyD);
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
- fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- for (int wC = 0; wC < ${effectiveFilterWidth};
- wC += ${dilationWidth}) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- float dyValue = getDy(batch, idyD, idyR, idyC, ch);
-
- dotProd += dyValue * avgMultiplier;
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const CHECK_NAN_SNIPPET = `
- if (isnan(a)) return a;
- if (isnan(b)) return b;
-`;
- // We use native integer division to deal with floating point imprecision. Since
- // we implement floor division and glsl implements truncated division, we
- // correct for this by subtracting 1 from result when the result is negative and
- // there is a remainder.
- const INT_DIV = `
- float s = sign(a) * sign(b);
- int ia = round(a);
- int ib = round(b);
- if (ib != 0) {
- // Windows (D3D) wants guaranteed non-zero int division at compile-time.
- return float(idiv(ia, ib, s));
- } else {
- return NAN;
- }
-`;
- const POW = `
-if(a < 0.0 && floor(b) < b){
- return NAN;
-}
-if (b == 0.0) {
- return 1.0;
-}
-return (round(mod(b, 2.0)) != 1) ?
- pow(abs(a), b) : sign(a) * pow(abs(a), b);
-`;
- const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
- const EQUAL = `return float(a == b);`;
- const LESS = `return float(a < b);`;
- const LESS_EQUAL = `return float(a <= b);`;
- const GREATER = `return float(a > b);`;
- const GREATER_EQUAL = `return float(a >= b);`;
- const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;
- const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;
- const MAX = CHECK_NAN_SNIPPET + `
- return max(a, b);
-`;
- const MIN = CHECK_NAN_SNIPPET + `
- return min(a, b);
-`;
- const MOD = `if (b == 0.0) return NAN;
- return mod(a, b);`;
- const ELU_DER = `return (b >= 1.0) ? a : a * (b + 1.0);`;
- const PRELU = `return (a < 0.) ? b * a : a;`;
- class BinaryOpProgram {
- constructor(op, aShape, bShape) {
- this.variableNames = ['A', 'B'];
- this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
- this.userCode = `
- float binaryOperation(float a, float b) {
- ${op}
- }
-
- void main() {
- float a = getAAtOutCoords();
- float b = getBAtOutCoords();
- setOutput(binaryOperation(a, b));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const CHECK_NAN_SNIPPET$1 = `
- result.r = isNaN.r > 0. ? NAN : result.r;
- result.g = isNaN.g > 0. ? NAN : result.g;
- result.b = isNaN.b > 0. ? NAN : result.b;
- result.a = isNaN.a > 0. ? NAN : result.a;
-`;
- const INT_DIV$1 = `
- ivec4 ia = round(a);
- ivec4 ib = round(b);
- bvec4 cond = notEqual(ib, ivec4(0));
- ivec4 result = ivec4(0);
- vec4 s = sign(a) * sign(b);
-
- // Windows (D3D) wants guaranteed non-zero int division at compile-time.
- if (cond[0]) {
- result[0] = idiv(ia[0], ib[0], s[0]);
- }
- if (cond[1]) {
- result[1] = idiv(ia[1], ib[1], s[1]);
- }
- if (cond[2]) {
- result[2] = idiv(ia[2], ib[2], s[2]);
- }
- if (cond[3]) {
- result[3] = idiv(ia[3], ib[3], s[3]);
- }
- return vec4(result);
-`;
- const POW$1 = `
- // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.
- vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
- vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
- vec4 result = multiplier * pow(abs(a), b);
-
- // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS
- bvec4 isExpZero = equal(b, vec4(0.0));
- result.r = isExpZero.r ? 1.0 : result.r;
- result.g = isExpZero.g ? 1.0 : result.g;
- result.b = isExpZero.b ? 1.0 : result.b;
- result.a = isExpZero.a ? 1.0 : result.a;
-
- vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));
- ` +
- CHECK_NAN_SNIPPET$1 + `
- return result;
-`;
- const PRELU$1 = `
- vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
- return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
-`;
- const ELU_DER$1 = `
- vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
- return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
-`;
- const EQUAL$1 = `
- return vec4(equal(a, b));
-`;
- const NOT_EQUAL = `
- return vec4(notEqual(a, b));
-`;
- const LESS$1 = `
- return vec4(lessThan(a, b));
-`;
- const LESS_EQUAL$1 = `
- return vec4(lessThanEqual(a, b));
-`;
- const GREATER$1 = `
- return vec4(greaterThan(a, b));
-`;
- const GREATER_EQUAL$1 = `
- return vec4(greaterThanEqual(a, b));
-`;
- const LOGICAL_AND$1 = `
- return vec4(
- vec4(greaterThanEqual(a, vec4(1.0))) *
- vec4(greaterThanEqual(b, vec4(1.0))));
-`;
- const LOGICAL_OR$1 = `
- return min(
- vec4(greaterThanEqual(a, vec4(1.0))) +
- vec4(greaterThanEqual(b, vec4(1.0))),
- vec4(1.0));
-`;
- const MAX$1 = `
- vec4 result = vec4(max(a, b));
- vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
- ` +
- CHECK_NAN_SNIPPET$1 + `
- return result;
-`;
- const MIN$1 = `
- vec4 result = vec4(min(a, b));
- vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
- ` +
- CHECK_NAN_SNIPPET$1 + `
- return result;
-`;
- const MOD$1 = `
- vec4 result = mod(a, b);
- vec4 isNaN = vec4(equal(b, vec4(0.0)));
- ` +
- CHECK_NAN_SNIPPET$1 + `
- return result;
-`;
- class BinaryOpPackedProgram {
- constructor(op, aShape, bShape, checkOutOfBounds = false) {
- this.variableNames = ['A', 'B'];
- this.supportsBroadcasting = true;
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
- const rank = this.outputShape.length;
- let checkOutOfBoundsString = '';
- if (checkOutOfBounds) {
- if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
- checkOutOfBoundsString = `
- result.y = 0.;
- result.z = 0.;
- result.w = 0.;
- `;
- }
- else {
- const dtype = getCoordsDataType(rank);
- checkOutOfBoundsString = `
- ${dtype} coords = getOutputCoords();
- `;
- if (rank === 1) {
- checkOutOfBoundsString += `
- result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
- result.z = 0.;
- result.w = 0.;
- `;
- }
- else {
- const channels = getChannels('coords', rank);
- checkOutOfBoundsString += `
- bool nextRowOutOfBounds =
- (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
- bool nextColOutOfBounds =
- (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
- result.y = nextColOutOfBounds ? 0. : result.y;
- result.z = nextRowOutOfBounds ? 0. : result.z;
- result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
- `;
- }
- }
- }
- this.userCode = `
- vec4 binaryOperation(vec4 a, vec4 b) {
- ${op}
- }
-
- void main() {
- vec4 a = getAAtOutCoords();
- vec4 b = getBAtOutCoords();
-
- vec4 result = binaryOperation(a, b);
- ${checkOutOfBoundsString}
-
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ClipProgram {
- constructor(aShape) {
- this.variableNames = ['A'];
- this.outputShape = aShape;
- this.userCode = `
- uniform float minVal;
- uniform float maxVal;
-
- void main() {
- float value = getAAtOutCoords();
- if (isnan(value)) {
- setOutput(value);
- return;
- }
-
- setOutput(clamp(value, minVal, maxVal));
- }
- `;
- }
- getCustomSetupFunc(min, max) {
- return (gpgpu, webGLProgram) => {
- if (this.minLoc == null) {
- this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
- this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
- }
- gpgpu.gl.uniform1f(this.minLoc, min);
- gpgpu.gl.uniform1f(this.maxLoc, max);
- };
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ClipPackedProgram {
- constructor(aShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = aShape;
- this.userCode = `
- uniform float minVal;
- uniform float maxVal;
-
- void main() {
- vec4 value = getAAtOutCoords();
-
- if (any(isnan(value))) {
- setOutput(value);
- return;
- }
-
- setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
- }
- `;
- }
- getCustomSetupFunc(min, max) {
- return (gpgpu, webGLProgram) => {
- if (this.minLoc == null) {
- this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
- this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
- }
- gpgpu.gl.uniform1f(this.minLoc, min);
- gpgpu.gl.uniform1f(this.maxLoc, max);
- };
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ComplexAbsProgram {
- constructor(shape) {
- this.variableNames = ['real', 'imag'];
- this.outputShape = shape;
- this.userCode = `
- void main() {
- float re = abs(getRealAtOutCoords());
- float im = abs(getImagAtOutCoords());
- float mx = max(re, im);
-
- // sadly the length function in glsl is not underflow-safe
- // (at least not on Intel GPUs). So the safe solution is
- // to ensure underflow-safety in all cases.
- setOutput(
- mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
- );
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class Conv2DDerFilterProgram {
- constructor(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int wR = coords.x;
- int wC = coords.y;
- int d1 = coords.z;
- int d2 = coords.w;
-
- // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
-
- for (int b = 0; b < ${convInfo.batchSize}; b++) {
- for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
- int xR = wR + yR * ${strideHeight} - ${padTop};
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
- int xC = wC + yC * ${strideWidth} - ${padLeft};
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- if (${isChannelsLast}) {
- float dyValue = getDy(b, yR, yC, d2);
- float xValue = getX(b, xR, xC, d1);
- dotProd += (xValue * dyValue);
- } else {
- float dyValue = getDy(b, d2, yR, yC);
- float xValue = getX(b, d1, xR, xC);
- dotProd += (xValue * dyValue);
- }
-
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class Conv2DDerInputProgram {
- constructor(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- const padTop = filterHeight - 1 - convInfo.padInfo.top;
- const padLeft = filterWidth - 1 - convInfo.padInfo.left;
- const rowDim = isChannelsLast ? 1 : 2;
- const colDim = isChannelsLast ? 2 : 3;
- const channelDim = isChannelsLast ? 3 : 1;
- this.userCode = `
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords[0];
- int d1 = coords[${channelDim}];
-
- ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
- int dyRCorner = dyCorner.x;
- int dyCCorner = dyCorner.y;
-
- // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- int wRPerm = ${filterHeight} - 1 - wR;
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- int wCPerm = ${filterWidth} - 1 - wC;
-
- for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
-
- if (${isChannelsLast}) {
- float xValue = getDy(batch, idyR, idyC, d2);
- float wValue = getW(wRPerm, wCPerm, d1, d2);
- dotProd += xValue * wValue;
- } else {
- float xValue = getDy(batch, d2, idyR, idyC);
- float wValue = getW(wRPerm, wCPerm, d1, d2);
- dotProd += xValue * wValue;
- }
-
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class Conv3DDerFilterProgram {
- constructor(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const padFront = convInfo.padInfo.front;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- this.userCode = `
- void main() {
- ivec5 coords = getOutputCoords();
- int wF = coords.x;
- int wR = coords.y;
- int wC = coords.z;
- int d1 = coords.w;
- int d2 = coords.u;
-
- float dotProd = 0.0;
-
- for (int b = 0; b < ${convInfo.batchSize}; b++) {
- for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
- int xF = wF + yF * ${strideDepth} - ${padFront};
-
- if (xF < 0 || xF >= ${convInfo.inDepth}) {
- continue;
- }
-
- for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
- int xR = wR + yR * ${strideHeight} - ${padTop};
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
- int xC = wC + yC * ${strideWidth} - ${padLeft};
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- float dyValue = getDy(b, yF, yR, yC, d2);
- float xValue = getX(b, xF, xR, xC, d1);
- dotProd += (xValue * dyValue);
- }
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class Conv3DDerInputProgram {
- constructor(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const padFront = filterDepth - 1 - convInfo.padInfo.front;
- const padTop = filterHeight - 1 - convInfo.padInfo.top;
- const padLeft = filterWidth - 1 - convInfo.padInfo.left;
- this.userCode = `
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int d1 = coords.u;
-
-
- ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
- int dyFCorner = dyCorner.x;
- int dyRCorner = dyCorner.y;
- int dyCCorner = dyCorner.z;
-
- float dotProd = 0.0;
- for (int wF = 0; wF < ${filterDepth}; wF++) {
- float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
-
- if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
- continue;
- }
- int idyF = int(dyF);
-
- int wFPerm = ${filterDepth} - 1 - wF;
-
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
- fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- int wRPerm = ${filterHeight} - 1 - wR;
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- int wCPerm = ${filterWidth} - 1 - wC;
-
- for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
- float xValue = getDy(batch, idyF, idyR, idyC, d2);
- float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
- dotProd += xValue * wValue;
- }
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DepthwiseConv2DDerFilterProgram {
- constructor(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const channelMul = convInfo.outChannels / convInfo.inChannels;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int wR = coords.x;
- int wC = coords.y;
- int d1 = coords.z;
- int dm = coords.w;
- int d2 = d1 * ${channelMul} + dm;
-
- float dotProd = 0.0;
-
- // TO DO: Vec4 over the batch size
- for (int b = 0; b < ${convInfo.batchSize}; b++) {
- for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
- int xR = wR + yR * ${strideHeight} - ${padTop};
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
- int xC = wC + yC * ${strideWidth} - ${padLeft};
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- float dyValue = getDy(b, yR, yC, d2);
- float xValue = getX(b, xR, xC, d1);
- dotProd += (xValue * dyValue);
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class DepthwiseConv2DDerInputProgram {
- constructor(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const padTop = filterHeight - 1 - convInfo.padInfo.top;
- const padLeft = filterWidth - 1 - convInfo.padInfo.left;
- const channelMul = convInfo.outChannels / convInfo.inChannels;
- this.userCode = `
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords[0];
- int d1 = coords[3];
- ivec2 dyCorner = coords.yz - pads;
- int dyRCorner = dyCorner.x;
- int dyCCorner = dyCorner.y;
-
- float dotProd = 0.0;
-
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- int wRPerm = ${filterHeight} - 1 - wR;
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- int wCPerm = ${filterWidth} - 1 - wC;
-
- // TO DO: Vec4 over the channelMul
- for (int dm = 0; dm < ${channelMul}; dm++) {
- int d2 = d1 * ${channelMul} + dm;
- float xValue = getDy(batch, idyR, idyC, d2);
- float wValue = getW(wRPerm, wCPerm, d1, dm);
- dotProd += xValue * wValue;
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class Conv2DProgram {
- constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false) {
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
- const inputDepthVec4Remainder = convInfo.inChannels % 4;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- const rowDim = isChannelsLast ? 1 : 2;
- const colDim = isChannelsLast ? 2 : 3;
- const channelDim = isChannelsLast ? 3 : 1;
- let activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivationWeights) {
- activationSnippet = `float activation(float a) {
- float b = getPreluActivationWeightsAtOutCoords();
- ${activation}
- }`;
- }
- else {
- activationSnippet = `
- float activation(float x) {
- ${activation}
- }
- `;
- }
- applyActivationSnippet = `result = activation(result);`;
- }
- const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivationWeights) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = `
- ${activationSnippet}
-
- const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords[0];
- int d2 = coords[${channelDim}];
-
- ivec2 xRCCorner =
- ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
- int xRCorner = xRCCorner.x;
- int xCCorner = xRCCorner.y;
-
- // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- int xR = xRCorner + wR * ${dilationHeight};
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- int xC = xCCorner + wC * ${dilationWidth};
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
- vec4 wValues = vec4(
- getW(wR, wC, d1, d2),
- getW(wR, wC, d1 + 1, d2),
- getW(wR, wC, d1 + 2, d2),
- getW(wR, wC, d1 + 3, d2)
- );
-
- if (${isChannelsLast}) {
- vec4 xValues = vec4(
- getX(batch, xR, xC, d1),
- getX(batch, xR, xC, d1 + 1),
- getX(batch, xR, xC, d1 + 2),
- getX(batch, xR, xC, d1 + 3)
- );
- dotProd += dot(xValues, wValues);
- } else {
- vec4 xValues = vec4(
- getX(batch, d1, xR, xC),
- getX(batch, d1 + 1, xR, xC),
- getX(batch, d1 + 2, xR, xC),
- getX(batch, d1 + 3, xR, xC)
- );
- dotProd += dot(xValues, wValues);
- }
- }
-
- if (${inputDepthVec4Remainder === 1}) {
-
- if (${isChannelsLast}) {
- dotProd +=
- getX(batch, xR, xC, ${inputDepthNearestVec4}) *
- getW(wR, wC, ${inputDepthNearestVec4}, d2);
- } else {
- dotProd +=
- getX(batch, ${inputDepthNearestVec4}, xR, xC) *
- getW(wR, wC, ${inputDepthNearestVec4}, d2);
- }
-
- } else if (${inputDepthVec4Remainder === 2}) {
- vec2 wValues = vec2(
- getW(wR, wC, ${inputDepthNearestVec4}, d2),
- getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
- );
-
- if (${isChannelsLast}) {
- vec2 xValues = vec2(
- getX(batch, xR, xC, ${inputDepthNearestVec4}),
- getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
- );
- dotProd += dot(xValues, wValues);
- } else {
- vec2 xValues = vec2(
- getX(batch, ${inputDepthNearestVec4}, xR, xC),
- getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
- );
- dotProd += dot(xValues, wValues);
- }
-
- } else if (${inputDepthVec4Remainder === 3}) {
- vec3 wValues = vec3(
- getW(wR, wC, ${inputDepthNearestVec4}, d2),
- getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
- getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
- );
-
- if (${isChannelsLast}) {
- vec3 xValues = vec3(
- getX(batch, xR, xC, ${inputDepthNearestVec4}),
- getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
- getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
- );
- dotProd += dot(xValues, wValues);
- } else {
- vec3 xValues = vec3(
- getX(batch, ${inputDepthNearestVec4}, xR, xC),
- getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
- getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
- );
- dotProd += dot(xValues, wValues);
- }
-
- }
- }
- }
-
- float result = dotProd;
- ${addBiasSnippet}
- ${applyActivationSnippet}
- setOutput(result);
- }
- `;
- }
- }
- class Conv3DProgram {
- constructor(convInfo) {
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- const padFront = convInfo.padInfo.front;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const filterDepth = convInfo.filterDepth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
- const inputDepthVec4Remainder = convInfo.inChannels % 4;
- this.userCode = `
- const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int d2 = coords.u;
-
- ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
- int xFCorner = xFRCCorner.x;
- int xRCorner = xFRCCorner.y;
- int xCCorner = xFRCCorner.z;
-
- // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get
- // y(yF, yR, yC, d2). ? = to be determined. : = across all
- // values in that axis.
- float dotProd = 0.0;
- for (int wF = 0; wF < ${filterDepth}; wF++) {
- int xF = xFCorner + wF * ${dilationDepth};
-
- if (xF < 0 || xF >= ${convInfo.inDepth}) {
- continue;
- }
-
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- int xR = xRCorner + wR * ${dilationHeight};
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- int xC = xCCorner + wC * ${dilationWidth};
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
- vec4 xValues = vec4(
- getX(batch, xF, xR, xC, d1),
- getX(batch, xF, xR, xC, d1 + 1),
- getX(batch, xF, xR, xC, d1 + 2),
- getX(batch, xF, xR, xC, d1 + 3)
- );
- vec4 wValues = vec4(
- getW(wF, wR, wC, d1, d2),
- getW(wF, wR, wC, d1 + 1, d2),
- getW(wF, wR, wC, d1 + 2, d2),
- getW(wF, wR, wC, d1 + 3, d2)
- );
-
- dotProd += dot(xValues, wValues);
- }
-
- if (${inputDepthVec4Remainder === 1}) {
- dotProd +=
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
- getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
- } else if (${inputDepthVec4Remainder === 2}) {
- vec2 xValues = vec2(
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
- );
- vec2 wValues = vec2(
- getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
- getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
- );
- dotProd += dot(xValues, wValues);
- } else if (${inputDepthVec4Remainder === 3}) {
- vec3 xValues = vec3(
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
- getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
- );
- vec3 wValues = vec3(
- getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
- getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
- getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
- );
- dotProd += dot(xValues, wValues);
- }
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DepthwiseConv2DProgram {
- constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false) {
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- const xNumRows = convInfo.inHeight;
- const xNumCols = convInfo.inWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const channelMul = convInfo.outChannels / convInfo.inChannels;
- let activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = `float activation(float a) {
- float b = getPreluActivationWeightsAtOutCoords();
- ${activation}
- }`;
- }
- else {
- activationSnippet = `
- float activation(float x) {
- ${activation}
- }
- `;
- }
- applyActivationSnippet = `result = activation(result);`;
- }
- const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = `
- ${activationSnippet}
-
- const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords.x;
- ivec2 xRCCorner = coords.yz * strides - pads;
- int d2 = coords.w;
- int d1 = d2 / ${channelMul};
- int q = d2 - d1 * ${channelMul};
-
- int xRCorner = xRCCorner.x;
- int xCCorner = xRCCorner.y;
-
- // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
- // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.
- for (int wR = 0; wR < ${filterHeight}; wR++) {
- int xR = xRCorner + wR * ${dilationHeight};
-
- if (xR < 0 || xR >= ${xNumRows}) {
- continue;
- }
-
- for (int wC = 0; wC < ${filterWidth}; wC++) {
- int xC = xCCorner + wC * ${dilationWidth};
-
- if (xC < 0 || xC >= ${xNumCols}) {
- continue;
- }
-
- float xVal = getX(batch, xR, xC, d1);
- float wVal = getW(wR, wC, d1, q);
- dotProd += xVal * wVal;
- }
- }
-
- float result = dotProd;
- ${addBiasSnippet}
- ${applyActivationSnippet}
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DepthwiseConvPacked2DProgram {
- constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false) {
- this.variableNames = ['x', 'W'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = convInfo.outShape;
- const xNumRows = convInfo.inHeight;
- const xNumCols = convInfo.inWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const filterHeight = convInfo.filterHeight;
- const filterWidth = convInfo.filterWidth;
- const texelsAcross = filterWidth;
- let mainLoop = `int xR; int xC; int xCOffset;`;
- for (let r = 0; r < filterHeight; r++) {
- for (let c = 0; c < filterWidth; c++) {
- mainLoop += `
- vec4 xTexelR${r}C${c * 2} = vec4(0.);
- vec4 wR${r}C${c} = vec4(0.);
- vec4 xR${r}C${c} = vec4(0.);`;
- }
- }
- /**
- * This vectorized implementation works by gathering the values needed for
- * each output channel's dot product into vec4's and then multiplying them
- * all together (this happens in the final double for-loop below). Most of
- * the main loop consists of constructing these vec4's with the minimum
- * number of texture2D calls, which means making use of all four returned
- * values from a texture2D call at once.
- */
- for (let r = 0; r < filterHeight; r++) {
- for (let texelC = 0; texelC < texelsAcross; texelC++) {
- const c = texelC * 2;
- mainLoop += `
- xR = xRCorner + ${r * dilationHeight};
- xC = xCCorner + ${c * dilationWidth};
- `;
- if (strideWidth === 1) {
- if (c < filterWidth) {
- // If padding is odd, the outer texels have to be composed.
- if (padLeft % 2 === 1) {
- // TODO: Ensure vec4 previous does not result in redundant sample,
- // and avoid setting xTexelRC's that exceed the boundary in the
- // first place rather than resetting them to vec4(0)).
- // To compute xCOffset:
- // - If padding is odd, we must add 1 to ensure we ask for an
- // even-numbered row.
- // - We subtract 2 to access the previous texel.
- mainLoop += `
- xCOffset = xC + 1;
- if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
-
- // Need to manually clear unused channels in case
- // we're reading from recycled texture.
- if(xCOffset + 1 >= ${xNumCols}) {
- xTexelR${r}C${c}.zw = vec2(0.);
- }
- } else {
- xTexelR${r}C${c} = vec4(0.);
- }
-
- xCOffset = xC + 1 - 2;
- if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${xNumCols}) {
- vec4 previous = getX(batch, xR, xCOffset, d1);
-
- // Need to manually clear unused channels in case
- // we're reading from recycled texture.
- if(xCOffset + 1 >= ${xNumCols}) {
- previous.zw = vec2(0.);
- }
-
- xR${r}C${c} = vec4(previous.zw, xTexelR${r}C${c}.xy);
- } else {
- xR${r}C${c} = vec4(0, 0, xTexelR${r}C${c}.xy);
- }
- `;
- }
- else {
- // Padding is even, so xRC corresponds to a single texel.
- mainLoop += `
- if(xR >= 0 && xR < ${xNumRows} && xC >= 0 && xC < ${xNumCols}) {
- xTexelR${r}C${c} = getX(batch, xR, xC, d1);
- } else {
- xTexelR${r}C${c} = vec4(0.);
- }
-
- xR${r}C${c} = xTexelR${r}C${c};
- `;
- }
- if (c + 1 < filterWidth) {
- // If dilation is even, the second entry should match the first
- // (either both are composed or both are single samples). But if
- // dilation is odd, then the second entry should be the opposite
- // of the first (if the first is composed, the second is a single
- // sample, and vice versa.)
- const nextTexelOffset = padLeft % 2 === 0 ?
- nearestLargerEven(dilationWidth) :
- dilationWidth;
- if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
- (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
- mainLoop += `
- xCOffset = xC + ${padLeft % 2} + ${nextTexelOffset};
-
- if(xR >= 0 && xR < ${xNumRows} &&
- xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);
- }
- `;
- // If dilation > 1 then the xRC's will not be able to share any
- // values, so each xRC will require two unique calls to getX.
- if (dilationWidth > 1) {
- mainLoop += `
- xCOffset -= 2;
- if(xR >= 0 && xR < ${xNumRows} &&
- xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
- } else {
- xTexelR${r}C${c} = vec4(0.);
- }
- `;
- }
- mainLoop += `
- xR${r}C${c + 1} = vec4(
- xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.xy);
- `;
- }
- else {
- mainLoop += `
- xCOffset = xC + ${nextTexelOffset};
-
- if(xR >= 0 && xR < ${xNumRows} &&
- xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);
- }
-
- xR${r}C${c + 1} = xTexelR${r}C${c + 2};
- `;
- }
- }
- }
- }
- else { // stride > 1
- if (c < filterWidth) {
- mainLoop += `
- if(xR >= 0 && xR < ${xNumRows}) {
- `;
- // Depending on whether padLeft is even or odd, we want either the
- // xy or zw channels from X texels for xR${r}C${c}. If padLeft is
- // even, xR${r}C${c + 1} is simply the zw channels of texels we've
- // already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will
- // need to come from the xy channels of a new texel, hence the `vec4
- // final` initialized below.
- if (padLeft % 2 === 1) {
- mainLoop += `
- xCOffset = xC + 1 - ${strideWidth};
- if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
- } else {
- xTexelR${r}C${c} = vec4(0.);
- }
-
- if(xC + 1 >= 0 && xC + 1 < ${xNumCols}) {
- xTexelR${r}C${c + 2} = getX(batch, xR, xC + 1, d1);
- } else {
- xTexelR${r}C${c + 2} = vec4(0.);
- }
-
- xR${r}C${c} = vec4(
- xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.zw);
- `;
- if (c + 1 < filterWidth) {
- mainLoop += `
- vec4 final = vec4(0.);
- xCOffset = xC + 1 + ${strideWidth};
- if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
- final = getX(batch, xR, xCOffset, d1);
- }
- xR${r}C${c + 1} = vec4(xTexelR${r}C${c + 2}.xy, final.xy);
- `;
- }
- }
- else {
- mainLoop += `
- if(xC >= 0 && xC < ${xNumCols}) {
- xTexelR${r}C${c} = getX(batch, xR, xC, d1);
- } else {
- xTexelR${r}C${c} = vec4(0.);
- }
-
- xCOffset = xC + ${strideWidth};
- if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
- xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);
- } else {
- xTexelR${r}C${c + 2} = vec4(0.);
- }
-
- xR${r}C${c} = vec4(
- xTexelR${r}C${c}.xy, xTexelR${r}C${c + 2}.xy);
- `;
- if (c + 1 < filterWidth) {
- mainLoop += `
- xR${r}C${c + 1} = vec4(
- xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.zw);
- `;
- }
- }
- mainLoop += `}`;
- }
- }
- if (c < filterWidth) {
- mainLoop += `
- vec4 wTexelR${r}C${c} = getW(${r}, ${c}, d1, q);
- wR${r}C${c} = vec4(wTexelR${r}C${c}.xz, wTexelR${r}C${c}.xz);
- `;
- if (c + 1 < filterWidth) {
- mainLoop += `
- vec4 wTexelR${r}C${c + 1} = getW(${r}, ${c + 1}, d1, q);
- wR${r}C${c + 1} =
- vec4(wTexelR${r}C${c + 1}.xz, wTexelR${r}C${c + 1}.xz);`;
- }
- }
- }
- }
- for (let r = 0; r < filterHeight; r++) {
- for (let c = 0; c < filterWidth; c++) {
- mainLoop += `dotProd += xR${r}C${c} * wR${r}C${c};`;
- }
- }
- let activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = `vec4 activation(vec4 a) {
- vec4 b = getPreluActivationWeightsAtOutCoords();
- ${activation}
- }`;
- }
- else {
- activationSnippet = `vec4 activation(vec4 x) {
- ${activation}
- }`;
- }
- applyActivationSnippet = `result = activation(result);`;
- }
- const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = `
- ${activationSnippet}
-
- const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
-
- ivec4 coords = getOutputCoords();
- int batch = coords.x;
- ivec2 xRCCorner = coords.yz * strides - pads;
- int d2 = coords.w;
- int d1 = d2;
- int q = 0;
- int xRCorner = xRCCorner.x;
- int xCCorner = xRCCorner.y;
-
- vec4 dotProd = vec4(0.);
-
- ${mainLoop}
-
- vec4 result = dotProd;
- ${addBiasSnippet}
- ${applyActivationSnippet}
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class CropAndResizeProgram {
- constructor(imageShape, boxShape, cropSize, method, extrapolationValue) {
- this.variableNames = ['Image', 'Boxes', 'BoxInd'];
- this.outputShape = [];
- const [batch, imageHeight, imageWidth, depth] = imageShape;
- const [numBoxes,] = boxShape;
- const [cropHeight, cropWidth] = cropSize;
- this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
- const methodId = method === 'bilinear' ? 1 : 0;
- const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];
- const [heightRatio, heightScale, inY] = cropHeight > 1 ?
- [
- `${(imageHeight - 1) / (cropHeight - 1)}`,
- '(y2-y1) * height_ratio',
- `y1*${inputHeightFloat} + float(y)*(height_scale)`,
- ] :
- [
- '0.0',
- '0.0',
- `0.5 * (y1+y2) * ${inputHeightFloat}`,
- ];
- const [widthRatio, widthScale, inX] = cropWidth > 1 ?
- [
- `${(imageWidth - 1) / (cropWidth - 1)}`,
- '(x2-x1) * width_ratio',
- `x1*${inputWidthFloat} + float(x)*(width_scale)`,
- ] :
- [
- '0.0',
- '0.0',
- `0.5 * (x1+x2) * ${inputWidthFloat}`,
- ];
- // Reference implementation
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
- this.userCode = `
- const float height_ratio = float(${heightRatio});
- const float width_ratio = float(${widthRatio});
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int y = coords[1];
- int x = coords[2];
- int d = coords[3];
-
- // get box vals
- float y1 = getBoxes(b,0);
- float x1 = getBoxes(b,1);
- float y2 = getBoxes(b,2);
- float x2 = getBoxes(b,3);
-
- // get image in batch index
- int bInd = round(getBoxInd(b));
- if(bInd < 0 || bInd >= ${batch}) {
- return;
- }
-
- float height_scale = ${heightScale};
- float width_scale = ${widthScale};
-
- float in_y = ${inY};
- if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
- setOutput(float(${extrapolationValue}));
- return;
- }
- float in_x = ${inX};
- if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
- setOutput(float(${extrapolationValue}));
- return;
- }
-
- vec2 sourceFracIndexCR = vec2(in_x,in_y);
- if(${methodId} == 1) {
- // Compute the four integer indices.
- ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
- ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
-
- float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
- float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
- float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
- float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
-
- vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
-
- float top = topLeft + (topRight - topLeft) * fracCR.x;
- float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
- float newValue = top + (bottom - top) * fracCR.y;
- setOutput(newValue);
- } else {
- // Compute the coordinators of nearest neighbor point.
- ivec2 sourceNearestCR = ivec2(floor(
- sourceFracIndexCR + vec2(0.5,0.5)));
- float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
- setOutput(newValue);
- }
- }
- `;
- }
- }
-
- class CumSumProgram {
- constructor(shape, exclusive, reverse) {
- this.variableNames = ['x'];
- this.outputShape = shape;
- const rank = shape.length;
- const val = exclusive ? '0.0' : `getX(${getCoords(rank, 'coords')})`;
- const length = shape[shape.length - 1];
- let condition = '';
- let idxString = '';
- // When exclusive is set, the cumsum op becomes roll op that copies the
- // value from the previous index based on the direction specified by the
- // reverse flag.
- if (exclusive) {
- condition = reverse ? `end != ${length - 1}` : 'end != 0';
- idxString = reverse ? 'end + 1' : 'end - 1';
- }
- else {
- condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
- idxString = (reverse ? 'end + pow2' : 'end - pow2');
- }
- this.userCode = `
- uniform float index;
- void main() {
- ${getCoordsDataType(rank)} coords = getOutputCoords();
- int end = ${getFinalCoord(rank, 'coords')};
- float val = ${val};
- int pow2 = int(pow(2.0, index));
- if (${condition}) {
- int idx = ${idxString};
- ${getFinalCoord(rank, 'coords')} = idx;
- val += getX(${getCoords(rank, 'coords')});
- }
- setOutput(val);
- }
- `;
- }
- getCustomSetupFunc(index) {
- return (gpgpu, webGLProgram) => {
- if (this.index == null) {
- this.index = gpgpu.getUniformLocation(webGLProgram, 'index');
- }
- gpgpu.gl.uniform1f(this.index, index);
- };
- }
- }
- function getCoords(rank, name) {
- if (rank === 1) {
- return `${name}`;
- }
- else if (rank === 2) {
- return `${name}.x, ${name}.y`;
- }
- else if (rank === 3) {
- return `${name}.x, ${name}.y, ${name}.z`;
- }
- else if (rank === 4) {
- return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
- }
- else {
- throw Error(`Cumulative sum for rank ${rank} is not yet supported`);
- }
- }
- function getFinalCoord(rank, name) {
- if (rank === 1) {
- return `${name}`;
- }
- else if (rank === 2) {
- return `${name}.y`;
- }
- else if (rank === 3) {
- return `${name}.z`;
- }
- else if (rank === 4) {
- return `${name}.w`;
- }
- else {
- throw Error(`Cumulative sum for rank ${rank} is not yet supported`);
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DecodeMatrixProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- this.outPackingScheme = PackingScheme.DENSE;
- const texShape = getDenseTexShape(outputShape);
- const glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = `
- ivec3 outCoordsFromFlatIndex(int index) {
- ${getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
- return ivec3(r, c, d);
- }
-
- void main() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
-
- vec4 result = vec4(0.);
-
- for (int i=0; i<4; i++) {
- int flatIndex = index + i;
- ivec3 rc = outCoordsFromFlatIndex(flatIndex);
- result[i] = getA(rc.x, rc.y, rc.z);
- }
-
- ${glsl.output} = result;
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DecodeMatrixPackedProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outPackingScheme = PackingScheme.DENSE;
- const texShape = getDenseTexShape(outputShape);
- const glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = `
- ivec3 outCoordsFromFlatIndex(int index) {
- ${getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
- return ivec3(r, c, d);
- }
-
- void main() {
- ivec2 resTexRC = ivec2(resultUV.yx *
- vec2(${texShape[0]}, ${texShape[1]}));
- int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
-
- vec4 result = vec4(0.);
-
- for (int i=0; i<4; i++) {
- int flatIndex = index + i;
- ivec3 rc = outCoordsFromFlatIndex(flatIndex);
- result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
- }
-
- ${glsl.output} = result;
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DepthToSpaceProgram {
- constructor(outputShape, blockSize, dataFormat) {
- this.variableNames = ['x'];
- this.outputShape = [];
- this.outputShape = outputShape;
- this.blockSize = blockSize;
- this.dataFormat = dataFormat;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int h = ${this.getHeightCoordString()};
- int w = ${this.getWidthCoordString()};
- int d = ${this.getDepthCoordString()};
-
- int in_h = h / ${blockSize};
- int offset_h = imod(h, ${blockSize});
- int in_w = w / ${blockSize};
- int offset_w = imod(w, ${blockSize});
- int offset_d = (offset_h * ${blockSize} + offset_w) *
- ${this.getOutputDepthSize()};
- int in_d = d + offset_d;
-
- float result = ${this.getInputSamplingString()};
- setOutput(result);
- }
- `;
- }
- getHeightCoordString() {
- if (this.dataFormat === 'NHWC') {
- return `coords[1]`;
- }
- else {
- return `coords[2]`;
- }
- }
- getWidthCoordString() {
- if (this.dataFormat === 'NHWC') {
- return `coords[2]`;
- }
- else {
- return `coords[3]`;
- }
- }
- getDepthCoordString() {
- if (this.dataFormat === 'NHWC') {
- return `coords[3]`;
- }
- else {
- return `coords[1]`;
- }
- }
- getOutputDepthSize() {
- if (this.dataFormat === 'NHWC') {
- return this.outputShape[3];
- }
- else {
- return this.outputShape[1];
- }
- }
- getInputSamplingString() {
- if (this.dataFormat === 'NHWC') {
- return `getX(b, in_h, in_w, in_d)`;
- }
- else {
- return `getX(b, in_d, in_h, in_w)`;
- }
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class DiagProgram {
- constructor(size) {
- this.variableNames = ['X'];
- this.outputShape = [size, size];
- this.userCode = `
- void main() {
- ivec2 coords = getOutputCoords();
- float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
- setOutput(val);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class EncodeFloatProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.outTexUsage = TextureUsage.DOWNLOAD;
- const glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = `
- ${ENCODE_FLOAT_SNIPPET}
-
- void main() {
- float x = getAAtOutCoords();
- ${glsl.output} = encode_float(x);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class EncodeFloatPackedProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = false;
- this.outTexUsage = TextureUsage.DOWNLOAD;
- const glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = `
- ${ENCODE_FLOAT_SNIPPET}
-
- void main() {
- ivec3 coords = getOutputCoords();
- float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
- ${glsl.output} = encode_float(x);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class EncodeMatrixProgram {
- constructor(outputShape, texShape, inputIsUnsignedByte = false) {
- this.variableNames = ['A'];
- const glsl = getGlslDifferences();
- const [height, width] = texShape;
- this.outputShape = outputShape;
- let output = `result`;
- if (inputIsUnsignedByte) {
- output = `floor(result * 255. + 0.5)`;
- }
- this.userCode = `
- ${getFlatIndexFrom3D(outputShape)}
-
- void main() {
- ivec3 coords = getOutputCoords();
-
- int flatIndex = getFlatIndex(coords);
- int offset = imod(flatIndex, 4);
-
- flatIndex = idiv(flatIndex, 4, 1.);
-
- int r = flatIndex / ${width};
- int c = imod(flatIndex, ${width});
- vec2 uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);
- vec4 values = ${glsl.texture2D}(A, uv);
-
- float result;
-
- if(offset == 0) {
- result = values[0];
- } else if(offset == 1) {
- result = values[1];
- } else if(offset == 2) {
- result = values[2];
- } else {
- result = values[3];
- }
-
- ${glsl.output} = vec4(${output}, 0., 0., 0.);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /*
- This is how the shader encodes a tensor with shape = [2, 3, 5]
- (indices are [batch, row, col]).
-
- 000|001 002|003 004|xxx 020|021 022|023 024|xxx
- ------- ------- ------- ------- ------- -------
- 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
-
- 100|101 102|103 104|xxx 120|121 122|123 124|xxx
- ------- ------- ------- ------- ------- -------
- 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
-
- Single texels contain only values from the same batch, and from adjacent rows
- and columns.
- */
- class EncodeMatrixPackedProgram {
- constructor(outputShape, texShape, inputIsUnsignedByte = false) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- const glsl = getGlslDifferences();
- const [height, width] = texShape;
- this.outputShape = outputShape;
- let mainLoop = '';
- let output = 'result';
- if (inputIsUnsignedByte) {
- output = 'floor(result * 255. + 0.5)';
- }
- for (let row = 0; row <= 1; row++) {
- for (let col = 0; col <= 1; col++) {
- const channel = row * 2 + col;
- mainLoop += `
- localCoords = coords;
- if(localCoords[2] + ${col} < ${outputShape[2]}) {
- localCoords[2] += ${col};
- if(localCoords[1] + ${row} < ${outputShape[1]}) {
- localCoords[1] += ${row};
-
- flatIndex = getFlatIndex(localCoords);
- offset = imod(flatIndex, 4);
-
- flatIndex = idiv(flatIndex, 4, 1.);
-
- r = flatIndex / ${width};
- c = imod(flatIndex, ${width});
- uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);
- values = ${glsl.texture2D}(A, uv);
-
- if(offset == 0) {
- result[${channel}] = values[0];
- } else if(offset == 1) {
- result[${channel}] = values[1];
- } else if(offset == 2) {
- result[${channel}] = values[2];
- } else {
- result[${channel}] = values[3];
- }
- }
- }
- `;
- }
- }
- this.userCode = `
- ${getFlatIndexFrom3D(outputShape)}
-
- void main() {
- ivec3 coords = getOutputCoords();
-
- vec4 result = vec4(0.);
- int flatIndex, r, c, offset;
- ivec3 localCoords;
- vec2 uv;
- vec4 values;
-
- ${mainLoop}
-
- ${glsl.output} = ${output};
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class FillProgram {
- constructor(shape, value) {
- this.outputShape = [];
- this.variableNames = ['x'];
- this.outputShape = shape;
- this.userCode = `
- uniform float value;
- void main() {
- // Input can be obtained from uniform value.
- setOutput(value);
- }
- `;
- }
- getCustomSetupFunc(value) {
- return (gpgpu, webGLProgram) => {
- if (this.valueLoc == null) {
- this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
- }
- gpgpu.gl.uniform1f(this.valueLoc, value);
- };
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class GatherProgram {
- constructor(aShape, indicesLength, axis) {
- this.variableNames = ['A', 'indices'];
- const outputShape = aShape.slice();
- outputShape[axis] = indicesLength;
- this.outputShape = outputShape;
- this.rank = outputShape.length;
- const dtype = getCoordsDataType(this.rank);
- const sourceCoords = getSourceCoords$1(aShape, axis);
- this.userCode = `
- void main() {
- ${dtype} resRC = getOutputCoords();
- setOutput(getA(${sourceCoords}));
- }
- `;
- }
- }
- function getSourceCoords$1(aShape, axis) {
- const rank = aShape.length;
- if (rank > 4) {
- throw Error(`Gather for rank ${rank} is not yet supported`);
- }
- if (rank === 1) {
- return `int(getIndices(resRC))`;
- }
- const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
- const sourceCoords = [];
- for (let i = 0; i < aShape.length; i++) {
- if (i === axis) {
- sourceCoords.push(`int(getIndices(${currentCoords[i]}))`);
- }
- else {
- sourceCoords.push(`${currentCoords[i]}`);
- }
- }
- return sourceCoords.join();
- }
-
- class GatherNDProgram {
- constructor(sliceDim, strides, shape) {
- this.sliceDim = sliceDim;
- this.strides = strides;
- this.variableNames = ['x', 'indices'];
- this.outputShape = shape;
- const stridesType = getCoordsDataType(strides.length);
- const dtype = getCoordsDataType(shape.length);
- const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
- this.userCode = `
- ${stridesType} strides = ${stridesType}(${this.strides});
- void main() {
- ${dtype} coords = getOutputCoords();
- int flattenIndex = 0;
- for (int j = 0; j < ${this.sliceDim}; j++) {
- int index = round(getIndices(coords[0], j));
- flattenIndex += index * ${strideString};
- }
- setOutput(getX(flattenIndex, coords[1]));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function createVertexShader$1(gl) {
- const glsl = getGlslDifferences();
- const vertexShaderSource = `${glsl.version}
- precision highp float;
- ${glsl.attribute} vec3 clipSpacePos;
- ${glsl.attribute} vec2 uv;
- ${glsl.varyingVs} vec2 resultUV;
-
- void main() {
- gl_Position = vec4(clipSpacePos, 1);
- resultUV = uv;
- }`;
- return createVertexShader(gl, vertexShaderSource);
- }
- function createVertexBuffer(gl) {
- // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
- const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
- return createStaticVertexBuffer(gl, vertexArray);
- }
- function createIndexBuffer(gl) {
- // OpenGL (and WebGL) have "CCW == front" winding
- const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
- return createStaticIndexBuffer(gl, triangleVertexIndices);
- }
- function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
- validateTextureSize(width, height);
- const texture = createTexture(gl);
- const tex2d = gl.TEXTURE_2D;
- callAndCheck(gl, () => gl.bindTexture(tex2d, texture));
- callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));
- callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));
- callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));
- callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));
- callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null));
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
- return texture;
- }
- function getInternalFormatForFloat32MatrixTexture(textureConfig) {
- return textureConfig.internalFormatFloat;
- }
- function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
- const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
- return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
- }
- function getInternalFormatForFloat16MatrixTexture(textureConfig) {
- return textureConfig.internalFormatHalfFloat;
- }
- function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
- const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
- return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
- }
- function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
- return textureConfig.downloadTextureFormat;
- }
- function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
- const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
- return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
- }
- function getInternalFormatForPackedMatrixTexture(textureConfig) {
- return textureConfig.internalFormatPackedFloat;
- }
- function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
- const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
- return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
- }
- function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
- return textureConfig.internalFormatPackedHalfFloat;
- }
- function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
- const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
- return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
- }
- function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
- const posOffset = 0; // x is the first buffer element
- const uvOffset = 3 * 4; // uv comes after [x y z]
- const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.
- callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));
- const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
- return success &&
- bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
- }
- function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
- let dataForUpload, texelDataType, internalFormat;
- if (data instanceof Uint8Array) {
- dataForUpload = new Uint8Array(width * height * 4);
- texelDataType = gl.UNSIGNED_BYTE;
- internalFormat = gl.RGBA;
- }
- else {
- dataForUpload = new Float32Array(width * height * 4);
- texelDataType = gl.FLOAT;
- internalFormat = textureConfig.internalFormatPackedFloat;
- }
- dataForUpload.set(data);
- callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload));
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
- }
- function uploadPixelDataToTexture(gl, texture, pixels) {
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
- if (pixels.data instanceof Uint8Array) {
- callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
- }
- else {
- callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
- }
- callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
- }
- function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
- // Create and bind the buffer.
- const buffer = gl2.createBuffer();
- callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));
- // Initialize the buffer to the size of the texture in bytes.
- const bytesPerFloat = 4;
- const valuesPerTexel = 4;
- const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
- callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));
- // Enqueue a command on the GPU command queue to copy of texture into the
- // buffer.
- callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));
- callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));
- return buffer;
- }
- function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
- const gl2 = gl;
- const downloadTarget = new Float32Array(size);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
- gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
- return downloadTarget;
- }
- function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
- const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
- const numChannels = 4;
- const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
- callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget));
- // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
- // decoding of the 4 bytes that back each 32 bit float.
- return new Float32Array(downloadTarget.buffer);
- }
- function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
- const gl2 = gl;
- const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
- gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
- return downloadTarget;
- }
- function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
- const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
- callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));
- return packedRGBA;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class GPGPUContext {
- constructor(gl) {
- this.outputTexture = null;
- this.program = null;
- this.disposed = false;
- this.vertexAttrsAreBound = false;
- this.itemsToPoll = [];
- const glVersion = env().getNumber('WEBGL_VERSION');
- if (gl != null) {
- this.gl = gl;
- setWebGLContext(glVersion, gl);
- }
- else {
- this.gl = getWebGLContext(glVersion);
- }
- // WebGL 2.0 enables texture floats without an extension.
- let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
- const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
- if (env().getNumber('WEBGL_VERSION') === 1) {
- const TEXTURE_FLOAT = 'OES_texture_float';
- const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
- this.textureFloatExtension =
- getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
- if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
- this.textureHalfFloatExtension =
- getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
- }
- else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
- throw new Error('GL context does not support half float textures, yet the ' +
- 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
- }
- this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
- if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
- this.colorBufferHalfFloatExtension =
- getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
- }
- else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
- throw new Error('GL context does not support color renderable half floats, yet ' +
- 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
- }
- }
- else {
- COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
- if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
- this.colorBufferFloatExtension =
- this.gl.getExtension(COLOR_BUFFER_FLOAT);
- }
- else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
- this.colorBufferHalfFloatExtension =
- this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
- }
- else {
- throw new Error('GL context does not support color renderable floats');
- }
- }
- this.vertexBuffer = createVertexBuffer(this.gl);
- this.indexBuffer = createIndexBuffer(this.gl);
- this.framebuffer = createFramebuffer(this.gl);
- this.textureConfig =
- getTextureConfig(this.gl, this.textureHalfFloatExtension);
- }
- get debug() {
- return env().getBool('DEBUG');
- }
- dispose() {
- if (this.disposed) {
- return;
- }
- if (this.program != null) {
- console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
- ' This is probably a resource leak, delete the program with ' +
- 'GPGPUContext.deleteProgram before disposing.');
- }
- if (this.outputTexture != null) {
- console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
- 'texture. This is probably a resource leak, delete the output ' +
- 'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
- 'disposing.');
- }
- const gl = this.gl;
- callAndCheck(gl, () => gl.finish());
- callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
- callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer));
- callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));
- callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));
- callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer));
- this.disposed = true;
- }
- createFloat32MatrixTexture(rows, columns) {
- this.throwIfDisposed();
- return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
- }
- createFloat16MatrixTexture(rows, columns) {
- this.throwIfDisposed();
- return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
- }
- createUnsignedBytesMatrixTexture(rows, columns) {
- this.throwIfDisposed();
- return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
- }
- uploadPixelDataToTexture(texture, pixels) {
- this.throwIfDisposed();
- uploadPixelDataToTexture(this.gl, texture, pixels);
- }
- uploadDenseMatrixToTexture(texture, width, height, data) {
- this.throwIfDisposed();
- uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
- }
- createFloat16PackedMatrixTexture(rows, columns) {
- this.throwIfDisposed();
- return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
- }
- createPackedMatrixTexture(rows, columns) {
- this.throwIfDisposed();
- return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
- }
- deleteMatrixTexture(texture) {
- this.throwIfDisposed();
- if (this.outputTexture === texture) {
- unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
- this.outputTexture = null;
- }
- callAndCheck(this.gl, () => this.gl.deleteTexture(texture));
- }
- downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) {
- return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig));
- }
- downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) {
- return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
- }
- downloadFloat32MatrixFromBuffer(buffer, size) {
- return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
- }
- createBufferFromTexture(texture, rows, columns) {
- this.bindTextureToFrameBuffer(texture);
- const result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
- this.unbindTextureToFrameBuffer();
- return result;
- }
- createAndWaitForFence() {
- const fenceContext = this.createFence(this.gl);
- return this.pollFence(fenceContext);
- }
- createFence(gl) {
- let query;
- let isFencePassed;
- if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
- const gl2 = gl;
- const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
- gl.flush();
- isFencePassed = () => {
- const status = gl2.clientWaitSync(sync, 0, 0);
- return status === gl2.ALREADY_SIGNALED ||
- status === gl2.CONDITION_SATISFIED;
- };
- query = sync;
- }
- else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
- query = this.beginQuery();
- this.endQuery();
- isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
- }
- else {
- // If we have no way to fence, return true immediately. This will fire in
- // WebGL 1.0 when there is no disjoint query timer. In this case, because
- // the fence passes immediately, we'll immediately ask for a download of
- // the texture, which will cause the UI thread to hang.
- isFencePassed = () => true;
- }
- return { query, isFencePassed };
- }
- downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
- return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols));
- }
- createProgram(fragmentShaderSource) {
- this.throwIfDisposed();
- const gl = this.gl;
- const fragmentShader = createFragmentShader(gl, fragmentShaderSource);
- const vertexShader = createVertexShader$1(gl);
- const program = createProgram(gl);
- callAndCheck(gl, () => gl.attachShader(program, vertexShader));
- callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
- linkProgram(gl, program);
- if (this.debug) {
- validateProgram(gl, program);
- }
- if (!this.vertexAttrsAreBound) {
- this.setProgram(program);
- this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.program, this.vertexBuffer);
- }
- return program;
- }
- deleteProgram(program) {
- this.throwIfDisposed();
- if (program === this.program) {
- this.program = null;
- }
- if (program != null) {
- callAndCheck(this.gl, () => this.gl.deleteProgram(program));
- }
- }
- setProgram(program) {
- this.throwIfDisposed();
- this.program = program;
- if ((this.program != null) && this.debug) {
- validateProgram(this.gl, this.program);
- }
- callAndCheck(this.gl, () => this.gl.useProgram(program));
- }
- getUniformLocation(program, uniformName, shouldThrow = true) {
- this.throwIfDisposed();
- if (shouldThrow) {
- return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
- }
- else {
- return getProgramUniformLocation(this.gl, program, uniformName);
- }
- }
- getAttributeLocation(program, attribute) {
- this.throwIfDisposed();
- return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute));
- }
- getUniformLocationNoThrow(program, uniformName) {
- this.throwIfDisposed();
- return this.gl.getUniformLocation(program, uniformName);
- }
- setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
- this.throwIfDisposed();
- this.throwIfNoProgram();
- bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
- }
- setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
- this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
- }
- setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
- this.throwIfDisposed();
- const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
- this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
- }
- setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
- this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
- }
- setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
- throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
- }
- debugValidate() {
- if (this.program != null) {
- validateProgram(this.gl, this.program);
- }
- validateFramebuffer(this.gl);
- }
- executeProgram() {
- this.throwIfDisposed();
- this.throwIfNoProgram();
- const gl = this.gl;
- if (this.debug) {
- this.debugValidate();
- }
- callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));
- }
- blockUntilAllProgramsCompleted() {
- this.throwIfDisposed();
- callAndCheck(this.gl, () => this.gl.finish());
- }
- getQueryTimerExtension() {
- if (this.disjointQueryTimerExtension == null) {
- this.disjointQueryTimerExtension =
- getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
- 'EXT_disjoint_timer_query_webgl2' :
- 'EXT_disjoint_timer_query');
- }
- return this.disjointQueryTimerExtension;
- }
- getQueryTimerExtensionWebGL2() {
- return this.getQueryTimerExtension();
- }
- getQueryTimerExtensionWebGL1() {
- return this.getQueryTimerExtension();
- }
- beginQuery() {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
- const gl2 = this.gl;
- const ext = this.getQueryTimerExtensionWebGL2();
- const query = gl2.createQuery();
- gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
- return query;
- }
- const ext = this.getQueryTimerExtensionWebGL1();
- const query = ext.createQueryEXT();
- ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
- return query;
- }
- endQuery() {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
- const gl2 = this.gl;
- const ext = this.getQueryTimerExtensionWebGL2();
- gl2.endQuery(ext.TIME_ELAPSED_EXT);
- return;
- }
- const ext = this.getQueryTimerExtensionWebGL1();
- ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
- }
- async waitForQueryAndGetTime(query) {
- await repeatedTry(() => this.disposed || // while testing contexts are created / disposed
- // in rapid succession, so without this check we
- // may poll for the query timer indefinitely
- this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
- return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
- }
- getQueryTime(query, queryTimerVersion) {
- if (queryTimerVersion === 0) {
- return null;
- }
- if (queryTimerVersion === 2) {
- const gl2 = this.gl;
- const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
- // Return milliseconds.
- return timeElapsedNanos / 1000000;
- }
- else {
- const ext = this.getQueryTimerExtensionWebGL1();
- const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
- // Return milliseconds.
- return timeElapsedNanos / 1000000;
- }
- }
- isQueryAvailable(query, queryTimerVersion) {
- if (queryTimerVersion === 0) {
- return true;
- }
- if (queryTimerVersion === 2) {
- const gl2 = this.gl;
- const ext = this.getQueryTimerExtensionWebGL2();
- const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
- if (this.disjoint == null) {
- this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
- }
- return available && !this.disjoint;
- }
- else {
- const ext = this.getQueryTimerExtensionWebGL1();
- const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
- if (this.disjoint == null) {
- this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
- }
- return available && !this.disjoint;
- }
- }
- pollFence(fenceContext) {
- return new Promise(resolve => {
- this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());
- });
- }
- pollItems() {
- // Find the last query that has finished.
- const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));
- for (let i = 0; i <= index; ++i) {
- const { resolveFn } = this.itemsToPoll[i];
- resolveFn();
- }
- this.itemsToPoll = this.itemsToPoll.slice(index + 1);
- }
- addItemToPoll(isDoneFn, resolveFn) {
- this.itemsToPoll.push({ isDoneFn, resolveFn });
- if (this.itemsToPoll.length > 1) {
- // We already have a running loop that polls.
- return;
- }
- // Start a new loop that polls.
- repeatedTry(() => {
- this.pollItems();
- // End the loop if no more items to poll.
- return this.itemsToPoll.length === 0;
- });
- }
- bindTextureToFrameBuffer(texture) {
- this.throwIfDisposed();
- bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(this.gl);
- }
- }
- unbindTextureToFrameBuffer() {
- if (this.outputTexture != null) {
- bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(this.gl);
- }
- }
- else {
- unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
- }
- }
- downloadMatrixDriver(texture, downloadAndDecode) {
- this.bindTextureToFrameBuffer(texture);
- const result = downloadAndDecode();
- this.unbindTextureToFrameBuffer();
- return result;
- }
- setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
- this.throwIfDisposed();
- const gl = this.gl;
- bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(gl);
- }
- this.outputTexture = outputMatrixTextureMaybePacked;
- callAndCheck(gl, () => gl.viewport(0, 0, width, height));
- callAndCheck(gl, () => gl.scissor(0, 0, width, height));
- }
- setOutputMatrixWriteRegionDriver(x, y, width, height) {
- this.throwIfDisposed();
- callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height));
- }
- throwIfDisposed() {
- if (this.disposed) {
- throw new Error('Attempted to use disposed GPGPUContext.');
- }
- }
- throwIfNoProgram() {
- if (this.program == null) {
- throw new Error('No GPU program is currently set.');
- }
- }
- }
- /**
- * Finds the index of the last true element using linear search.
- * Note: We can't do binary search because Chrome expects us to explicitly
- * test all fences before download:
- * https://github.com/tensorflow/tfjs/issues/1145
- */
- function linearSearchLastTrue(arr) {
- let i = 0;
- for (; i < arr.length; ++i) {
- const isDone = arr[i]();
- if (!isDone) {
- break;
- }
- }
- return i - 1;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function compileProgram(gpgpu, program, inputs, output) {
- const userCode = program.userCode;
- const inputInfos = inputs.map((input, i) => {
- const shapeInfo = {
- logicalShape: input.shape,
- texShape: input.isUniform ? null : input.texData.texShape,
- isUniform: input.isUniform,
- isPacked: input.isUniform ? false : input.texData.isPacked,
- flatOffset: null
- };
- if (input.texData != null && input.texData.slice != null &&
- input.texData.slice.flatOffset > 0) {
- shapeInfo.flatOffset = input.texData.slice.flatOffset;
- }
- return { name: program.variableNames[i], shapeInfo };
- });
- const inShapeInfos = inputInfos.map(x => x.shapeInfo);
- const outShapeInfo = {
- logicalShape: output.shape,
- texShape: output.texData.texShape,
- isUniform: false,
- isPacked: output.texData.isPacked,
- flatOffset: null
- };
- const source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs);
- const webGLProgram = gpgpu.createProgram(source);
- // Add special uniforms (NAN, INFINITY)
- let infLoc = null;
- const nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
- if (env().getNumber('WEBGL_VERSION') === 1) {
- infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
- }
- // Add user-defined uniforms
- const uniformLocations = {};
- for (let i = 0; i < program.variableNames.length; i++) {
- const varName = program.variableNames[i];
- const shouldThrow = false;
- uniformLocations[varName] =
- gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
- uniformLocations[`offset${varName}`] =
- gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow);
- }
- return {
- program,
- source,
- webGLProgram,
- uniformLocations,
- inShapeInfos,
- outShapeInfo,
- infLoc,
- nanLoc,
- };
- }
- function validateBinaryAndProgram(shapeInfos, inputs) {
- if (shapeInfos.length !== inputs.length) {
- throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` +
- `was executed with ${inputs.length} inputs`);
- }
- shapeInfos.forEach((s, i) => {
- const shapeA = s.logicalShape;
- const input = inputs[i];
- const shapeB = input.shape;
- if (!arraysEqual(shapeA, shapeB)) {
- throw Error(`Binary was compiled with different shapes than ` +
- `the current args. Shapes ${shapeA} and ${shapeB} must match`);
- }
- // The input is uploaded as uniform.
- if (s.isUniform && input.isUniform) {
- return;
- }
- const texShapeA = s.texShape;
- const texShapeB = input.isUniform ? null : input.texData.texShape;
- if (!arraysEqual(texShapeA, texShapeB)) {
- throw Error(`Binary was compiled with different texture shapes than the` +
- ` current args. Shape ${texShapeA} and ${texShapeB} must match`);
- }
- });
- }
- function runProgram(gpgpu, binary, inputs, output, customSetup) {
- validateBinaryAndProgram(binary.inShapeInfos, inputs);
- validateBinaryAndProgram([binary.outShapeInfo], [output]);
- const outTex = output.texData.texture;
- const outTexShape = output.texData.texShape;
- if (output.texData.isPacked) {
- gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
- }
- else {
- gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
- }
- gpgpu.setProgram(binary.webGLProgram);
- // Set special uniforms (NAN, INFINITY)
- if (env().getNumber('WEBGL_VERSION') === 1) {
- if (binary.infLoc !== null) {
- gpgpu.gl.uniform1f(binary.infLoc, Infinity);
- }
- }
- if (binary.nanLoc !== null) {
- gpgpu.gl.uniform1f(binary.nanLoc, NaN);
- }
- // Set user-defined inputs
- inputs.forEach((input, i) => {
- const varName = binary.program.variableNames[i];
- const varLoc = binary.uniformLocations[varName];
- const varOffsetLoc = binary.uniformLocations[`offset${varName}`];
- if (varLoc == null) {
- // The compiler inferred that this variable is not used in this shader.
- return;
- }
- if (input.isUniform) {
- // Upload the values of the tensor as uniform.
- if (sizeFromShape(input.shape) < 2) {
- gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
- }
- else {
- let vals = input.uniformValues;
- if (!(vals instanceof Float32Array)) {
- vals = new Float32Array(vals);
- }
- gpgpu.gl.uniform1fv(varLoc, vals);
- }
- return;
- }
- // If the input was sliced, upload the flat offset index.
- if (input.texData.slice != null && varOffsetLoc != null) {
- gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
- }
- gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
- });
- if (customSetup != null) {
- customSetup(gpgpu, binary.webGLProgram);
- }
- gpgpu.executeProgram();
- }
- function makeShaderKey(program, inputs, output) {
- let keyInputs = '';
- inputs.concat(output).forEach(x => {
- const hasOffset = x.texData != null && x.texData.slice != null &&
- x.texData.slice.flatOffset > 0;
- const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
- keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
- });
- const keyUserCode = program.userCode;
- let key = program.constructor.name;
- // Fast string concat. See https://jsperf.com/string-concatenation/14.
- key += '_' + keyInputs + '_' + keyUserCode;
- return key;
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class Im2ColPackedProgram {
- constructor(outputShape, inputShape, convInfo) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- const { filterWidth, inChannels, strideWidth, strideHeight, padInfo, outWidth, dilationWidth, dilationHeight, dataFormat } = convInfo;
- const { left, top } = padInfo;
- const itemsPerBlockRow = inChannels * filterWidth;
- const glsl = getGlslDifferences();
- const isChannelsLast = dataFormat === 'channelsLast';
- const rowDim = isChannelsLast ? 0 : 1;
- const colDim = isChannelsLast ? 1 : 2;
- let unrolled = ``;
- for (let row = 0; row <= 1; row++) {
- for (let col = 0; col <= 1; col++) {
- unrolled += `
- blockIndex = rc.y + ${col};
- pos = rc.x + ${row};
-
- if(blockIndex < ${outputShape[1]} && pos < ${outputShape[0]}) {
- offsetY = int(blockIndex / (${outWidth})) * ${strideHeight} - ${top};
- d0 = offsetY + ${dilationHeight} * (pos / ${itemsPerBlockRow});
-
- if(d0 < ${inputShape[rowDim]} && d0 >= 0) {
-
- offsetX = int(mod(float(blockIndex), ${outWidth}.) * ${strideWidth}. - ${left}.);
- d1 = offsetX + ${dilationWidth} * (int(mod(float(pos), ${itemsPerBlockRow}.) / ${inChannels}.));
-
- if(d1 < ${inputShape[colDim]} && d1 >= 0) {
-
- ch = int(mod(float(pos), ${inChannels}.));
-
- if (${isChannelsLast}) {
- innerDims = vec2(d1, ch);
- result[${row * 2 + col}] = getChannel(
- getA(d0, int(innerDims.x),
- int(innerDims.y)), innerDims);
- } else {
- innerDims = vec2(d0, d1);
- result[${row * 2 + col}] = getChannel(
- getA(ch, int(innerDims.x),
- int(innerDims.y)), innerDims);
- }
- }
- }
- }
- `;
- }
- }
- this.userCode = `
- void main() {
- ivec2 rc = getOutputCoords();
-
- vec4 result = vec4(0);
-
- int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
- vec2 innerDims;
-
- ${unrolled}
-
- ${glsl.output} = result;
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class LRNProgram {
- constructor(xShape, radius, bias, alpha, beta) {
- this.variableNames = ['x'];
- this.outputShape = [];
- const rad = radius;
- const maxD = xShape[3] - 1;
- this.outputShape = xShape;
- // optimize pow(bias + alpha * sum, -beta)
- // src: https://github.com/tensorflow/tensorflow/..
- // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
- // tensorflow/core/kernels/mkl_lrn_op.cc#L320
- let powOperator;
- const basis = `float(${bias}) + float(${alpha}) * sum`;
- if (beta === 0.5) {
- powOperator = `inversesqrt(${basis})`;
- }
- else if (beta === 1.0) {
- powOperator = `1.0/(${basis})`;
- }
- else {
- powOperator = `exp(log(${basis}) * float(-${beta}));`;
- }
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int r = coords[1];
- int c = coords[2];
- int d = coords[3];
- float x = getX(b, r, c, d);
- float sum = 0.0;
- for (int j = -${rad}; j <= ${rad}; j++) {
- int idx = d + j;
- if (idx >= 0 && idx <= ${maxD}) {
- float z = getX(b, r, c, idx);
- sum += z * z;
- }
- }
- float val = x * ${powOperator};
- setOutput(val);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class LRNGradProgram {
- constructor(inputShape, depthRadius, bias, alpha, beta) {
- this.variableNames = ['inputImage', 'outputImage', 'dy'];
- this.outputShape = [];
- this.outputShape = inputShape;
- this.depth = inputShape[3];
- this.depthRadius = depthRadius;
- this.bias = bias;
- this.alpha = alpha;
- this.beta = beta;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int r = coords[1];
- int c = coords[2];
-
- float result = 0.0;
- for (int d = 0; d < ${this.depth}; ++d) {
- int depthBegin = int(max(0.0, float(d - ${depthRadius})));
- int depthEnd = int(min(float(${this.depth}),
- float(d + ${depthRadius} + 1)));
-
- const int MIN_DEPTH_BEGIN = 0;
- const int MAX_DEPTH_END = ${this.depth};
-
- float norm = 0.0;
- for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
- if (k < depthBegin){
- continue;
- }
- else if (k >= depthBegin && k < depthEnd) {
- norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
- }
- else {
- break;
- }
- }
-
- norm = float(${alpha}) * norm + float(${bias});
-
- for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
- if (k < depthBegin){
- continue;
- }
- else if (k >= depthBegin && k < depthEnd){
- float dyi = -2.0 * float(${alpha})
- * float(${beta})
- * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)
- / norm;
- if (k == d) {
- dyi += pow(norm, -1.0 * ${beta});
- }
- if (k == coords[3]) {
- dyi *= getDy(b, r, c, d);
- result += dyi;
- }
- }
- else {
- break;
- }
- }
- }
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class LRNPackedProgram {
- constructor(xShape, radius, bias, alpha, beta) {
- this.variableNames = ['x'];
- this.outputShape = [];
- this.packedInputs = true;
- this.packedOutput = true;
- const rad = radius;
- const maxD = xShape[3] - 1;
- this.outputShape = xShape;
- // optimize pow(bias + alpha * sum, -beta)
- // src: https://github.com/tensorflow/tensorflow/..
- // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
- // tensorflow/core/kernels/mkl_lrn_op.cc#L320
- let powOperator;
- const basis = `float(${bias}) + float(${alpha}) * sum`;
- if (beta === 0.5) {
- powOperator = `inversesqrt(${basis})`;
- }
- else if (beta === 1.0) {
- powOperator = `1.0/(${basis})`;
- }
- else {
- powOperator = `exp(log(${basis}) * float(-${beta}));`;
- }
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords.x;
- int r = coords.y;
- int c = coords.z;
- int d = coords.w;
-
- bool hasNextCol = d < ${this.outputShape[3]};
- bool hasNextRow = c < ${this.outputShape[2]};
-
- vec4 sum = vec4(0.);
- vec4 xFragAtOutputCoords = getX(b, r, c, d);
-
- vec4 xAtOutputCoords = vec4(
- getChannel(xFragAtOutputCoords, vec2(c, d)),
- hasNextCol ?
- getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
- hasNextRow ?
- getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
- (hasNextRow && hasNextCol) ?
- getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
- );
-
- int firstChannel = d - ${rad};
- vec2 cache = vec2(0.);
- if(firstChannel >= 0){
- vec4 firstChannelFrag = getX(b, r, c, firstChannel);
- cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
- if(hasNextRow){
- cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
- }
- }
-
- ivec2 depth = ivec2(d, d + 1);
- for (int j = - ${rad}; j <= ${rad}; j++) {
- ivec2 idx = depth + j;
- bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
- bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
-
- bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
- bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
-
- if(depthInRange || depthPlusOneInRange){
- vec4 z = vec4(0.);
- vec4 xFragAtCurrentDepth;
- z.xz = cache.xy;
- if(depthPlusOneInRange && hasNextCol){
- xFragAtCurrentDepth = idx.y != d ?
- getX(b, r, c, idx.y) : xFragAtOutputCoords;
- z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
- if(hasNextRow){
- z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
- }
- }
- cache.xy = z.yw;
- sum += z * z;
- }
- }
- vec4 result = xAtOutputCoords * ${powOperator};
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class MaxPool2DBackpropProgram {
- constructor(convInfo) {
- this.variableNames = ['dy', 'maxPos'];
- this.outputShape = convInfo.inShape;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
- this.userCode = `
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
-
- ivec2 dyRCCorner = coords.yz - pads;
- int dyRCorner = dyRCCorner.x;
- int dyCCorner = dyRCCorner.y;
-
- // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- float dyValue = getDy(b, idyR, idyC, d);
- int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
-
- // Get the current value, check it against the value from the
- // position matrix.
- int curPosValue = wR * ${effectiveFilterWidth} + wC;
- float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
-
- dotProd += dyValue * mask;
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
- class MaxPool3DBackpropProgram {
- constructor(convInfo) {
- this.variableNames = ['dy', 'maxPos'];
- this.outputShape = convInfo.inShape;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
- this.userCode = `
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int ch = coords.u;
-
- ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
- int dyDCorner = dyCorner.x;
- int dyRCorner = dyCorner.y;
- int dyCCorner = dyCorner.z;
-
- // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get
- // dx(xD, xR, xC, ch).
- // ? = to be determined. : = across all values in that axis.
- float dotProd = 0.0;
-
- for (int wD = 0; wD < ${effectiveFilterDepth};
- wD += ${dilationDepth}) {
- float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
-
- if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
- continue;
- }
- int idyD = int(dyD);
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
-
- if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
- fract(dyR) > 0.0) {
- continue;
- }
- int idyR = int(dyR);
-
- for (int wC = 0; wC < ${effectiveFilterWidth};
- wC += ${dilationWidth}) {
- float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
-
- if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
- fract(dyC) > 0.0) {
- continue;
- }
- int idyC = int(dyC);
-
- float dyValue = getDy(batch, idyD, idyR, idyC, ch);
- int maxPosValue = ${lastIndex} -
- int(getMaxPos(batch, idyD, idyR, idyC, ch));
-
- // Get the current value, check it against the value from the
- // position matrix.
- int curPosValue =
- wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
- wR * ${effectiveFilterWidth} + wC;
- float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
-
- dotProd += dyValue * mask;
- }
- }
- }
- setOutput(dotProd);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class MatMulPackedProgram {
- constructor(aShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false) {
- this.variableNames = ['matrixA', 'matrixB'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- const sharedDim = transposeA ? aShape[1] : aShape[2];
- const sharedDimensionPacked = Math.ceil(sharedDim / 2);
- const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
- const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
- const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
- const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
- let activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = `vec4 activation(vec4 a) {
- vec4 b = getPreluActivationWeightsAtOutCoords();
- ${activation}
- }`;
- }
- else {
- activationSnippet = `vec4 activation(vec4 x) {
- ${activation}
- }`;
- }
- applyActivationSnippet = `result = activation(result);`;
- }
- const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = `
- ${activationSnippet}
-
- const float sharedDimension = ${sharedDimensionPacked}.0;
-
- vec4 dot2x2ARowBCol(ivec3 rc) {
- vec4 result = vec4(0);
- for (int i = 0; i < ${sharedDimensionPacked}; i++) {
- vec4 a = getMatrixA(rc.x, ${aSample});
- vec4 b = getMatrixB(rc.x, ${bSample});
-
- // These swizzled products need to be separately added.
- // See: https://github.com/tensorflow/tfjs/issues/1735
- result += (${aSwizzle[0]} * ${bSwizzle[0]});
- result += (${aSwizzle[1]} * ${bSwizzle[1]});
- }
- return result;
- }
-
- void main() {
- ivec3 rc = getOutputCoords();
- vec4 result = dot2x2ARowBCol(rc);
-
- ${addBiasSnippet}
-
- ${applyActivationSnippet}
-
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class MultinomialProgram {
- constructor(batchSize, numOutcomes, numSamples) {
- this.variableNames = ['probs'];
- this.outputShape = [batchSize, numSamples];
- this.userCode = `
- uniform float seed;
-
- void main() {
- ivec2 coords = getOutputCoords();
- int batch = coords[0];
-
- float r = random(seed);
- float cdf = 0.0;
-
- for (int i = 0; i < ${numOutcomes - 1}; i++) {
- cdf += getProbs(batch, i);
-
- if (r < cdf) {
- setOutput(float(i));
- return;
- }
- }
-
- // If no other event happened, last event happened.
- setOutput(float(${numOutcomes - 1}));
- }
- `;
- }
- getCustomSetupFunc(seed) {
- return (gpgpu, webGLProgram) => {
- if (this.seedLoc == null) {
- this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed');
- }
- gpgpu.gl.uniform1f(this.seedLoc, seed);
- };
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class OneHotProgram {
- constructor(numIndices, depth, onValue, offValue) {
- this.variableNames = ['indices'];
- this.outputShape = [numIndices, depth];
- this.userCode = `
- void main() {
- ivec2 coords = getOutputCoords();
- int index = round(getIndices(coords.x));
- setOutput(mix(float(${offValue}), float(${onValue}),
- float(index == coords.y)));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class PackProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- // Only input / output 3D tensors.
- this.outputShape = outputShape;
- const rank = outputShape.length;
- if (rank === 0) {
- this.userCode = `
- void main() {
- setOutput(vec4(getA(), 0., 0., 0.));
- }
- `;
- }
- else {
- const channels = getChannels('rc', rank);
- const dtype = getCoordsDataType(rank);
- const outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels);
- const setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels);
- const output = getOutput(outputShape, channels);
- this.userCode = `
- void main() {
- ${dtype} rc = getOutputCoords();
-
- if(${outOfBoundsCondition}) {
- setOutput(vec4(0));
- } else {
- ${setup}
-
- setOutput(vec4(${output}));
- }
- }
- `;
- }
- }
- }
- function getSourceCoordsArr(rank, dims) {
- const coords = [];
- for (let row = 0; row <= 1; row++) {
- for (let col = 0; col <= 1; col++) {
- let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
- for (let d = 2; d < rank; d++) {
- coord = `${dims[dims.length - 1 - d]},` + coord;
- }
- coords.push(coord);
- }
- }
- return coords;
- }
- function getOutOfBoundsCondition(rank, shape, dims) {
- if (rank === 1) {
- return `rc > ${shape[0]}`;
- }
- let cond = '';
- for (let i = rank - 2; i < rank; i++) {
- cond += `${dims[i]} >= ${shape[i]}`;
- if (i < rank - 1) {
- cond += '||';
- }
- }
- return cond;
- }
- function getSetup(rank, cols, rows, dims) {
- if (rank === 1) {
- return '';
- }
- const innerDims = dims.slice(-2);
- return `
- int r = ${innerDims[0]};
- int c = ${innerDims[1]};
- int rp1 = r + 1;
- int cp1 = c + 1;
-
- bool cEdge = cp1 >= ${cols};
- bool rEdge = rp1 >= ${rows};
- `;
- }
- function getOutput(shape, dims) {
- const rank = shape.length;
- const sourceCoords = getSourceCoordsArr(rank, dims);
- if (rank === 1) {
- return `getA(rc),
- rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
- 0, 0`;
- }
- return `getA(${sourceCoords[0]}),
- cEdge ? 0. : getA(${sourceCoords[1]}),
- rEdge ? 0. : getA(${sourceCoords[2]}),
- rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class PadProgram {
- constructor(xShape, paddings, constantValue) {
- this.variableNames = ['x'];
- this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
- const rank = xShape.length;
- const type = getCoordsDataType(rank);
- const start = paddings.map(p => p[0]).join(',');
- const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
- const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
- if (rank === 1) {
- this.userCode = `
- int start = ${start};
- int end = ${end};
-
- void main() {
- int outC = getOutputCoords();
- if (outC < start || outC >= end) {
- setOutput(float(${constantValue}));
- } else {
- setOutput(getX(outC - start));
- }
- }
- `;
- return;
- }
- this.userCode = `
- ${type} start = ${type}(${start});
- ${type} end = ${type}(${end});
-
- void main() {
- ${type} outC = getOutputCoords();
- if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
- setOutput(float(${constantValue}));
- } else {
- ${type} coords = outC - start;
- setOutput(getX(${unpackedCoords}));
- }
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class PadPackedProgram {
- constructor(xShape, paddings, constantValue) {
- this.variableNames = ['x'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
- const rank = xShape.length;
- const dtype = getCoordsDataType(rank);
- const start = paddings.map(p => p[0]).join(',');
- const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
- const coords = getChannels('rc', rank);
- const source = getChannels('source', rank);
- const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
- const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
- const componentSetup = [
- `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
- if(${cLimit}) {
- `,
- rank === 1 ? '' : `}
- rc = outputLoc;
- ${coords[rank - 2]} += 1;
- if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
- rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
- if(${cLimit}) {`
- ];
- const paddingArea = rank === 1 ?
- 'rc < start || rc >= end' :
- 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
- let mainLoop = '';
- for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
- mainLoop += `
- ${componentSetup[i]}
- if (${paddingArea}) {
- result[${i}] = float(${constantValue});
- } else {
- ${dtype} source = rc - start;
- result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
- }
- `;
- }
- mainLoop += (rank === 1 ? `} ` : `}}`);
- this.userCode = `
- const ${dtype} start = ${dtype}(${start});
- const ${dtype} end = ${dtype}(${end});
-
- void main() {
- ${dtype} outputLoc = getOutputCoords();
- vec4 result = vec4(0.);
- ${mainLoop}
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class Pool2DProgram {
- constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
- this.variableNames = ['x'];
- if (poolType === 'avg' && computePositions) {
- throw new Error('Cannot compute positions for average pool.');
- }
- const filterWidth = convInfo.filterWidth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- this.outputShape = convInfo.outShape;
- const isAvgPool = poolType === 'avg';
- const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
- const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
- let initializationValue = '0.0';
- if (!isAvgPool) {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- }
- if (computePositions) {
- const compareOp = '>=';
- this.userCode = `
- const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords[0];
- int d = coords[3];
-
- ivec2 xRCCorner = coords.yz * strides - pads;
- int xRCorner = xRCCorner.x;
- int xCCorner = xRCCorner.y;
-
- // max/min x(?, ?, d) to get y(yR, yC, d).
- // ? = to be determined
- float minMaxValue = 0.0;
- float minMaxValueFound = 0.0;
- int minMaxPosition = 0;
- float avgValue = 0.0;
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- int xR = xRCorner + wR;
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${effectiveFilterWidth};
- wC += ${dilationWidth}) {
- int xC = xCCorner + wC;
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- float value = getX(batch, xR, xC, d);
-
- // If a min / max value has already been found, use it. If not,
- // use the current value.
- float currMinMaxValue = mix(
- value, minMaxValue, minMaxValueFound);
- if (value ${compareOp} currMinMaxValue) {
- minMaxValue = value;
- minMaxValueFound = 1.0;
- minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
- flattenPositionStr) :
- `wR * ${effectiveFilterWidth} + wC`};
- }
- }
- }
- setOutput(float(minMaxPosition));
- }
- `;
- return;
- }
- const compareOp = 'max';
- let returnValue = `${poolType}(${poolType}(${poolType}(` +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (poolType === 'avg') {
- returnValue = `avgValue / count`;
- }
- const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
- const filterWidthVec4Remainder = filterWidth % 4;
- const updateSnippet = `
- if (${isAvgPool}) {
- avgValue += dot(values, ones);
- } else {
- minMaxValue = ${compareOp}(values, minMaxValue);
- }
- `;
- this.userCode = `
- const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
- const ivec2 pads = ivec2(${padTop}, ${padLeft});
- const float initializationValue = ${initializationValue};
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
-
- float count = 0.0;
-
- float getValue(int batch, int xR, int xC, int d) {
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- return initializationValue;
- }
- count += 1.0;
- return getX(batch, xR, xC, d);
- }
-
- void main() {
- ivec4 coords = getOutputCoords();
- int batch = coords[0];
- int d = coords[3];
-
- ivec2 xRCCorner = coords.yz * strides - pads;
- int xRCorner = xRCCorner.x;
- int xCCorner = xRCCorner.y;
-
- // max/min x(?, ?, d) to get y(yR, yC, d).
- // ? = to be determined
- vec4 minMaxValue = vec4(${initializationValue});
- float avgValue = 0.0;
- count = 0.0;
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- int xR = xRCorner + wR;
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
- int xC = xCCorner + wC * ${dilationWidth};
-
- vec4 values = vec4(
- getValue(batch, xR, xC, d),
- getValue(batch, xR, xC + ${dilationWidth}, d),
- getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
- getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
- );
-
- ${updateSnippet}
- }
-
- int xC = xCCorner + ${filterWidthNearestVec4};
- if (${filterWidthVec4Remainder === 1}) {
- vec4 values = vec4(
- getValue(batch, xR, xC, d),
- initializationValue,
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${filterWidthVec4Remainder === 2}) {
- vec4 values = vec4(
- getValue(batch, xR, xC, d),
- getValue(batch, xR, xC + ${dilationWidth}, d),
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${filterWidthVec4Remainder === 3}) {
- vec4 values = vec4(
- getValue(batch, xR, xC, d),
- getValue(batch, xR, xC + ${dilationWidth}, d),
- getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
- initializationValue
- );
-
- ${updateSnippet}
- }
- }
- setOutput(${returnValue});
- }
- `;
- }
- }
- class Pool3DProgram {
- constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
- this.variableNames = ['x'];
- if (poolType === 'avg' && computePositions) {
- throw new Error('Cannot compute positions for average pool.');
- }
- const filterWidth = convInfo.filterWidth;
- const strideDepth = convInfo.strideDepth;
- const strideHeight = convInfo.strideHeight;
- const strideWidth = convInfo.strideWidth;
- const dilationDepth = convInfo.dilationDepth;
- const dilationHeight = convInfo.dilationHeight;
- const dilationWidth = convInfo.dilationWidth;
- const effectiveFilterDepth = convInfo.effectiveFilterDepth;
- const effectiveFilterHeight = convInfo.effectiveFilterHeight;
- const effectiveFilterWidth = convInfo.effectiveFilterWidth;
- const padFront = convInfo.padInfo.front;
- const padTop = convInfo.padInfo.top;
- const padLeft = convInfo.padInfo.left;
- this.outputShape = convInfo.outShape;
- const isAvgPool = poolType === 'avg';
- let initializationValue = '0.0';
- if (!isAvgPool) {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- }
- if (computePositions) {
- const compareOp = '>=';
- this.userCode = `
- const ivec3 strides =
- ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int ch = coords.u;
-
- ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
- int xDCorner = xCorner.x;
- int xRCorner = xCorner.y;
- int xCCorner = xCorner.z;
-
- // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).
- // ? = to be determined
- float minMaxValue = 0.0;
- float minMaxValueFound = 0.0;
- int minMaxPosition = 0;
-
- for (int wD = 0; wD < ${effectiveFilterDepth};
- wD += ${dilationDepth}) {
- int xD = xDCorner + wD;
-
- if (xD < 0 || xD >= ${convInfo.inDepth}) {
- continue;
- }
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- int xR = xRCorner + wR;
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${effectiveFilterWidth};
- wC += ${dilationWidth}) {
- int xC = xCCorner + wC;
-
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- continue;
- }
-
- float value = getX(batch, xD, xR, xC, ch);
-
- // If a min / max value has already been found, use it. If not,
- // use the current value.
- float currMinMaxValue = mix(
- value, minMaxValue, minMaxValueFound);
- if (value ${compareOp} currMinMaxValue) {
- minMaxValue = value;
- minMaxValueFound = 1.0;
- minMaxPosition = ${flattenPositions ?
- (includeBatchInIndex ?
- `(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
- `((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
- `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
- wR * ${effectiveFilterWidth} + wC`};
- }
- }
- }
- }
- setOutput(float(minMaxPosition));
- }
- `;
- return;
- }
- const compareOp = 'max';
- let returnValue = `${poolType}(${poolType}(${poolType}(` +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (poolType === 'avg') {
- returnValue = `avgValue / count`;
- }
- const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
- const filterWidthVec4Remainder = filterWidth % 4;
- const updateSnippet = `
- if (${isAvgPool}) {
- avgValue += dot(values, ones);
- } else {
- minMaxValue = ${compareOp}(values, minMaxValue);
- }
- `;
- this.userCode = `
- const ivec3 strides =
- ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
- const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
- const float initializationValue = ${initializationValue};
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
-
- float count = 0.0;
-
- float getValue(int batch, int xD, int xR, int xC, int ch) {
- if (xC < 0 || xC >= ${convInfo.inWidth}) {
- return initializationValue;
- }
- count += 1.0;
- return getX(batch, xD, xR, xC, ch);
- }
-
- void main() {
- ivec5 coords = getOutputCoords();
- int batch = coords.x;
- int ch = coords.u;
-
- ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
- int xDCorner = xCorner.x;
- int xRCorner = xCorner.y;
- int xCCorner = xCorner.z;
-
- // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).
- // ? = to be determined
- vec4 minMaxValue = vec4(${initializationValue});
- float avgValue = 0.0;
- count = 0.0;
-
- for (int wD = 0; wD < ${effectiveFilterDepth};
- wD += ${dilationDepth}) {
- int xD = xDCorner + wD;
-
- if (xD < 0 || xD >= ${convInfo.inDepth}) {
- continue;
- }
-
- for (int wR = 0; wR < ${effectiveFilterHeight};
- wR += ${dilationHeight}) {
- int xR = xRCorner + wR;
-
- if (xR < 0 || xR >= ${convInfo.inHeight}) {
- continue;
- }
-
- for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
- int xC = xCCorner + wC * ${dilationWidth};
-
- vec4 values = vec4(
- getValue(batch, xD, xR, xC, ch),
- getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
- getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
- getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
- );
-
- ${updateSnippet}
- }
-
- int xC = xCCorner + ${filterWidthNearestVec4};
- if (${filterWidthVec4Remainder === 1}) {
- vec4 values = vec4(
- getValue(batch, xD, xR, xC, ch),
- initializationValue,
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${filterWidthVec4Remainder === 2}) {
- vec4 values = vec4(
- getValue(batch, xD, xR, xC, ch),
- getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${filterWidthVec4Remainder === 3}) {
- vec4 values = vec4(
- getValue(batch, xD, xR, xC, ch),
- getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
- getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
- initializationValue
- );
-
- ${updateSnippet}
- }
- }
- setOutput(${returnValue});
- }
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ReduceProgram {
- constructor(reduceInfo, reduceType) {
- this.variableNames = ['x'];
- const { windowSize, batchSize, inSize, outSize } = reduceInfo;
- this.outputShape = [batchSize, outSize];
- let initializationValue = '0.0';
- let compareOp = ``;
- if (reduceType === 'prod') {
- initializationValue = '1.0';
- }
- else if (reduceType === 'min') {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '1.0 / 1e-20';
- compareOp = `min`;
- }
- else if (reduceType === 'max') {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- compareOp = `max`;
- }
- let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (reduceType === 'sum') {
- returnValue = `sumValue`;
- }
- else if (reduceType === 'prod') {
- returnValue = `prodValue`;
- }
- else if (reduceType === 'all') {
- returnValue = `allValue`;
- }
- else if (reduceType === 'any') {
- returnValue = `anyValue`;
- }
- const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
- const windowSizeVec4Remainder = windowSize % 4;
- let updateSnippet = `
- if (${reduceType === 'sum'}) {
- sumValue += dot(values, ones);
- } else if (${reduceType === 'prod'}) {
- vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
- prodValue *= tmp[0] * tmp[1];
- } else {
- minMaxValue = ${compareOp}(values, minMaxValue);
- }
- `;
- let vecType = `vec4`;
- if (reduceType === 'all') {
- initializationValue = '1.0';
- updateSnippet = `
- bool reducedAllValue = all(values);
- float floatedReducedAllValue = float(reducedAllValue);
- allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
- `;
- vecType = `bvec4`;
- }
- else if (reduceType === 'any') {
- initializationValue = '0.0';
- updateSnippet = `
- bool reducedAnyValue = any(values);
- float floatedReducedAnyValue = float(reducedAnyValue);
- anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
- `;
- vecType = `bvec4`;
- }
- let checkOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkOutOfBounds = `
- if (inIdx < 0 || inIdx >= ${inSize}) {
- return initializationValue;
- }
- `;
- }
- this.userCode = `
- const float initializationValue = ${initializationValue};
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
-
- float getValue(int batch, int inIdx) {
- ${checkOutOfBounds}
- return getX(batch, inIdx);
- }
-
- void main() {
- ivec2 coords = getOutputCoords();
- int batch = coords[0];
- int outIdx = coords[1];
- int inOffset = outIdx * ${windowSize};
-
- vec4 minMaxValue = vec4(${initializationValue});
- float prodValue = 1.0;
- float sumValue = 0.0;
- float allValue = 1.0;
- float anyValue = 0.0;
-
- for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
- int inIdx = inOffset + i;
- ${vecType} values = ${vecType}(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2),
- getValue(batch, inIdx + 3)
- );
-
- ${updateSnippet}
- }
-
- int inIdx = inOffset + ${windowSizeNearestVec4};
- if (${windowSizeVec4Remainder === 1}) {
- ${vecType} values = ${vecType}(
- getValue(batch, inIdx),
- initializationValue,
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 2}) {
- ${vecType} values = ${vecType}(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- initializationValue,
- initializationValue
- );
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 3}) {
- ${vecType} values = ${vecType}(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2),
- initializationValue
- );
-
- ${updateSnippet}
- }
- setOutput(${returnValue});
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ReshapePackedProgram {
- constructor(outputShape, inputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- let mainLoop = ``;
- for (let i = 0; i < 4; i++) {
- let thisRC = `thisRC = rc;`;
- if (i % 2 === 1) {
- thisRC += `thisRC.z += 1;`;
- }
- if (i > 1) {
- thisRC += `thisRC.y += 1;`;
- }
- mainLoop += `
- ${thisRC}
- ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
- int flatIndex = getFlatIndex(thisRC);
-
- ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
- vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
-
- result[${i}] =
- getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
- ${i > 0 ? '}' : ''}
- `;
- }
- this.userCode = `
- ${getReshapedInputCoords(inputShape)}
- ${getFlatIndexFrom3D(outputShape)}
-
- void main() {
- ivec3 rc = getOutputCoords();
-
- vec4 result = vec4(0.);
-
- ivec3 thisRC;
- int rows = ${outputShape[1]};
- int cols = ${outputShape[2]};
-
- ${mainLoop}
-
- setOutput(result);
- }
- `;
- }
- }
- function getReshapedInputCoords(shape) {
- const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
- return `
- ivec3 inputCoordsFromReshapedOutCoords(int index) {
- ${coordsFromIndexSnippet}
- return ivec3(r, c, d);
- }
- `;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ResizeBilinearBackpropProgram {
- constructor(dy, x, alignCorners) {
- this.variableNames = ['dy'];
- this.outputShape = [];
- this.outputShape = x.shape;
- const [, xHeight, xWidth,] = x.shape;
- const [, yHeight, yWidth] = dy.shape;
- // In the backwards pass, we want to find the pixels that were generated for
- // each pixel in the input image the forward pass and add the corresponding
- // coefficient from dy to the gradient (with some interpolation).
- const effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- const effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- const heightScale = effectiveXSize[0] / effectiveYSize[0];
- const widthScale = effectiveXSize[1] / effectiveYSize[1];
- const invHeightScale = 1 / heightScale;
- const invWidthScale = 1 / widthScale;
- // This defines the size of the window of values around a particular
- // index in dy that we want to search for contributions to dx.
- const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
- const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
- int r = coords[1];
- int c = coords[2];
-
- float accumulator = 0.0;
-
- const float heightScale = float(${heightScale});
- const float widthScale = float(${widthScale});
-
- const float invHeightScale = float(${invHeightScale});
- const float invWidthScale = float(${invWidthScale});
-
- const int winHeight = int(${winHeight});
- const int winWidth = int(${winWidth});
-
- // Compute bounds for where in dy we will look
- float startRLerp = floor(float(r) * invHeightScale);
- int startDyR = int(startRLerp - float(winHeight / 2));
-
- float startCLerp = floor(float(c) * invWidthScale);
- int startDyC = int(startCLerp - float(winWidth / 2));
-
- // Loop over dy
- for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
- int dyR = dyROffset + startDyR;
-
- // Guard against the window exceeding the bounds of dy
- if (dyR < 0 || dyR >= ${yHeight}) {
- continue;
- }
-
- for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
- int dyC = dyCOffset + startDyC;
-
- // Guard against the window exceeding the bounds of dy
- if (dyC < 0 || dyC >= ${yWidth}) {
- continue;
- }
-
- float dxR = float(dyR) * heightScale;
- int topDxRIndex = int(floor(dxR));
- int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
- float dxRLerp = dxR - float(topDxRIndex);
- float inverseDxRLerp = 1.0 - dxRLerp;
-
- float dxC = float(dyC) * widthScale;
- int leftDxCIndex = int(floor(dxC));
- int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
- float dxCLerp = dxC - float(leftDxCIndex);
- float inverseDxCLerp = 1.0 - dxCLerp;
-
- if (r == topDxRIndex && c == leftDxCIndex) {
- // topLeft
- accumulator +=
- getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
- }
-
- if (r == topDxRIndex && c == rightDxCIndex) {
- // topRight
- accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
- }
-
- if (r == bottomDxRIndex && c == leftDxCIndex) {
- // bottomLeft
- accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
- }
-
- if (r == bottomDxRIndex && c == rightDxCIndex) {
- // bottomRight
- accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
- }
- }
- }
- // End loop over dy
-
- setOutput(accumulator);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ResizeBilinearProgram {
- constructor(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.outputShape = [];
- const [batch, oldHeight, oldWidth, depth] = inputShape;
- this.outputShape = [batch, newHeight, newWidth, depth];
- const effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- const effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- this.userCode = `
- const vec2 effectiveInputOverOutputRatioRC = vec2(
- ${effectiveInSize[0] / effectiveOutSize[0]},
- ${effectiveInSize[1] / effectiveOutSize[1]});
- const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
-
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
- ivec2 yRC = coords.yz;
-
- // Fractional source index.
- vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;
-
- // Compute the four integer indices.
- ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);
- ivec2 sourceCeilRC = ivec2(
- min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
-
- float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
- float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
- float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
- float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
-
- vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
-
- float top = topLeft + (topRight - topLeft) * fracRC.y;
- float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
- float newValue = top + (bottom - top) * fracRC.x;
-
- setOutput(newValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ResizeBilinearPackedProgram {
- constructor(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = [];
- const [batch, oldHeight, oldWidth, depth] = inputShape;
- this.outputShape = [batch, newHeight, newWidth, depth];
- const effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- const effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- this.userCode = `
- const vec3 effectiveInputOverOutputRatioRC = vec3(
- ${effectiveInSize[0] / effectiveOutSize[0]},
- ${effectiveInSize[1] / effectiveOutSize[1]},
- ${effectiveInSize[1] / effectiveOutSize[1]});
- const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
- ${oldWidth}.0);
-
- float getAValue(int b, int r, int c, int d) {
- return getChannel(getA(b, r, c, d), vec2(c, d));
- }
-
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
- // Calculate values for next column in yRC.z.
- ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
-
- // Fractional source index.
- vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;
-
- // Compute the four integer indices.
- ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);
- ivec3 sourceCeilRC = ivec3(
- min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
-
- // Should we calculate next column and row elements in 2x2 packed cell.
- bool hasNextCol = d < ${depth - 1};
- bool hasNextRow = coords.z < ${newWidth - 1};
-
- // In parallel, construct four corners for all four components in
- // packed 2x2 cell.
- vec4 topLeft = vec4(
- getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
- hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
- : 0.0,
- hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
- : 0.0,
- (hasNextRow && hasNextCol) ?
- getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
-
- vec4 bottomLeft = vec4(
- getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
- hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
- : 0.0,
- hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
- : 0.0,
- (hasNextRow && hasNextCol) ?
- getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
-
- vec4 topRight = vec4(
- getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
- hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
- : 0.0,
- hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
- : 0.0,
- (hasNextRow && hasNextCol) ?
- getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
-
- vec4 bottomRight = vec4(
- getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
- hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
- : 0.0,
- hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
- : 0.0,
- (hasNextRow && hasNextCol) ?
- getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
-
- vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
-
- vec4 top = mix(topLeft, topRight, fracRC.yyzz);
- vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
- vec4 newValue = mix(top, bottom, fracRC.x);
-
- setOutput(newValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ResizeNearestNeigborBackpropProgram {
- constructor(dy, x, alignCorners) {
- this.variableNames = ['dy'];
- this.outputShape = [];
- this.outputShape = x.shape;
- const [, xHeight, xWidth,] = x.shape;
- const [, yHeight, yWidth] = dy.shape;
- // In the backwards pass, we want to find the pixels that were generated for
- // each pixel in the input image the forward pass and add the corresponding
- // coefficient from dy to the gradient (with some interpolation).
- const effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- const effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- const heightScale = effectiveXSize[0] / effectiveYSize[0];
- const widthScale = effectiveXSize[1] / effectiveYSize[1];
- const invHeightScale = 1 / heightScale;
- const invWidthScale = 1 / widthScale;
- // This defines the size of the window of values around a particular
- // index in dy that we want to search for contributions to dx.
- const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
- const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
- int r = coords[1];
- int c = coords[2];
-
- float accumulator = 0.0;
-
- const float heightScale = float(${heightScale});
- const float widthScale = float(${widthScale});
-
- const float invHeightScale = float(${invHeightScale});
- const float invWidthScale = float(${invWidthScale});
-
- const int winHeight = int(${winHeight});
- const int winWidth = int(${winWidth});
-
- // Compute bounds for where in dy we will look
- float startRLerp = floor(float(r) * invHeightScale);
- int startDyR = int(floor(startRLerp - float(winHeight / 2)));
-
- float startCLerp = floor(float(c) * invWidthScale);
- int startDyC = int(floor(startCLerp - float(winWidth / 2)));
-
- // Loop over dy
- for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
- int dyR = dyROffset + startDyR;
-
- // Guard against the window exceeding the bounds of dy
- if (dyR < 0 || dyR >= ${yHeight}) {
- continue;
- }
-
- for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
- int dyC = dyCOffset + startDyC;
-
- // Guard against the window exceeding the bounds of dy
- if (dyC < 0 || dyC >= ${yWidth}) {
- continue;
- }
-
- float sourceFracRow =
- float(${effectiveXSize[0]}) *
- (float(dyR) / float(${effectiveYSize[0]}));
-
- float sourceFracCol =
- float(${effectiveXSize[1]}) *
- (float(dyC) / float(${effectiveYSize[1]}));
-
- int sourceNearestRow = int(min(
- float(int(${xHeight}) - 1),
- ${alignCorners} ? float(round(sourceFracRow)) :
- float(floor(sourceFracRow))));
-
- int sourceNearestCol = int(min(
- float(int(${xWidth}) - 1),
- ${alignCorners} ? float(round(sourceFracCol)) :
- float(floor(sourceFracCol))));
-
- if (r == sourceNearestRow && c == sourceNearestCol) {
- accumulator += getDy(b, dyR, dyC, d);
- }
- }
- }
- // End loop over dy
-
- setOutput(accumulator);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ResizeNearestNeighborProgram {
- constructor(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.outputShape = [];
- const [batch, oldHeight, oldWidth, depth] = inputShape;
- this.outputShape = [batch, newHeight, newWidth, depth];
- const effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- const effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- // When align corners is false, we rounds the value with floor.
- const roundBase = alignCorners ? '0.5' : '0.0';
- this.userCode = `
- const vec2 effectiveInputOverOutputRatioRC = vec2(
- ${effectiveInSize[0] / effectiveOutSize[0]},
- ${effectiveInSize[1] / effectiveOutSize[1]});
- const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
-
- void main() {
- ivec4 coords = getOutputCoords();
- int b = coords[0];
- int d = coords[3];
- ivec2 yRC = coords.yz;
-
- // Fractional source index.
- vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;
-
- // Compute the coordinators of nearest neighbor point.
- ivec2 sourceNearestRC = ivec2(
- min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
-
- float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
-
- setOutput(newValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ReverseProgram {
- constructor(xShape, axis) {
- this.variableNames = ['x'];
- const rank = xShape.length;
- if (rank > 4) {
- throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
- }
- this.outputShape = xShape;
- if (rank === 1) {
- this.userCode = `
- void main() {
- int coord = getOutputCoords();
- setOutput(getX(${xShape[0]} - coord - 1));
- }
- `;
- return;
- }
- const getInCoord = (i) => {
- if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
- return `${xShape[i]} - coords[${i}] - 1`;
- }
- return `coords[${i}]`;
- };
- const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
- const type = getCoordsDataType(rank);
- this.userCode = `
- void main() {
- ${type} coords = getOutputCoords();
- setOutput(getX(${inCoords}));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ReversePackedProgram {
- constructor(xShape, axis) {
- this.variableNames = ['x'];
- this.packedInputs = true;
- this.packedOutput = true;
- const rank = xShape.length;
- if (rank > 4) {
- throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
- }
- this.outputShape = xShape;
- const channels = getChannels('rc', rank);
- const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;
- const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;
- const type = getCoordsDataType(rank);
- if (rank === 1) {
- this.userCode = `
- void main(){
- int rc = getOutputCoords();
- vec4 result = vec4(0.);
- result.r = getChannel(getX(${xShape[0]} - rc - 1),
- ${xShape[0]} - rc - 1);
- if(${nextColumn}){
- result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
- ${xShape[0]} - (rc + 1) - 1);
- }
- setOutput(result);
- }
- `;
- }
- else {
- this.userCode = `
- void main() {
- ${type} rc = getOutputCoords();
- vec4 result = vec4(0.);
- result.r = ${getR(channels.slice())};
- if(${nextColumn}){
- result.g = ${getG(channels.slice())};
- }
- if(${nextRow}) {
- result.b = ${getB(channels.slice())};
- if(${nextColumn}) {
- result.a = ${getA(channels.slice())};
- }
- }
- setOutput(result);
- }
- `;
- }
- function getR(channels) {
- return getChannel(channels);
- }
- function getG(channels) {
- channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
- return getChannel(channels);
- }
- function getB(channels) {
- channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
- return getChannel(channels);
- }
- function getA(channels) {
- channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
- channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
- return getChannel(channels);
- }
- function getChannel(channels) {
- const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));
- const inCoords = inCoordsArray.join(',');
- const innerDims = inCoordsArray.slice(-2).join(',');
- return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;
- }
- function getInCoord(i, channels1) {
- if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
- return `${xShape[i]} - ${channels1[i]} - 1`;
- }
- else {
- return `${channels1[i]}`;
- }
- }
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ScatterProgram {
- constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true) {
- this.variableNames = ['updates', 'indices', 'defaultValue'];
- this.outputShape = shape;
- const stridesType = getCoordsDataType(strides.length);
- const dtype = getCoordsDataType(shape.length);
- let indicesString = '';
- if (indicesRank === 1) {
- indicesString = 'i';
- }
- else if (indicesRank === 2) {
- indicesString = 'i, j';
- }
- const indicesSnippet = `getIndices(${indicesString})`;
- let updatesString = '';
- if (updatesRank === 1) {
- updatesString = 'i';
- }
- else if (updatesRank === 2) {
- updatesString = 'i, coords[1]';
- }
- const updatesSnippet = `getUpdates(${updatesString})`;
- const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
- this.userCode = `
- ${stridesType} strides = ${stridesType}(${strides});
-
- void main() {
- ${dtype} coords = getOutputCoords();
- float sum = 0.0;
- bool found = false;
- for (int i = 0; i < ${updateSize}; i++) {
- int flattenedIndex = 0;
- for (int j = 0; j < ${sliceDim}; j++) {
- int index = round(${indicesSnippet});
- flattenedIndex += index * ${strideString};
- }
- if (flattenedIndex == coords[0]) {
- sum += ${updatesSnippet};
- found = true;
- }
- }
- setOutput(mix(getDefaultValue(), sum, float(found)));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class SegmentOpProgram {
- constructor(segOpInfo, segOpType) {
- this.variableNames = ['x', 'segmentIds'];
- const windowSize = segOpInfo.windowSize;
- const batchSize = segOpInfo.batchSize;
- const inSize = segOpInfo.inSize;
- const numSegments = segOpInfo.numSegments;
- const outSize = numSegments * Math.ceil(inSize / windowSize);
- this.outputShape = [batchSize, outSize];
- const initializationValue = '0.0';
- const returnValue = `sumValue`;
- const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
- const windowSizeVec4Remainder = windowSize % 4;
- const updateSnippet = `
- sumValue += dot(values, segFilter);
- `;
- let checkValueOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkValueOutOfBounds = `
- if (inIdx < 0 || inIdx >= ${inSize}) {
- return initializationValue;
- }
- `;
- }
- let checkSegmentIdOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkSegmentIdOutOfBounds = `
- if (inIdx < 0 || inIdx >= ${inSize}) {
- return -1.0;
- }
- `;
- }
- this.userCode = `
- const float initializationValue = ${initializationValue};
-
- float getValue(int batch, int inIdx) {
- ${checkValueOutOfBounds}
- return getX(batch, inIdx);
- }
-
- float getSegmentIdAtIndex(int inIdx) {
- ${checkSegmentIdOutOfBounds}
- return getSegmentIds(inIdx);
- }
-
- void main() {
- ivec2 coords = getOutputCoords();
- int batch = coords[0];
- int outIdx = coords[1];
- int inOffset = int(floor(float(outIdx) / float(
- ${numSegments})) * float(${windowSize}));
- int currentSeg = int(mod(float(outIdx), float(${numSegments})));
-
- float sumValue = 0.0;
-
- for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
- int inIdx = inOffset + i;
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2),
- getValue(batch, inIdx + 3)
- );
-
- vec4 segFilter = vec4(
- int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
- );
-
- ${updateSnippet}
- }
-
- int inIdx = inOffset + ${windowSizeNearestVec4};
- if (${windowSizeVec4Remainder === 1}) {
- vec4 values = vec4(
- getValue(batch, inIdx),
- initializationValue,
- initializationValue,
- initializationValue
- );
-
- int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
-
- vec4 segFilter = vec4(
- int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
- 0,
- 0,
- 0
- );
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 2}) {
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- initializationValue,
- initializationValue
- );
-
- vec4 segFilter = vec4(
- int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
- 0,
- 0
- );
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 3}) {
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2),
- initializationValue
- );
-
- vec4 segFilter = vec4(
- int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
- int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
- 0
- );
-
- ${updateSnippet}
- }
- setOutput(${returnValue});
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class SelectProgram {
- constructor(cRank, shape, rank) {
- this.variableNames = ['c', 'a', 'b'];
- this.outputShape = shape;
- let cCoords;
- let abCoords;
- if (rank > 4) {
- throw Error(`Where for rank ${rank} is not yet supported`);
- }
- if (rank === 1) {
- abCoords = `resRC`;
- cCoords = `resRC`;
- }
- else {
- const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
- const cCoordVars = [];
- const abCoordVars = [];
- for (let i = 0; i < shape.length; i++) {
- abCoordVars.push(`${currentCoords[i]}`);
- if (i < cRank) {
- cCoordVars.push(`${currentCoords[i]}`);
- }
- }
- cCoords = cCoordVars.join();
- abCoords = abCoordVars.join();
- }
- const dtype = getCoordsDataType(rank);
- this.userCode = `
- void main() {
- ${dtype} resRC = getOutputCoords();
- float cVal = getC(${cCoords});
- if (cVal >= 1.0) {
- setOutput(getA(${abCoords}));
- } else {
- setOutput(getB(${abCoords}));
- }
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class SliceProgram {
- constructor(destSize) {
- this.variableNames = ['source'];
- this.outputShape = destSize;
- this.rank = destSize.length;
- const dtype = getCoordsDataType(this.rank);
- const uniformPart = `uniform int start[${this.rank}];`;
- const sourceCoords = getCoords$1(this.rank);
- let body;
- const coordSum = destSize.map((_, i) => {
- return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;
- });
- body = `
- ${dtype} sourceLoc;
- ${dtype} coords = getOutputCoords();
- ${coordSum.join('\n')}
- `;
- this.userCode = `
- ${uniformPart}
- void main() {
- ${body}
- setOutput(getSource(${sourceCoords}));
- }
- `;
- }
- getCustomSetupFunc(start) {
- if (start.length !== this.rank) {
- throw Error(`The rank (${this.rank}) of the program must match the ` +
- `length of start (${start.length})`);
- }
- return (gpgpu, webGLProgram) => {
- if (this.startLoc == null) {
- this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');
- if (this.startLoc == null) {
- // This means the compiler has optimized and realized it doesn't need
- // the uniform.
- return;
- }
- }
- gpgpu.gl.uniform1iv(this.startLoc, start);
- };
- }
- }
- const coords = ['x', 'y', 'z', 'w', 'u', 'v'];
- function getCoords$1(rank) {
- if (rank === 1) {
- return 'sourceLoc';
- }
- else if (rank <= 6) {
- return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');
- }
- else {
- throw Error(`Slicing for rank ${rank} is not yet supported`);
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class SlicePackedProgram {
- constructor(destSize) {
- this.variableNames = ['source'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = destSize;
- this.rank = destSize.length;
- const dtype = getCoordsDataType(this.rank);
- const coords = getChannels('coords', this.rank);
- const sourceLoc = getChannels('sourceLoc', this.rank);
- const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;
- const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;
- const upperRow = `
- result.x = ${getChannel};
- if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
- ++${sourceLoc[this.rank - 1]};
- result.y = ${getChannel};
- --${sourceLoc[this.rank - 1]};
- }
- `;
- const lowerRow = this.rank === 1 ? '' : `
- --${coords[this.rank - 1]};
- if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {
- ++${sourceLoc[this.rank - 2]};
- result.z = ${getChannel};
- if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
- ++${sourceLoc[this.rank - 1]};
- result.w = ${getChannel};
- }
- }
- `;
- const sourceLocSetup = this.rank <= 4 ?
- `sourceLoc = coords +
- ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :
- destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)
- .join('\n');
- this.userCode = `
- uniform int start[${this.rank}];
- void main() {
- ${dtype} coords = getOutputCoords();
- ${dtype} sourceLoc;
- ${sourceLocSetup}
- vec4 result = vec4(0.);
- ${upperRow}
- ${lowerRow}
- setOutput(result);
- }
- `;
- }
- getCustomSetupFunc(start) {
- if (start.length !== this.rank) {
- throw Error(`The rank (${this.rank}) of the program must match the ` +
- `length of start (${start.length})`);
- }
- return (gpgpu, webGLProgram) => {
- if (this.startLoc == null) {
- this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');
- if (this.startLoc == null) {
- // This means the compiler has optimized and realized it doesn't need
- // the uniform.
- return;
- }
- }
- gpgpu.gl.uniform1iv(this.startLoc, start);
- };
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class StridedSliceProgram {
- constructor(begin, strides, size) {
- this.variableNames = ['x'];
- this.outputShape = size;
- const rank = size.length;
- const inputDtype = getCoordsDataType(size.length);
- const dtype = getCoordsDataType(size.length);
- let newCoords = '';
- if (rank === 1) {
- newCoords = 'coords * strides + begin';
- }
- else {
- let outputAxis = 0;
- newCoords =
- size.map((_, i) => {
- outputAxis++;
- return size.length === 1 ?
- `coords * strides[${i}] + begin[${i}]` :
- `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
- })
- .join(',');
- }
- this.userCode = `
- ${inputDtype} begin = ${inputDtype}(${begin});
- ${inputDtype} strides = ${inputDtype}(${strides});
-
- void main() {
- ${dtype} coords = getOutputCoords();
- setOutput(getX(${newCoords}));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class TextureManager {
- constructor(gpgpu) {
- this.gpgpu = gpgpu;
- this.numUsedTextures = 0;
- this.numFreeTextures = 0;
- this._numBytesAllocated = 0;
- this._numBytesFree = 0; // How many bytes that have been allocated
- // are available for reuse.
- this.freeTextures = {};
- this.logEnabled = false;
- this.usedTextures = {};
- }
- acquireTexture(shapeRC, usage, isPacked) {
- const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
- const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
- if (!(shapeKey in this.freeTextures)) {
- this.freeTextures[shapeKey] = [];
- }
- if (!(shapeKey in this.usedTextures)) {
- this.usedTextures[shapeKey] = [];
- }
- const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
- if (this.freeTextures[shapeKey].length > 0) {
- this.numFreeTextures--;
- this.numUsedTextures++;
- this._numBytesFree -= texBytes;
- this.log();
- const newTexture = this.freeTextures[shapeKey].shift();
- this.usedTextures[shapeKey].push(newTexture);
- return newTexture;
- }
- let newTexture;
- if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
- newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
- }
- else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
- newTexture =
- this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
- }
- else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
- newTexture =
- this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
- }
- else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
- newTexture =
- this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
- }
- else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
- newTexture =
- this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
- }
- this.usedTextures[shapeKey].push(newTexture);
- this.numUsedTextures++;
- this._numBytesAllocated += texBytes;
- this.log();
- return newTexture;
- }
- releaseTexture(texture, shape, logicalTexType, isPacked) {
- if (this.freeTextures == null) {
- // Already disposed.
- return;
- }
- const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
- const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
- if (!(shapeKey in this.freeTextures)) {
- this.freeTextures[shapeKey] = [];
- }
- const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
- const deleteTexThreshold = env().get('WEBGL_DELETE_TEXTURE_THRESHOLD');
- if (deleteTexThreshold !== -1 &&
- this._numBytesAllocated > deleteTexThreshold) {
- this.gpgpu.deleteMatrixTexture(texture);
- this._numBytesAllocated -= texBytes;
- }
- else {
- this.freeTextures[shapeKey].push(texture);
- this.numFreeTextures++;
- this._numBytesFree += texBytes;
- }
- this.numUsedTextures--;
- const texList = this.usedTextures[shapeKey];
- const texIndex = texList.indexOf(texture);
- if (texIndex < 0) {
- throw new Error('Cannot release a texture that was never provided by this ' +
- 'texture manager');
- }
- texList.splice(texIndex, 1);
- this.log();
- }
- log() {
- if (!this.logEnabled) {
- return;
- }
- const total = this.numFreeTextures + this.numUsedTextures;
- console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`);
- const freeRatio = this._numBytesFree / this._numBytesAllocated;
- console.log(`Bytes allocated: ${this._numBytesAllocated}`);
- console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`);
- }
- get numBytesAllocated() {
- return this._numBytesAllocated;
- }
- get numBytesFree() {
- return this._numBytesFree;
- }
- getNumUsedTextures() {
- return this.numUsedTextures;
- }
- getNumFreeTextures() {
- return this.numFreeTextures;
- }
- dispose() {
- if (this.freeTextures == null) {
- // Already disposed.
- return;
- }
- for (const texShape in this.freeTextures) {
- this.freeTextures[texShape].forEach(tex => {
- this.gpgpu.deleteMatrixTexture(tex);
- });
- }
- for (const texShape in this.usedTextures) {
- this.usedTextures[texShape].forEach(tex => {
- this.gpgpu.deleteMatrixTexture(tex);
- });
- }
- this.freeTextures = null;
- this.usedTextures = null;
- this.numUsedTextures = 0;
- this.numFreeTextures = 0;
- this._numBytesAllocated = 0;
- this._numBytesFree = 0;
- }
- }
- function numBytesForInternalFormat(gl, internalFormat) {
- // tslint:disable-next-line:no-any
- const glany = gl;
- if (internalFormat === glany.R32F) {
- return 4;
- }
- else if (internalFormat === glany.R16F) {
- return 2;
- }
- else if (internalFormat === glany.RGBA32F) {
- return 16;
- }
- else if (internalFormat === gl.RGBA) {
- return 16;
- }
- else if (internalFormat === glany.RGBA16F) {
- return 8;
- }
- throw new Error(`Unknown internal format ${internalFormat}`);
- }
- function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
- // It is not possible to infer packed status from the texture type because
- // depending on the textureConfig, different texture types may resolve to the
- // same internal format (e.g. in WebGL1, the internal format for
- // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
- // explicitly.
- const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
- let numElements;
- if (isPacked) {
- const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
- numElements = packedWidth * packedHeight;
- }
- else {
- const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
- numElements = width * height;
- }
- const bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
- return numElements * bytesPerElement;
- }
- function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
- switch (physicalTexType) {
- case PhysicalTextureType.PACKED_2X2_FLOAT32:
- return getInternalFormatForPackedMatrixTexture(textureConfig);
- case PhysicalTextureType.PACKED_2X2_FLOAT16:
- return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
- case PhysicalTextureType.UNPACKED_FLOAT32:
- return getInternalFormatForFloat32MatrixTexture(textureConfig);
- case PhysicalTextureType.UNPACKED_FLOAT16:
- return getInternalFormatForFloat16MatrixTexture(textureConfig);
- case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
- return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
- default:
- throw new Error(`Unknown physical texture type ${physicalTexType}`);
- }
- }
- function getPhysicalTextureForRendering(isPacked) {
- if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
- if (isPacked) {
- return PhysicalTextureType.PACKED_2X2_FLOAT32;
- }
- return PhysicalTextureType.UNPACKED_FLOAT32;
- }
- if (isPacked) {
- return PhysicalTextureType.PACKED_2X2_FLOAT16;
- }
- return PhysicalTextureType.UNPACKED_FLOAT16;
- }
- function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
- if (logicalTexType === TextureUsage.UPLOAD) {
- return PhysicalTextureType.PACKED_2X2_FLOAT32;
- }
- else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
- return getPhysicalTextureForRendering(isPacked);
- }
- else if (logicalTexType === TextureUsage.DOWNLOAD ||
- logicalTexType === TextureUsage.PIXELS) {
- return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
- }
- throw new Error(`Unknown logical texture type ${logicalTexType}`);
- }
- function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
- return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class TileProgram {
- constructor(aShape, reps) {
- this.variableNames = ['A'];
- const outputShape = new Array(aShape.length);
- for (let i = 0; i < outputShape.length; i++) {
- outputShape[i] = aShape[i] * reps[i];
- }
- this.outputShape = outputShape;
- this.rank = outputShape.length;
- const dtype = getCoordsDataType(this.rank);
- const sourceCoords = getSourceCoords$2(aShape);
- this.userCode = `
- void main() {
- ${dtype} resRC = getOutputCoords();
- setOutput(getA(${sourceCoords}));
- }
- `;
- }
- }
- function getSourceCoords$2(aShape) {
- const rank = aShape.length;
- if (rank > 5) {
- throw Error(`Tile for rank ${rank} is not yet supported`);
- }
- if (rank === 1) {
- return `imod(resRC, ${aShape[0]})`;
- }
- const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
- const sourceCoords = [];
- for (let i = 0; i < aShape.length; i++) {
- sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
- }
- return sourceCoords.join();
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class UnaryOpProgram {
- constructor(aShape, opSnippet) {
- this.variableNames = ['A'];
- this.outputShape = aShape;
- this.userCode = `
- float unaryOperation(float x) {
- ${opSnippet}
- }
-
- void main() {
- float x = getAAtOutCoords();
- float y = unaryOperation(x);
-
- setOutput(y);
- }
- `;
- }
- }
- const CHECK_NAN_SNIPPET$2 = `if (isnan(x)) return x;`;
- const LINEAR = `return x;`;
- const ABS = `return abs(x);`;
- const RELU = CHECK_NAN_SNIPPET$2 + `
- return (x < 0.0) ? 0.0 : x;
-`;
- const RELU6 = CHECK_NAN_SNIPPET$2 + `
- return (x < 0.0) ? 0.0 : min(6.0, x);
-`;
- const ELU$1 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
- const SELU = `
- // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
- // see: https://arxiv.org/abs/1706.02515
- float scaleAlpha = ${SELU_SCALEALPHA};
- float scale = ${SELU_SCALE};
- return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
-`;
- function STEP(alpha = 0.0) {
- return CHECK_NAN_SNIPPET$2 + `
- return x > 0.0 ? 1.0 : float(${alpha});
- `;
- }
- const NEG = `return -x;`;
- const CEIL = `return ceil(x);`;
- const FLOOR = `return floor(x);`;
- const SIGN = `
- if (isnan(x)) { return 0.0; }
- return sign(x);
-`;
- const IS_NAN = `return float(isnan(x));`;
- const IS_INF = `return float(isinf(x));`;
- const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
- const ROUND = `
- // OpenGL ES does not support round function.
- // The algorithm is based on banker's rounding.
- float base = floor(x);
- if ((x - base) < 0.5) {
- return floor(x);
- } else if ((x - base) > 0.5) {
- return ceil(x);
- } else {
- if (mod(base, 2.0) == 0.0) {
- return base;
- } else {
- return base + 1.0;
- }
- }
-`;
- const EXP = `return exp(x);`;
- const EXPM1 = `return exp(x) - 1.0;`;
- const LOG = `if (x < 0.0) return NAN;
- return log(x);`;
- const LOG1P = `return log(1.0 + x);`;
- const SQRT = `return sqrt(x);`;
- const RSQRT = `return inversesqrt(x);`;
- const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;
- /**
- * mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
- *
- * epsilon is the difference between 1.0 and the next representable
- * float. For a single precision 32 bit float this should be 2^-23, see:
- * https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
- *
- * too_large = (x > -threshold) is value above which exp(x) may overflow
- * but softplus(x) == x is within machine epsilon
- *
- * too_small = (x < threshold) is value below which exp(x) may underflow,
- * but softplus(x) == exp(x) is within machine epsilon.
- */
- const SOFTPLUS = `
- float epsilon = 1.1920928955078125e-7;
- float threshold = log(epsilon) + 2.0;
-
- bool too_large = x > -threshold;
- bool too_small = x < threshold;
-
- float result;
- float exp_x = exp(x);
-
- if (too_large){
- result = x;
- }
- else if (too_small){
- result = exp_x;
- }
- else{
- result = log(exp_x + 1.0);
- }
- return result;
-`;
- const ASIN = CHECK_NAN_SNIPPET$2 + `
- if (abs(x) > 1.) {
- return NAN;
- }
- return asin(x);
-`;
- const ACOS = CHECK_NAN_SNIPPET$2 + `
- if (abs(x) > 1.) {
- return NAN;
- }
- return acos(x);
-`;
- const ATAN = CHECK_NAN_SNIPPET$2 + `
- return atan(x);
-`;
- const SINH = `
- float e2x = exp(x);
- return (e2x - 1.0 / e2x) / 2.0;
-`;
- const COSH = `
- float e2x = exp(-x);
- return (e2x + 1.0 / e2x) / 2.0;
-`;
- const TANH = `
- float e2x = exp(-2.0 * abs(x));
- return sign(x) * (1.0 - e2x) / (1.0 + e2x);
-`;
- const ASINH = CHECK_NAN_SNIPPET$2 + `return log(x + sqrt(x * x + 1.0));`;
- const ACOSH = CHECK_NAN_SNIPPET$2 + `
- if (x < 1.0) return NAN;
- return log(x + sqrt(x * x - 1.0));`;
- const ATANH = CHECK_NAN_SNIPPET$2 + `
- if ((x < -1.0) || (x > 1.0)) return NAN;
- return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
- const ERF = `
- // Error function is calculated approximately with elementary function.
- // See "Handbook of Mathematical Functions with Formulas,
- // Graphs, and Mathematical Tables", Abramowitz and Stegun.
- float p = ${ERF_P};
- float a1 = ${ERF_A1};
- float a2 = ${ERF_A2};
- float a3 = ${ERF_A3};
- float a4 = ${ERF_A4};
- float a5 = ${ERF_A5};
-
- float sign = sign(x);
- x = abs(x);
- float t = 1.0 / (1.0 + p * x);
- return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
-`;
- const RECIPROCAL = `return 1.0 / x;`;
- const LOGICAL_NOT = `return float(!(x >= 1.0));`;
- const CLONE = 'return x;';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const LINEAR$1 = `return x;`;
- const LOG$1 = `
- vec4 result = log(x);
- vec4 isNaN = vec4(lessThan(x, vec4(0.0)));
- result.r = isNaN.r == 1.0 ? NAN : result.r;
- result.g = isNaN.g == 1.0 ? NAN : result.g;
- result.b = isNaN.b == 1.0 ? NAN : result.b;
- result.a = isNaN.a == 1.0 ? NAN : result.a;
-
- return result;
-`;
- const RELU$1 = `
- vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
- bvec4 isNaN = isnan(x);
-
- result.r = isNaN.r ? x.r : result.r;
- result.g = isNaN.g ? x.g : result.g;
- result.b = isNaN.b ? x.b : result.b;
- result.a = isNaN.a ? x.a : result.a;
-
- return result;
-`;
- const RELU6$1 = `
- vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
- bvec4 isNaN = isnan(x);
-
- result.r = isNaN.r ? x.r : result.r;
- result.g = isNaN.g ? x.g : result.g;
- result.b = isNaN.b ? x.b : result.b;
- result.a = isNaN.a ? x.a : result.a;
-
- return result;
-`;
- const ELU$2 = `
- vec4 result;
-
- result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
- result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
- result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
- result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
-
- return result;
-`;
- class UnaryOpPackedProgram {
- constructor(aShape, opSnippet) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = aShape;
- this.userCode = `
- vec4 unaryOperation(vec4 x) {
- ${opSnippet}
- }
-
- void main() {
- vec4 x = getAAtOutCoords();
- vec4 y = unaryOperation(x);
-
- setOutput(y);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class UnpackProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = false;
- this.outputShape = outputShape;
- const rank = outputShape.length;
- const channels = getChannels('rc', rank);
- const dtype = getCoordsDataType(rank);
- const sourceCoords = getSourceCoords(rank, channels);
- const innerDims = channels.slice(-2);
- const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
- this.userCode = `
- void main() {
- ${dtype} rc = getOutputCoords();
- vec4 packedInput = getA(${sourceCoords});
-
- setOutput(getChannel(packedInput, ${coords}));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const { segment_util: segment_util$1 } = backend_util;
- const split$5 = split$1;
- const tile$4 = tile$1;
- const topkImpl$2 = topkImpl;
- const whereImpl$2 = whereImpl;
- const EPSILON_FLOAT32$1 = 1e-7;
- const EPSILON_FLOAT16$1 = 1e-4;
- const binaryCaches = {};
- function getBinaryCache(webGLVersion) {
- if (webGLVersion in binaryCaches) {
- return binaryCaches[webGLVersion];
- }
- binaryCaches[webGLVersion] = {};
- return binaryCaches[webGLVersion];
- }
- function mapActivationToShaderProgram(activation, packed = false) {
- if (activation === 'linear') {
- if (packed) {
- return LINEAR$1;
- }
- return LINEAR;
- }
- else if (activation === 'relu') {
- if (packed) {
- return RELU$1;
- }
- return RELU;
- }
- else if (activation === 'elu') {
- if (packed) {
- return ELU$2;
- }
- return ELU$1;
- }
- else if (activation === 'relu6') {
- if (packed) {
- return RELU6$1;
- }
- return RELU6;
- }
- else if (activation === 'prelu') {
- if (packed) {
- return PRELU$1;
- }
- return PRELU;
- }
- throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`);
- }
- // Empirically determined constant used to determine size threshold for handing
- // off execution to the CPU.
- const CPU_HANDOFF_SIZE_THRESHOLD = 128;
- // Empirically determined constant used to decide the number of MB on GPU
- // before we warn about high memory use. The MB are this constant * screen area
- // * dpi / 1024 / 1024.
- const BEFORE_PAGING_CONSTANT = 600;
- function numMBBeforeWarning() {
- if (env().global.screen == null) {
- return 1024; // 1 GB.
- }
- return (env().global.screen.height * env().global.screen.width *
- window.devicePixelRatio) *
- BEFORE_PAGING_CONSTANT / 1024 / 1024;
- }
- // Empirically determined minimal shared dimension in matmul before we forward
- // to a.mul(b).sum() in order to take advantage of GPU parallelism. See
- // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
- const MATMUL_SHARED_DIM_THRESHOLD = 1000;
- class MathBackendWebGL extends KernelBackend {
- constructor(gpgpu) {
- super();
- // Maps data ids that have a pending read operation, to list of subscribers.
- this.pendingRead = new WeakMap();
- // List of data ids that are scheduled for disposal, but are waiting on a
- // pending read operation.
- this.pendingDisposal = new WeakSet();
- // Used to count the number of 'shallow' sliced tensors that point to the
- // same data id.
- this.dataRefCount = new WeakMap();
- this.numBytesInGPU = 0;
- // Accumulated time spent (including blocking) in uploading data to webgl.
- this.uploadWaitMs = 0;
- // Accumulated time spent (including blocking in downloading data from webgl.
- this.downloadWaitMs = 0;
- this.warnedAboutMemory = false;
- this.warnedAboutCPUBackend = false;
- this.pendingDeletes = 0;
- this.disposed = false;
- if (!env().getBool('HAS_WEBGL')) {
- throw new Error('WebGL is not supported on this device');
- }
- if (gpgpu == null) {
- const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
- this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
- this.gpgpu = new GPGPUContext(gl);
- this.canvas = gl.canvas;
- this.gpgpuCreatedLocally = true;
- }
- else {
- this.gpgpu = gpgpu;
- this.binaryCache = {};
- this.gpgpuCreatedLocally = false;
- this.canvas = gpgpu.gl.canvas;
- }
- this.textureManager = new TextureManager(this.gpgpu);
- this.numMBBeforeWarning = numMBBeforeWarning();
- this.texData = new DataStorage(this, engine());
- }
- numDataIds() {
- return this.texData.numDataIds() +
- (this.cpuBackend ? this.cpuBackend.numDataIds() : 0) -
- this.pendingDeletes;
- }
- write(values, shape, dtype) {
- if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
- env().getBool('DEBUG')) {
- this.checkNumericalProblems(values);
- }
- if (dtype === 'complex64' && values != null) {
- throw new Error(`Cannot write to a complex64 dtype. ` +
- `Please use tf.complex(real, imag).`);
- }
- const dataId = {};
- this.texData.set(dataId, {
- shape,
- dtype,
- values,
- usage: TextureUsage.UPLOAD,
- refCount: 1,
- complexParentRefCount: 0
- });
- return dataId;
- }
- /** Increase refCount of a `TextureData`. */
- incRef(dataId) {
- const texData = this.texData.get(dataId);
- texData.refCount++;
- }
- /** Decrease refCount of a `TextureData`. */
- decRef(dataId) {
- if (this.texData.has(dataId)) {
- const texData = this.texData.get(dataId);
- texData.refCount--;
- }
- }
- move(dataId, values, shape, dtype) {
- if (env().getBool('DEBUG')) {
- this.checkNumericalProblems(values);
- }
- if (dtype === 'complex64') {
- throw new Error(`Cannot write to a complex64 dtype. ` +
- `Please use tf.complex(real, imag).`);
- }
- this.texData.set(dataId, {
- shape,
- dtype,
- values,
- usage: TextureUsage.UPLOAD,
- refCount: 1,
- complexParentRefCount: 0
- });
- }
- disposeIntermediateTensorInfo(tensorInfo) {
- const dataId = tensorInfo.dataId;
- if (this.texData.has(dataId)) {
- const textureData = this.texData.get(dataId);
- textureData.refCount--;
- if (textureData.refCount < 1) {
- this.disposeData(dataId);
- }
- }
- }
- readSync(dataId) {
- const texData = this.texData.get(dataId);
- const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData;
- // The presence of `slice` indicates this tensor is a shallow slice of a
- // different tensor, and is using that original tensor's texture. Run
- // `clone` in order to copy that texture and read from it.
- if (slice != null) {
- let program;
- if (isPacked) {
- program = new UnaryOpPackedProgram(shape, CLONE);
- }
- else {
- program = new UnaryOpProgram(shape, CLONE);
- }
- const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
- const data = this.readSync(res.dataId);
- this.disposeIntermediateTensorInfo(res);
- return data;
- }
- if (values != null) {
- return this.convertAndCacheOnCPU(dataId);
- }
- if (dtype === 'string') {
- return values;
- }
- const shouldTimeProgram = this.activeTimers != null;
- let start;
- if (shouldTimeProgram) {
- start = now();
- }
- let result;
- if (dtype === 'complex64') {
- const realValues = this.readSync(complexTensorInfos.real.dataId);
- const imagValues = this.readSync(complexTensorInfos.imag.dataId);
- result = mergeRealAndImagArrays(realValues, imagValues);
- }
- else {
- result = this.getValuesFromTexture(dataId);
- }
- if (shouldTimeProgram) {
- this.downloadWaitMs += now() - start;
- }
- return this.convertAndCacheOnCPU(dataId, result);
- }
- async read(dataId) {
- if (this.pendingRead.has(dataId)) {
- const subscribers = this.pendingRead.get(dataId);
- return new Promise(resolve => subscribers.push(resolve));
- }
- const texData = this.texData.get(dataId);
- const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData;
- // The presence of `slice` indicates this tensor is a shallow slice of a
- // different tensor, and is using that original tensor's texture. Run
- // `clone` in order to copy that texture and read from it.
- if (slice != null) {
- let program;
- if (isPacked) {
- program = new UnaryOpPackedProgram(shape, CLONE);
- }
- else {
- program = new UnaryOpProgram(shape, CLONE);
- }
- const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
- const data = this.read(res.dataId);
- this.disposeIntermediateTensorInfo(res);
- return data;
- }
- if (values != null) {
- return this.convertAndCacheOnCPU(dataId);
- }
- if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
- env().getNumber('WEBGL_VERSION') === 2) {
- throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
- `WEBGL_VERSION=2 not yet supported.`);
- }
- let buffer = null;
- let tmpDownloadTarget;
- if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
- // Possibly copy the texture into a buffer before inserting a fence.
- tmpDownloadTarget = this.decode(dataId);
- const tmpData = this.texData.get(tmpDownloadTarget.dataId);
- buffer = this.gpgpu.createBufferFromTexture(tmpData.texture, ...getDenseTexShape(shape));
- }
- this.pendingRead.set(dataId, []);
- if (dtype !== 'complex64') {
- // Create a fence and wait for it to resolve.
- await this.gpgpu.createAndWaitForFence();
- }
- // Download the values from the GPU.
- let vals;
- if (dtype === 'complex64') {
- const ps = await Promise.all([
- this.read(complexTensorInfos.real.dataId),
- this.read(complexTensorInfos.imag.dataId)
- ]);
- const realValues = ps[0];
- const imagValues = ps[1];
- vals = mergeRealAndImagArrays(realValues, imagValues);
- }
- else if (buffer == null) {
- vals = this.getValuesFromTexture(dataId);
- }
- else {
- const size = sizeFromShape(shape);
- vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
- }
- if (tmpDownloadTarget != null) {
- this.disposeIntermediateTensorInfo(tmpDownloadTarget);
- }
- const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
- const subscribers = this.pendingRead.get(dataId);
- this.pendingRead.delete(dataId);
- // Notify all pending reads.
- subscribers.forEach(resolve => resolve(dTypeVals));
- if (this.pendingDisposal.has(dataId)) {
- this.pendingDisposal.delete(dataId);
- this.disposeData(dataId);
- this.pendingDeletes--;
- }
- return dTypeVals;
- }
- checkNumericalProblems(values) {
- if (values == null) {
- return;
- }
- for (let i = 0; i < values.length; i++) {
- const num = values[i];
- if (!canBeRepresented(num)) {
- if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
- throw Error(`The value ${num} cannot be represented with your ` +
- `current settings. Consider enabling float32 rendering: ` +
- `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
- }
- throw Error(`The value ${num} cannot be represented on this device.`);
- }
- }
- }
- getValuesFromTexture(dataId) {
- const { shape, dtype, isPacked } = this.texData.get(dataId);
- const size = sizeFromShape(shape);
- if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
- const tmpTarget = this.decode(dataId);
- const tmpData = this.texData.get(tmpTarget.dataId);
- const vals = this.gpgpu
- .downloadMatrixFromPackedTexture(tmpData.texture, ...getDenseTexShape(shape))
- .subarray(0, size);
- this.disposeIntermediateTensorInfo(tmpTarget);
- return vals;
- }
- const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
- const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
- const program = shouldUsePackedProgram ?
- new EncodeFloatPackedProgram(outputShape) :
- new EncodeFloatProgram(outputShape);
- const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32');
- const tmpData = this.texData.get(output.dataId);
- const vals = this.gpgpu
- .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
- .subarray(0, size);
- this.disposeIntermediateTensorInfo(output);
- return vals;
- }
- async time(f) {
- const oldActiveTimers = this.activeTimers;
- const newActiveTimers = [];
- let outerMostTime = false;
- if (this.programTimersStack == null) {
- this.programTimersStack = newActiveTimers;
- outerMostTime = true;
- }
- else {
- this.activeTimers.push(newActiveTimers);
- }
- this.activeTimers = newActiveTimers;
- f();
- // needing to split these up because util.flatten only accepts certain types
- const flattenedActiveTimerQueries = flatten(this.activeTimers.map((d) => d.query))
- .filter(d => d != null);
- const flattenedActiveTimerNames = flatten(this.activeTimers.map((d) => d.name))
- .filter(d => d != null);
- this.activeTimers = oldActiveTimers;
- if (outerMostTime) {
- this.programTimersStack = null;
- }
- const res = {
- uploadWaitMs: this.uploadWaitMs,
- downloadWaitMs: this.downloadWaitMs,
- kernelMs: null,
- wallMs: null // will be filled by the engine
- };
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
- const kernelMs = await Promise.all(flattenedActiveTimerQueries);
- res['kernelMs'] = sum(kernelMs);
- res['getExtraProfileInfo'] = () => kernelMs.map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
- .map(d => `${d.name}: ${d.ms}`)
- .join(', ');
- }
- else {
- res['kernelMs'] = {
- error: 'WebGL query timers are not supported in this environment.'
- };
- }
- this.uploadWaitMs = 0;
- this.downloadWaitMs = 0;
- return res;
- }
- memory() {
- return {
- unreliable: false,
- numBytesInGPU: this.numBytesInGPU,
- numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
- numBytesInGPUFree: this.textureManager.numBytesFree
- };
- }
- startTimer() {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
- return this.gpgpu.beginQuery();
- }
- return { startMs: now(), endMs: null };
- }
- endTimer(query) {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
- this.gpgpu.endQuery();
- return query;
- }
- query.endMs = now();
- return query;
- }
- async getQueryTime(query) {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
- return this.gpgpu.waitForQueryAndGetTime(query);
- }
- const timerQuery = query;
- return timerQuery.endMs - timerQuery.startMs;
- }
- disposeData(dataId) {
- if (this.pendingDisposal.has(dataId)) {
- return;
- }
- if (this.pendingRead.has(dataId)) {
- this.pendingDisposal.add(dataId);
- this.pendingDeletes++;
- return;
- }
- // No-op if already disposed.
- if (!this.texData.has(dataId)) {
- return;
- }
- // Trying to dispose a textureData that has a 'kept' refCount, e.g. trying
- // to dispose a tensor whose data bucket is shared with a complex tensor. In
- // this case we are removing a reference to the textureData, but we
- // shouldn't actually dispose the texture.
- if (this.texData.get(dataId).complexParentRefCount > 0) {
- this.texData.get(dataId).refCount--;
- return;
- }
- this.releaseGPUData(dataId);
- const { complexTensorInfos } = this.texData.get(dataId);
- if (complexTensorInfos != null) {
- this.texData.get(complexTensorInfos.real.dataId).complexParentRefCount--;
- this.disposeIntermediateTensorInfo(complexTensorInfos.real);
- this.texData.get(complexTensorInfos.imag.dataId).complexParentRefCount--;
- this.disposeIntermediateTensorInfo(complexTensorInfos.imag);
- }
- this.texData.delete(dataId);
- }
- releaseGPUData(dataId) {
- const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId);
- const key = slice && slice.origDataId || dataId;
- const refCount = this.dataRefCount.get(key);
- if (refCount > 1) {
- this.dataRefCount.set(key, refCount - 1);
- }
- else {
- this.dataRefCount.delete(key);
- if (texture != null) {
- this.numBytesInGPU -= this.computeBytes(texShape, dtype);
- this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
- }
- }
- const texData = this.texData.get(dataId);
- texData.texture = null;
- texData.texShape = null;
- texData.isPacked = false;
- texData.slice = null;
- }
- getTexture(dataId) {
- this.uploadToGPU(dataId);
- return this.texData.get(dataId).texture;
- }
- /**
- * Returns internal information for the specific data bucket. Used in unit
- * tests.
- */
- getDataInfo(dataId) {
- return this.texData.get(dataId);
- }
- getCPUBackend() {
- if (!env().getBool('WEBGL_CPU_FORWARD')) {
- return null;
- }
- if (this.cpuBackend == null) {
- this.cpuBackend = engine().findBackend('cpu');
- }
- return this.cpuBackend;
- }
- /*
- Tests whether all the inputs to an op are small and on the CPU. This heuristic
- determines when it would be faster to execute a kernel on the CPU. WebGL
- kernels opt into running this check and forwarding when appropriate.
- TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
- sustainable strategy for optimizing backend execution of ops.
- */
- shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) {
- const cpuBackend = this.getCPUBackend();
- if (!this.warnedAboutCPUBackend && cpuBackend == null) {
- console.warn('Your application contains ops that are small enough to be ' +
- 'executed on the CPU backend, however the CPU backend cannot ' +
- 'be found. Consider importing the CPU backend ' +
- '(@tensorflow/tfjs-backend-cpu) for better performance.');
- this.warnedAboutCPUBackend = true;
- }
- return cpuBackend != null &&
- inputs.every(input => this.texData.get(input.dataId).texture == null &&
- sizeFromShape(input.shape) < sizeThreshold);
- }
- getGPGPUContext() {
- return this.gpgpu;
- }
- slice(x, begin, size) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = sliceImplCPU(this.texData.get(x.dataId).values, begin, size, x.shape, x.dtype);
- return this.makeOutput(size, x.dtype, outValues);
- }
- // Short-circuit computation if the slice is zero-sized.
- if (sizeFromShape(size) === 0) {
- return tensor([], size, x.dtype);
- }
- const { isPacked } = this.texData.get(x.dataId);
- const isContinous = isSliceContinous(x.shape, begin, size);
- if (isPacked || !isContinous) {
- const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
- new SlicePackedProgram(size) :
- new SliceProgram(size);
- const customSetup = program.getCustomSetupFunc(begin);
- return this.compileAndRun(program, [x], null, customSetup);
- }
- this.uploadToGPU(x.dataId);
- return this.shallowSlice(x, begin, size);
- }
- shallowSlice(x, begin, size) {
- const xTexData = this.texData.get(x.dataId);
- const t = this.makeOutput(size, x.dtype);
- const newTexData = this.texData.get(t.dataId);
- // Copy texture data from the original tensor.
- Object.assign(newTexData, xTexData);
- newTexData.shape = size;
- newTexData.dtype = x.dtype;
- let flatOffset = computeFlatOffset(begin, x.strides);
- if (xTexData.slice) {
- // We are slicing an already sliced tensor, so we have to accumulate
- // the offset.
- flatOffset += xTexData.slice.flatOffset;
- }
- newTexData.slice = {
- flatOffset,
- // Point to the original dataId, which is used to do ref counting.
- origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
- };
- // Increase the ref count for that data bucket.
- const refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;
- this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
- return t;
- }
- stridedSlice(x, begin, end, strides) {
- const cpuRes = this.tryRunOnCpuOrThrow([x], () => this.cpuBackend.stridedSlice(x, begin, end, strides));
- if (cpuRes) {
- return cpuRes;
- }
- const outShape = computeOutShape(begin, end, strides);
- if (outShape.some(axis => axis === 0)) {
- return tensor([], outShape);
- }
- const program = new StridedSliceProgram(begin, strides, outShape);
- return this.compileAndRun(program, [x]);
- }
- reverse(x, axis) {
- const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
- new ReversePackedProgram(x.shape, axis) :
- new ReverseProgram(x.shape, axis);
- return this.compileAndRun(program, [x]);
- }
- neg(x) {
- const cpuRes = this.tryRunOnCpuOrThrow([x], () => this.cpuBackend.neg(x));
- if (cpuRes) {
- return cpuRes;
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, NEG, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, NEG);
- return this.compileAndRun(program, [x]);
- }
- batchMatMul(a, b, transposeA, transposeB) {
- const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
- const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
- const sharedDim = transposeA ? a.shape[1] : a.shape[2];
- const [batch, ,] = a.shape;
- // Since the matrices are vectors, it is faster to call mul().sum()
- // because sum() is O(sqrt(N)) due to divide-and-conquer.
- if ((outerShapeA === 1 || outerShapeB === 1) &&
- sharedDim > MATMUL_SHARED_DIM_THRESHOLD) {
- if (transposeA) {
- a = transpose(a, [0, 2, 1]);
- }
- if (transposeB) {
- b = transpose(b, [0, 2, 1]);
- }
- const a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
- const axis = outerShapeB === 1 ? 2 : 1;
- const b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;
- // TODO(annxingyuan): Call multiply directly as part of batchMatMul
- // modularization.
- const product = mul(a3D, b3D);
- return product.sum(axis, true /* keepDims */);
- }
- const dtype = upcastType(a.dtype, b.dtype);
- const program = new MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);
- return this.compileAndRun(program, [a, b], dtype);
- }
- fusedBatchMatMul({ a, b, transposeA, transposeB, bias, activation, preluActivationWeights }) {
- const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
- const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
- const [batch, ,] = a.shape;
- const dtype = upcastType(a.dtype, b.dtype);
- const hasBias = bias != null;
- const hasPreluActivationWeights = preluActivationWeights != null;
- const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
- const program = new MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
- const inputs = [a, b];
- if (bias) {
- inputs.push(bias);
- }
- if (preluActivationWeights) {
- inputs.push(preluActivationWeights);
- }
- return this.compileAndRun(program, inputs, dtype);
- }
- localResponseNormalization4D(x, radius, bias, alpha, beta) {
- const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
- new LRNPackedProgram(x.shape, radius, bias, alpha, beta) :
- new LRNProgram(x.shape, radius, bias, alpha, beta);
- return this.compileAndRun(program, [x]);
- }
- LRNGrad(dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
- const program = new LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);
- return this.compileAndRun(program, [inputImage, outputImage, dy]);
- }
- tile(x, reps) {
- if (x.dtype === 'string') {
- const data = this.readSync(x.dataId);
- const decodedData = data.map(d => decodeString(d));
- const buf = buffer(x.shape, x.dtype, decodedData);
- return tile$4(buf, reps);
- }
- const program = new TileProgram(x.shape, reps);
- return this.compileAndRun(program, [x]);
- }
- pad(x, paddings, constantValue) {
- const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
- new PadPackedProgram(x.shape, paddings, constantValue) :
- new PadProgram(x.shape, paddings, constantValue);
- return this.compileAndRun(program, [x]);
- }
- gather(x, indices, axis) {
- const cpuRes = this.tryRunOnCpuOrThrow([x, indices], () => this.cpuBackend.gather(x, indices, axis));
- if (cpuRes) {
- return cpuRes;
- }
- const program = new GatherProgram(x.shape, indices.size, axis);
- return this.compileAndRun(program, [x, indices]);
- }
- batchToSpaceND(x, blockShape, crops) {
- assert(x.rank <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
- 'implemented yet');
- const prod = blockShape.reduce((a, b) => a * b);
- const reshaped = getReshaped(x.shape, blockShape, prod);
- const permuted = getPermuted(reshaped.length, blockShape.length);
- const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
- const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
- const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
- return transpose(x.reshape(reshaped), permuted)
- .reshape(reshapedPermuted)
- .slice(sliceBeginCoords, sliceSize);
- }
- spaceToBatchND(x, blockShape, paddings) {
- assert(x.rank <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
- 'implemented yet');
- const prod = blockShape.reduce((a, b) => a * b);
- const completePaddings = [[0, 0]];
- completePaddings.push(...paddings);
- for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
- completePaddings.push([0, 0]);
- }
- const paddedX = x.pad(completePaddings);
- const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
- const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
- const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
- const paddedXT = transpose(paddedX.reshape(reshapedPaddedShape), permutedReshapedPaddedPermutation);
- return reshape(paddedXT, flattenShape);
- }
- reduce(x, reduceType, dtype) {
- const batchSize = x.shape[0];
- const inSize = x.shape[1];
- const windowSize = computeOptimalWindowSize(inSize);
- const outSize = Math.ceil(inSize / windowSize);
- const reduceInfo = { windowSize, inSize, batchSize, outSize };
- const program = new ReduceProgram(reduceInfo, reduceType);
- const output = this.compileAndRun(program, [x], dtype);
- // No need to run another GPGPU program.
- if (output.shape[1] === 1) {
- return output;
- }
- return this.reduce(output, reduceType, dtype);
- }
- argReduce(x, reduceType, bestIndicesA = null) {
- let batchSize = x.shape[0];
- let inSize = x.shape[1];
- if (bestIndicesA != null) {
- batchSize = bestIndicesA.shape[0];
- inSize = bestIndicesA.shape[1];
- }
- const windowSize = computeOptimalWindowSize(inSize);
- const reduceInfo = {
- windowSize,
- inSize,
- batchSize,
- outSize: Math.ceil(inSize / windowSize)
- };
- const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
- const inputs = [x];
- if (bestIndicesA != null) {
- inputs.push(bestIndicesA);
- }
- const output = this.compileAndRun(program, inputs, 'int32');
- // No need to run another GPGPU program.
- if (output.shape[1] === 1) {
- return output;
- }
- return this.argReduce(x, reduceType, output);
- }
- argReducePacked(x, reduceType, bestIndicesA = null) {
- const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
- const inSize = inShape[inShape.length - 1];
- const windowSize = computeOptimalWindowSize(inSize);
- const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
- const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
- const output = this.compileAndRun(program, inputs, 'int32');
- if (output.rank === x.rank) {
- return this.argReducePacked(x, reduceType, output);
- }
- return output;
- }
- sum(x, axes) {
- assertAxesAreInnerMostDims('sum', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- const outputDType = sumOutType(x.dtype);
- return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
- }
- prod(x, axes) {
- const cpuRes = this.tryRunOnCpuOrThrow([x], () => this.cpuBackend.prod(x, axes));
- if (cpuRes) {
- return cpuRes;
- }
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- const outputDType = sumOutType(x.dtype);
- return this.reduce(a2D, 'prod', outputDType).reshape(outShape);
- }
- unsortedSegmentSum(x, segmentIds, numSegments) {
- let axis = 0;
- const permutation = getAxesPermutation([axis], x.rank);
- let permutedX = x;
- if (permutation != null) {
- permutedX = transpose(x, permutation);
- axis = getInnerMostAxes(1, x.rank)[0];
- }
- const outShape = segment_util$1.computeOutShape(permutedX.shape, axis, numSegments);
- const inSize = sizeFromShape([permutedX.shape[axis]]);
- const a2D = permutedX.as2D(-1, inSize);
- const outputDType = sumOutType(x.dtype);
- let result = this.segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
- .reshape(outShape);
- if (permutation != null) {
- result =
- transpose(result, getUndoAxesPermutation(permutation));
- }
- return result;
- }
- segOpCompute(x, segOpType, segmentIds, dtype, numSegments) {
- const batchSize = x.shape[0];
- const inSize = x.shape[1];
- const windowSize = segment_util$1.segOpComputeOptimalWindowSize(inSize, numSegments);
- const segOpInfo = { windowSize, inSize, batchSize, numSegments };
- const program = new SegmentOpProgram(segOpInfo, segOpType);
- const output = this.compileAndRun(program, [x, segmentIds], dtype);
- // No need to run another GPGPU program.
- if (output.shape[1] === numSegments) {
- return output;
- }
- segmentIds = range(0, numSegments).tile([inSize / windowSize]);
- return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments);
- }
- argMinMaxReduce(x, axis, reduceType) {
- const axes = [axis];
- assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.rank);
- if (!env().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) {
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- return this.argReduce(a2D, reduceType).reshape(outShape);
- }
- return this.argReducePacked(x, reduceType);
- }
- argMin(x, axis) {
- return this.argMinMaxReduce(x, axis, 'min');
- }
- argMax(x, axis) {
- return this.argMinMaxReduce(x, axis, 'max');
- }
- cumsum(x, axis, exclusive, reverse) {
- if (axis !== x.rank - 1) {
- throw new Error(`WebGL cumsum shader expects an inner-most axis=${x.rank - 1} ` +
- `but got axis=${axis}`);
- }
- const size = x.shape[axis];
- let result = x;
- // Use cumsum parallel algorithm, ref:
- // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
- for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
- const program = new CumSumProgram(x.shape, false, reverse);
- const customSetup = program.getCustomSetupFunc(i);
- const prevResult = result;
- result = this.compileAndRun(program, [result], result.dtype, customSetup);
- prevResult.dispose();
- }
- // For exclusive cumsum, shift the end result in the direction of sum and
- // add 0 to the front index.
- if (exclusive) {
- const program = new CumSumProgram(x.shape, exclusive, reverse);
- const prevResult = result;
- result = this.compileAndRun(program, [result]);
- prevResult.dispose();
- }
- return result;
- }
- equal(a, b) {
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, EQUAL$1, 'bool');
- }
- const program = new BinaryOpProgram(EQUAL, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- less(a, b) {
- const cpuRes = this.tryRunOnCpuOrThrow([a, b], () => this.cpuBackend.less(a, b));
- if (cpuRes) {
- return cpuRes;
- }
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, LESS$1, 'bool');
- }
- const program = new BinaryOpProgram(LESS, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- lessEqual(a, b) {
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, LESS_EQUAL$1, 'bool');
- }
- const program = new BinaryOpProgram(LESS_EQUAL, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- greater(a, b) {
- const cpuRes = this.tryRunOnCpuOrThrow([a, b], () => this.cpuBackend.greater(a, b));
- if (cpuRes) {
- return cpuRes;
- }
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, GREATER$1, 'bool');
- }
- const program = new BinaryOpProgram(GREATER, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- greaterEqual(a, b) {
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, GREATER_EQUAL$1, 'bool');
- }
- const program = new BinaryOpProgram(GREATER_EQUAL, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- logicalNot(x) {
- const program = new UnaryOpProgram(x.shape, LOGICAL_NOT);
- return this.compileAndRun(program, [x]);
- }
- logicalAnd(a, b) {
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, LOGICAL_AND$1, 'bool');
- }
- const program = new BinaryOpProgram(LOGICAL_AND, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- logicalOr(a, b) {
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, LOGICAL_OR$1, 'bool');
- }
- const program = new BinaryOpProgram(LOGICAL_OR, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], 'bool');
- }
- select(condition, a, b) {
- const program = new SelectProgram(condition.rank, a.shape, a.rank);
- return this.compileAndRun(program, [condition, a, b], upcastType(a.dtype, b.dtype));
- }
- where(condition) {
- warn('tf.where() in webgl locks the UI thread. ' +
- 'Call tf.whereAsync() instead');
- const condVals = condition.dataSync();
- return whereImpl$2(condition.shape, condVals);
- }
- topk(x, k, sorted) {
- const xVals = x.dataSync();
- return topkImpl$2(xVals, x.shape, x.dtype, k, sorted);
- }
- min(x, axes) {
- assertAxesAreInnerMostDims('min', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape);
- }
- minimum(a, b) {
- const cpuRes = this.tryRunOnCpuOrThrow([a, b], () => this.cpuBackend.minimum(a, b));
- if (cpuRes) {
- return cpuRes;
- }
- const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
- new BinaryOpPackedProgram(MIN$1, a.shape, b.shape) :
- new BinaryOpProgram(MIN, a.shape, b.shape);
- return this.compileAndRun(program, [a, b]);
- }
- mod(a, b) {
- const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
- new BinaryOpPackedProgram(MOD$1, a.shape, b.shape) :
- new BinaryOpProgram(MOD, a.shape, b.shape);
- return this.compileAndRun(program, [a, b]);
- }
- maximum(a, b) {
- const cpuRes = this.tryRunOnCpuOrThrow([a, b], () => this.cpuBackend.maximum(a, b));
- if (cpuRes) {
- return cpuRes;
- }
- const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
- new BinaryOpPackedProgram(MAX$1, a.shape, b.shape) :
- new BinaryOpProgram(MAX, a.shape, b.shape);
- return this.compileAndRun(program, [a, b]);
- }
- all(x, axes) {
- assertAxesAreInnerMostDims('all', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- return this.reduce(a2D, 'all', a2D.dtype).reshape(outShape);
- }
- any(x, axes) {
- assertAxesAreInnerMostDims('any', axes, x.rank);
- const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes);
- const inSize = sizeFromShape(reduceShape);
- const a2D = x.as2D(-1, inSize);
- return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape);
- }
- floorDiv(a, b) {
- const op = INT_DIV;
- const outputDtype = 'int32';
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- return this.packedBinaryOp(a, b, INT_DIV$1, outputDtype);
- }
- const program = new BinaryOpProgram(op, a.shape, b.shape);
- return this.compileAndRun(program, [a, b], outputDtype);
- }
- packedUnaryOp(x, op, dtype) {
- const program = new UnaryOpPackedProgram(x.shape, op);
- return this.compileAndRun(program, [x], dtype);
- }
- packedBinaryOp(a, b, op, dtype, checkOutOfBounds = false) {
- const program = new BinaryOpPackedProgram(op, a.shape, b.shape, checkOutOfBounds);
- return this.compileAndRun(program, [a, b], dtype);
- }
- // Returns a TensorInfo with the complex shape and the dataId of the
- // underlying part. We need to do this because a reshaped complex tensor is
- // not reflected in its parts.
- makeComplexComponentTensorInfo(complexTensor, complexPart) {
- return {
- dataId: complexPart.dataId,
- dtype: complexPart.dtype,
- shape: complexTensor.shape
- };
- }
- addN(tensors) {
- if (tensors.length === 1) {
- return tensors[0];
- }
- // Limit the number of uploaded textures for optimization.
- if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) {
- const midIndex = Math.floor(tensors.length / 2);
- const leftSide = this.addN(tensors.slice(0, midIndex));
- const rightSide = this.addN(tensors.slice(midIndex));
- return this.addN([leftSide, rightSide]);
- }
- const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
- const shapes = tensors.map(t => t.shape);
- // We can make sure shapes are identical in op level.
- const usePackedOp = env().getBool('WEBGL_PACK');
- const program = usePackedOp ?
- new AddNPackedProgram(tensors[0].shape, shapes) :
- new AddNProgram(tensors[0].shape, shapes);
- return this.compileAndRun(program, tensors, dtype);
- }
- pow(a, b) {
- const usePackedOp = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
- const program = usePackedOp ?
- new BinaryOpPackedProgram(POW$1, a.shape, b.shape) :
- new BinaryOpProgram(POW, a.shape, b.shape);
- const dtype = upcastType(a.dtype, b.dtype);
- return this.compileAndRun(program, [a, b], dtype);
- }
- ceil(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = ceilImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, CEIL, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, CEIL);
- return this.compileAndRun(program, [x]);
- }
- floor(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = floorImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, FLOOR, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, FLOOR);
- return this.compileAndRun(program, [x]);
- }
- sign(x) {
- const program = new UnaryOpProgram(x.shape, SIGN);
- return this.compileAndRun(program, [x]);
- }
- isNaN(x) {
- const program = new UnaryOpProgram(x.shape, IS_NAN);
- return this.compileAndRun(program, [x], 'bool');
- }
- isInf(x) {
- const program = new UnaryOpProgram(x.shape, IS_INF);
- return this.compileAndRun(program, [x], 'bool');
- }
- isFinite(x) {
- const program = new UnaryOpProgram(x.shape, IS_FINITE);
- return this.compileAndRun(program, [x], 'bool');
- }
- round(x) {
- const program = new UnaryOpProgram(x.shape, ROUND);
- return this.compileAndRun(program, [x]);
- }
- exp(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = expImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, EXP, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, EXP);
- return this.compileAndRun(program, [x]);
- }
- expm1(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = expm1ImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, EXPM1, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, EXPM1);
- return this.compileAndRun(program, [x]);
- }
- softmax(logits, dim) {
- const axes = parseAxisParam([dim], logits.shape);
- // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
- // modularization.
- const maxLogit = max(logits, axes);
- const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
- // TODO(annxingyuan): Call sub directly as part of softmax kernel
- // modularization.
- const a = sub(logits, maxLogit.reshape(expandedShape));
- const b = this.exp(a);
- const sumExp = this.sum(b, axes).reshape(expandedShape);
- // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel
- // modularization.
- return div(b, sumExp);
- }
- log(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = logImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, LOG$1, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, LOG);
- return this.compileAndRun(program, [x]);
- }
- log1p(x) {
- const program = new UnaryOpProgram(x.shape, LOG1P);
- return this.compileAndRun(program, [x]);
- }
- sqrt(x) {
- const program = new UnaryOpProgram(x.shape, SQRT);
- return this.compileAndRun(program, [x]);
- }
- rsqrt(x) {
- if (this.shouldExecuteOnCPU([x])) {
- const outValues = rsqrtImplCPU(this.texData.get(x.dataId).values, x.dtype);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- const program = new UnaryOpProgram(x.shape, RSQRT);
- return this.compileAndRun(program, [x]);
- }
- reciprocal(x) {
- const program = new UnaryOpProgram(x.shape, RECIPROCAL);
- return this.compileAndRun(program, [x]);
- }
- relu(x) {
- let program;
- if (env().getBool('WEBGL_PACK')) {
- program = new UnaryOpPackedProgram(x.shape, RELU$1);
- }
- else {
- program = new UnaryOpProgram(x.shape, RELU);
- }
- return this.compileAndRun(program, [x]);
- }
- relu6(x) {
- let program;
- if (env().getBool('WEBGL_PACK')) {
- program = new UnaryOpPackedProgram(x.shape, RELU6$1);
- }
- else {
- program = new UnaryOpProgram(x.shape, RELU6);
- }
- return this.compileAndRun(program, [x]);
- }
- prelu(x, alpha) {
- const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
- new BinaryOpPackedProgram(PRELU$1, x.shape, alpha.shape) :
- new BinaryOpProgram(PRELU, x.shape, alpha.shape);
- return this.compileAndRun(program, [x, alpha]);
- }
- elu(x) {
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, ELU$2, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, ELU$1);
- return this.compileAndRun(program, [x]);
- }
- eluDer(dy, y) {
- const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
- new BinaryOpPackedProgram(ELU_DER$1, dy.shape, y.shape) :
- new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
- return this.compileAndRun(program, [dy, y]);
- }
- selu(x) {
- const program = new UnaryOpProgram(x.shape, SELU);
- return this.compileAndRun(program, [x]);
- }
- clip(x, min, max) {
- let program;
- if (env().getBool('WEBGL_PACK_CLIP')) {
- program = new ClipPackedProgram(x.shape);
- }
- else {
- program = new ClipProgram(x.shape);
- }
- const customSetup = program.getCustomSetupFunc(min, max);
- return this.compileAndRun(program, [x], null, customSetup);
- }
- abs(x) {
- // TODO: handle cases when x is complex.
- if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
- const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
- return this.makeOutput(x.shape, x.dtype, outValues);
- }
- if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
- return this.packedUnaryOp(x, ABS, x.dtype);
- }
- const program = new UnaryOpProgram(x.shape, ABS);
- return this.compileAndRun(program, [x]);
- }
- complexAbs(x) {
- const xData = this.texData.get(x.dataId);
- const program = new ComplexAbsProgram(x.shape);
- const inputs = [
- this.makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
- this.makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
- ];
- return this.compileAndRun(program, inputs);
- }
- sigmoid(x) {
- const program = new UnaryOpProgram(x.shape, SIGMOID);
- return this.compileAndRun(program, [x]);
- }
- softplus(x) {
- const program = new UnaryOpProgram(x.shape, SOFTPLUS);
- return this.compileAndRun(program, [x]);
- }
- asin(x) {
- const program = new UnaryOpProgram(x.shape, ASIN);
- return this.compileAndRun(program, [x]);
- }
- acos(x) {
- const program = new UnaryOpProgram(x.shape, ACOS);
- return this.compileAndRun(program, [x]);
- }
- atan(x) {
- const program = new UnaryOpProgram(x.shape, ATAN);
- return this.compileAndRun(program, [x]);
- }
- sinh(x) {
- const program = new UnaryOpProgram(x.shape, SINH);
- return this.compileAndRun(program, [x]);
- }
- cosh(x) {
- const program = new UnaryOpProgram(x.shape, COSH);
- return this.compileAndRun(program, [x]);
- }
- tanh(x) {
- const program = new UnaryOpProgram(x.shape, TANH);
- return this.compileAndRun(program, [x]);
- }
- asinh(x) {
- const program = new UnaryOpProgram(x.shape, ASINH);
- return this.compileAndRun(program, [x]);
- }
- acosh(x) {
- const program = new UnaryOpProgram(x.shape, ACOSH);
- return this.compileAndRun(program, [x]);
- }
- atanh(x) {
- const program = new UnaryOpProgram(x.shape, ATANH);
- return this.compileAndRun(program, [x]);
- }
- erf(x) {
- const program = new UnaryOpProgram(x.shape, ERF);
- return this.compileAndRun(program, [x]);
- }
- step(x, alpha) {
- const program = new UnaryOpProgram(x.shape, STEP(alpha));
- return this.compileAndRun(program, [x]);
- }
- conv2dByMatMul(x, filter, convInfo, bias, activation, preluActivationWeights) {
- // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
- // result from 2D to 4D.
- const xShape = x.shape;
- const xTexData = this.texData.get(x.dataId);
- const sharedMatMulDim = convInfo.inChannels;
- const outerShapeX = xShape[0] * xShape[1] * xShape[2];
- const outerShapeFilter = convInfo.outChannels;
- const isChannelsLast = convInfo.dataFormat === 'channelsLast';
- const transposeA = false;
- const transposeB = false;
- // TODO: Once reduction ops are packed, batchMatMul will always be packed
- // and we can remove this condition.
- const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
- sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
- const reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked;
- if (batchMatMulWillBeUnpacked || !env().getBool('WEBGL_LAZILY_UNPACK') ||
- !env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ||
- !reshapeWillBeExpensive) {
- const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] :
- xShape[0] * xShape[2] * xShape[3];
- const xReshaped = reshape(x, [1, targetShape, convInfo.inChannels]);
- const filterReshaped = reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
- const result = this.fusedBatchMatMul({
- a: xReshaped,
- b: filterReshaped,
- transposeA,
- transposeB,
- bias,
- activation,
- preluActivationWeights
- });
- return reshape(result, convInfo.outShape);
- }
- // Following optimization is specific to packed |x| with odd row count
- // (For example, in channelLast mode, 'row count' refers to x.shape[2]):
- // we avoid expensive packed 2x2 reshape by padding row count to next,
- // even number. When x.shape[2] is odd, the result of packed batchMatMul is
- // the same (has the same texture layout and and values in the texture) as
- // it is for even x.shape[2] + 1. We make the odd-rows tensor to look like
- // even-rows tensor before the operation and, after the batchMatMul,
- // fix the even-rows result to have odd number of rows.
- const targetShape = isChannelsLast ?
- xShape[0] * xShape[1] * (xShape[2] + 1) :
- xShape[0] * xShape[2] * (xShape[3] + 1);
- const xReshaped = {
- dataId: x.dataId,
- shape: [1, targetShape, convInfo.inChannels],
- dtype: x.dtype
- };
- // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
- // Decrementing row count, after batchMatMul->...->compileProgram leads to
- // invalid row count within the reference in GPGPUBinary.inShapeInfos.
- // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
- // in compileProgram method, but that would affect compilation of all
- // programs - instead, provide a copy here, with even row count, before
- // calling batchMatMul->...->compileProgram and after that, the original
- // xTexData.shape is restored.
- const originalXTexDataShape = xTexData.shape;
- xTexData.shape = xTexData.shape.slice();
- xTexData.shape[xTexData.shape.length - 2]++;
- assert(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`);
- const filterReshaped = reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
- const pointwiseConv = this.fusedBatchMatMul({
- a: xReshaped,
- b: filterReshaped,
- transposeA,
- transposeB,
- bias,
- activation,
- preluActivationWeights
- });
- const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
- assert(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed');
- // Restore the input shape to original.
- xTexData.shape = originalXTexDataShape;
- // Set the output shape - there is no need for expensive reshape as data
- // layout is already correct.
- pointwiseConvTexData.shape = convInfo.outShape;
- return engine().makeTensorFromDataId(pointwiseConv.dataId, convInfo.outShape, pointwiseConv.dtype);
- }
- conv2dWithIm2Row(x, filter, convInfo, bias, activation, preluActivationWeights) {
- // Rearranges conv2d input so each block to be convolved over forms the
- // column of a new matrix with shape [filterWidth * filterHeight *
- // inChannels, outHeight * outWidth]. The filter is also rearranged so each
- // output channel forms a row of a new matrix with shape [outChannels,
- // filterWidth * filterHeight * inChannels]. The convolution is then
- // computed by multiplying these matrices and reshaping the result.
- const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo;
- const isChannelsLast = dataFormat === 'channelsLast';
- const sharedDim = filterWidth * filterHeight * inChannels;
- const numCols = outHeight * outWidth;
- const x2ColShape = [sharedDim, numCols];
- const transposeA = true;
- const transposeB = false;
- const xSqueezed = x.squeeze([0]);
- const w2Row = filter.reshape([1, sharedDim, -1]);
- const im2ColProgram = new Im2ColPackedProgram(x2ColShape, xSqueezed.shape, convInfo);
- const im2Col = this.compileAndRun(im2ColProgram, [xSqueezed]).reshape([
- 1, x2ColShape[0], x2ColShape[1]
- ]);
- const hasBias = bias != null;
- const hasPreluActivationWeights = preluActivationWeights != null;
- const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
- const matmulProgram = new MatMulPackedProgram(im2Col.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
- const inputs = [im2Col, w2Row];
- if (bias) {
- inputs.push(bias);
- }
- if (hasPreluActivationWeights) {
- inputs.push(preluActivationWeights);
- }
- const product = this.compileAndRun(matmulProgram, inputs);
- if (isChannelsLast) {
- return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
- }
- else {
- return product.reshape([1, convInfo.outChannels, outHeight, outWidth]);
- }
- }
- fusedConv2d({ input, filter, convInfo, bias, activation, preluActivationWeights }) {
- if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
- convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
- convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
- (convInfo.padInfo.type === 'SAME' ||
- convInfo.padInfo.type === 'VALID')) {
- return this.conv2dByMatMul(input, filter, convInfo, bias, activation, preluActivationWeights);
- }
- if (env().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) {
- return this.conv2dWithIm2Row(input, filter, convInfo, bias, activation, preluActivationWeights);
- }
- const hasBias = bias != null;
- const hasPreluActivationWeights = preluActivationWeights != null;
- const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
- const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
- const inputs = [input, filter];
- if (bias) {
- inputs.push(bias);
- }
- if (preluActivationWeights) {
- inputs.push(preluActivationWeights);
- }
- return this.compileAndRun(program, inputs);
- }
- conv2d(x, filter, convInfo) {
- if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
- convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
- convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
- (convInfo.padInfo.type === 'SAME' ||
- convInfo.padInfo.type === 'VALID')) {
- return this.conv2dByMatMul(x, filter, convInfo);
- }
- if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
- return this.conv2dWithIm2Row(x, filter, convInfo);
- }
- const program = new Conv2DProgram(convInfo);
- return this.compileAndRun(program, [x, filter]);
- }
- conv2dDerInput(dy, filter, convInfo) {
- const program = new Conv2DDerInputProgram(convInfo);
- return this.compileAndRun(program, [dy, filter]);
- }
- conv2dDerFilter(x, dy, convInfo) {
- const program = new Conv2DDerFilterProgram(convInfo);
- return this.compileAndRun(program, [x, dy]);
- }
- fusedDepthwiseConv2D({ input, filter, convInfo, bias, activation, preluActivationWeights }) {
- const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
- convInfo.strideWidth <= 2 &&
- convInfo.outChannels / convInfo.inChannels === 1;
- const fusedActivation = activation ?
- mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
- null;
- const inputs = [input, filter];
- const hasBias = bias != null;
- const hasPreluActivationWeights = preluActivationWeights != null;
- if (hasBias) {
- inputs.push(bias);
- }
- if (hasPreluActivationWeights) {
- inputs.push(preluActivationWeights);
- }
- let program;
- if (shouldPackDepthwiseConv) {
- program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
- return this.compileAndRun(program, inputs);
- }
- program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
- return this.compileAndRun(program, inputs);
- }
- depthwiseConv2D(x, filter, convInfo) {
- let program;
- if (env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
- convInfo.strideWidth <= 2 &&
- convInfo.outChannels / convInfo.inChannels === 1) {
- program = new DepthwiseConvPacked2DProgram(convInfo);
- return this.compileAndRun(program, [x, filter]);
- }
- program = new DepthwiseConv2DProgram(convInfo);
- return this.compileAndRun(program, [x, filter]);
- }
- depthwiseConv2DDerInput(dy, filter, convInfo) {
- const program = new DepthwiseConv2DDerInputProgram(convInfo);
- return this.compileAndRun(program, [dy, filter]);
- }
- depthwiseConv2DDerFilter(x, dy, convInfo) {
- const program = new DepthwiseConv2DDerFilterProgram(convInfo);
- return this.compileAndRun(program, [x, dy]);
- }
- conv3d(x, filter, convInfo) {
- const program = new Conv3DProgram(convInfo);
- return this.compileAndRun(program, [x, filter]);
- }
- conv3dDerInput(dy, filter, convInfo) {
- const program = new Conv3DDerInputProgram(convInfo);
- return this.compileAndRun(program, [dy, filter]);
- }
- conv3dDerFilter(x, dy, convInfo) {
- const program = new Conv3DDerFilterProgram(convInfo);
- return this.compileAndRun(program, [x, dy]);
- }
- unstack(x, axis) {
- const num = x.shape[axis];
- const outShape = new Array(x.rank - 1);
- let outIndex = 0;
- for (let i = 0; i < x.rank; i++) {
- if (i !== axis) {
- outShape[outIndex++] = x.shape[i];
- }
- }
- const begin = new Array(x.rank).fill(0);
- const size = x.shape.slice();
- size[axis] = 1;
- const res = new Array(num);
- for (let i = 0; i < res.length; i++) {
- begin[axis] = i;
- res[i] = this.slice(x, begin, size).reshape(outShape);
- }
- return res;
- }
- avgPool3d(x, convInfo) {
- const program = new Pool3DProgram(convInfo, 'avg', false);
- return this.compileAndRun(program, [x], 'float32');
- }
- avgPool3dBackprop(dy, x, convInfo) {
- const avgPool3dBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
- return this.compileAndRun(avgPool3dBackpropProgram, [dy], x.dtype);
- }
- maxPool3d(x, convInfo) {
- const program = new Pool3DProgram(convInfo, 'max', false);
- return this.compileAndRun(program, [x], 'float32');
- }
- maxPool3dBackprop(dy, x, y, convInfo) {
- const getPositions = true;
- const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', getPositions);
- const maxPool3dPositions = this.compileAndRun(maxPool3dPositionsProgram, [x]);
- const maxPool3dBackPropProgram = new MaxPool3DBackpropProgram(convInfo);
- const result = this.compileAndRun(maxPool3dBackPropProgram, [dy, maxPool3dPositions], x.dtype);
- maxPool3dPositions.dispose();
- return result;
- }
- resizeBilinear(x, newHeight, newWidth, alignCorners) {
- const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
- new ResizeBilinearPackedProgram(x.shape, newHeight, newWidth, alignCorners) :
- new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);
- return this.compileAndRun(program, [x], 'float32');
- }
- resizeBilinearBackprop(dy, x, alignCorners) {
- const program = new ResizeBilinearBackpropProgram(dy, x, alignCorners);
- return this.compileAndRun(program, [dy]);
- }
- resizeNearestNeighbor(x, newHeight, newWidth, alignCorners) {
- const program = new ResizeNearestNeighborProgram(x.shape, newHeight, newWidth, alignCorners);
- return this.compileAndRun(program, [x]);
- }
- resizeNearestNeighborBackprop(dy, x, alignCorners) {
- const program = new ResizeNearestNeigborBackpropProgram(dy, x, alignCorners);
- return this.compileAndRun(program, [dy]);
- }
- multinomial(logits, normalized, numSamples, seed) {
- const probs = normalized ? logits : softmax(logits);
- const batchSize = probs.shape[0];
- const numOutcomes = probs.shape[1];
- const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
- const customSetup = program.getCustomSetupFunc(seed);
- return this.compileAndRun(program, [probs], 'int32', customSetup);
- }
- oneHot(indices, depth, onValue, offValue) {
- const program = new OneHotProgram(indices.size, depth, onValue, offValue);
- return this.compileAndRun(program, [indices]);
- }
- diag(x) {
- const program = new DiagProgram(x.size);
- return this.compileAndRun(program, [x]);
- }
- cropAndResize(image, boxes, boxIndex, cropSize, method, extrapolationValue) {
- const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
- return this.compileAndRun(program, [image, boxes, boxIndex], 'float32');
- }
- depthToSpace(x, blockSize, dataFormat) {
- assert(blockSize > 1, () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);
- const batchSize = x.shape[0];
- const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
- const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
- const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
- const outputHeight = inputHeight * blockSize;
- const outputWidth = inputWidth * blockSize;
- const outputDepth = inputDepth / (blockSize * blockSize);
- const outputShape = (dataFormat === 'NHWC') ?
- [batchSize, outputHeight, outputWidth, outputDepth] :
- [batchSize, outputDepth, outputHeight, outputWidth];
- const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
- return this.compileAndRun(program, [x]);
- }
- split(x, sizeSplits, axis) {
- return split$5(x, sizeSplits, axis);
- }
- scatterND(indices, updates, shape) {
- const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
- const flattenShape = [outputSize / sliceSize, sliceSize];
- const flattenIndices = indices.reshape([numUpdates, sliceRank]);
- const flattenX = updates.reshape([numUpdates, sliceSize]);
- if (outputSize === 0) {
- return reshapeTensor(tensor([]), shape);
- }
- const defaultValue = scalar(0);
- const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.rank, flattenX.rank, strides, flattenShape);
- const res = this.compileAndRun(program, [flattenX, flattenIndices, defaultValue]);
- return res.reshape(shape);
- }
- sparseToDense(sparseIndices, sparseValues, outputShape, defaultValue) {
- const { sliceRank, numUpdates, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
- const sumDupeIndices = false;
- const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.rank, sparseValues.rank, strides, [outputSize, 1], sumDupeIndices);
- const res = this.compileAndRun(program, [sparseValues, sparseIndices, defaultValue]);
- return res.reshape(outputShape);
- }
- gatherND(x, indices) {
- const indicesShape = indices.shape;
- const sliceRank = indicesShape[indicesShape.length - 1];
- const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(x, indices);
- const flattenIndices = indices.reshape([numSlices, sliceRank]);
- const flattenX = x.reshape([x.size / sliceSize, sliceSize]);
- const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
- const res = this.compileAndRun(program, [flattenX, flattenIndices]);
- return res.reshape(resultShape);
- }
- fill(shape, value, dtype) {
- dtype = dtype || inferDtype(value);
- if (dtype === 'string') {
- // String type should be handled in CPU memory.
- const values = getArrayFromDType(dtype, sizeFromShape(shape));
- values.fill(value);
- return engine().makeTensor(values, shape, dtype, this);
- }
- else {
- const program = new FillProgram(shape, value);
- const customSetup = program.getCustomSetupFunc(value);
- return this.compileAndRun(program, [], dtype, customSetup);
- }
- }
- onesLike(x) {
- if (x.dtype === 'string') {
- throw new Error('onesLike is not supported under string dtype');
- }
- else {
- // TODO(cais, smilkov): Add WebGL shader for onesLike:
- // https://github.com/tensorflow/tfjs/issues/1293
- return this.fill(x.shape, 1, x.dtype);
- }
- }
- zerosLike(x) {
- return this.fill(x.shape, x.dtype === 'string' ? '' : 0, x.dtype);
- }
- linspace(start, stop, num) {
- // TODO: Use CPU implementation due to the precision problem in Safari.
- return linspaceImpl(start, stop, num);
- }
- makeTensorInfo(shape, dtype, values) {
- const dataId = this.write(values, shape, dtype);
- this.texData.get(dataId).usage = null;
- return { dataId, shape, dtype };
- }
- makeOutput(shape, dtype, values) {
- const { dataId } = this.makeTensorInfo(shape, dtype, values);
- return engine().makeTensorFromDataId(dataId, shape, dtype, this);
- }
- unpackTensor(input) {
- const program = new UnpackProgram(input.shape);
- return this.runWebGLProgram(program, [input], input.dtype);
- }
- packTensor(input) {
- const program = new PackProgram(input.shape);
- const preventEagerUnpackingOutput = true;
- return this.runWebGLProgram(program, [input], input.dtype, null /* customSetup */, preventEagerUnpackingOutput);
- }
- packedReshape(input, afterShape) {
- const input3DShape = [
- getBatchDim(input.shape),
- ...getRowsCols(input.shape)
- ];
- const input3D = {
- dtype: input.dtype,
- shape: input3DShape,
- dataId: input.dataId
- };
- const afterShapeAs3D = [
- getBatchDim(afterShape), ...getRowsCols(afterShape)
- ];
- const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
- const preventEagerUnpackingOfOutput = true;
- const output = this.runWebGLProgram(program, [input3D], input.dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
- return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
- }
- decode(dataId) {
- const texData = this.texData.get(dataId);
- const { isPacked, shape, dtype } = texData;
- const shapeAs3D = getShapeAs3D(shape);
- let program;
- if (isPacked) {
- program = new DecodeMatrixPackedProgram(shapeAs3D);
- }
- else {
- program = new DecodeMatrixProgram(shapeAs3D);
- }
- const preventEagerUnpackingOfOutput = true;
- const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
- return { dtype, shape, dataId: out.dataId };
- }
- runWebGLProgram(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput = false) {
- const output = this.makeTensorInfo(program.outputShape, outputDtype);
- const outData = this.texData.get(output.dataId);
- if (program.packedOutput) {
- outData.isPacked = true;
- }
- if (program.outPackingScheme === PackingScheme.DENSE) {
- const texelShape = getDenseTexShape(program.outputShape);
- // For a densely packed output, we explicitly set texShape
- // so it doesn't get assigned later according to our typical packing
- // scheme wherein a single texel can only contain values from adjacent
- // rows/cols.
- outData.texShape = texelShape.map(d => d * 2);
- }
- if (program.outTexUsage != null) {
- outData.usage = program.outTexUsage;
- }
- if (sizeFromShape(output.shape) === 0) {
- // Short-circuit the computation since the result is empty (has 0 in its
- // shape).
- outData.values =
- getTypedArrayFromDType(output.dtype, 0);
- return output;
- }
- const dataToDispose = [];
- const inputsData = inputs.map(input => {
- if (input.dtype === 'complex64') {
- throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` +
- `dtypes, please separate the program into real and imaginary ` +
- `parts.`);
- }
- let texData = this.texData.get(input.dataId);
- if (texData.texture == null) {
- if (!program.packedInputs &&
- sizeFromShape(input.shape) <=
- env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
- // Upload small tensors that live on the CPU as uniforms, not as
- // textures. Do this only when the environment supports 32bit floats
- // due to problems when comparing 16bit floats with 32bit floats.
- // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
- // possible for packed shaders to sample from uniforms.
- return {
- shape: input.shape,
- texData: null,
- isUniform: true,
- uniformValues: texData.values
- };
- }
- // This ensures that if a packed program's inputs have not yet been
- // uploaded to the GPU, they get uploaded as packed right off the bat.
- if (program.packedInputs) {
- texData.isPacked = true;
- texData.shape = input.shape;
- }
- }
- else if (!!texData.isPacked !== !!program.packedInputs) {
- input = texData.isPacked ? this.unpackTensor(input) :
- this.packTensor(input);
- dataToDispose.push(input);
- texData = this.texData.get(input.dataId);
- }
- else if (texData.isPacked &&
- !isReshapeFree(texData.shape, input.shape)) {
- // This is a special case where a texture exists for a tensor
- // but the shapes are incompatible (due to packing constraints) because
- // the tensor did not have a chance to go through the packed reshape
- // shader. This only happens when we reshape the *same* tensor to form
- // *distinct* inputs to an op, e.g. dotting a vector with itself. This
- // case will disappear once packed uploading is the default.
- const savedInput = input;
- const targetShape = input.shape;
- input.shape = texData.shape;
- input = this.packedReshape(input, targetShape);
- dataToDispose.push(input);
- texData = this.texData.get(input.dataId);
- savedInput.shape = targetShape;
- }
- this.uploadToGPU(input.dataId);
- return { shape: input.shape, texData, isUniform: false };
- });
- this.uploadToGPU(output.dataId);
- const outputData = { shape: output.shape, texData: outData, isUniform: false };
- const key = makeShaderKey(program, inputsData, outputData);
- const binary = this.getAndSaveBinary(key, () => {
- return compileProgram(this.gpgpu, program, inputsData, outputData);
- });
- const shouldTimeProgram = this.activeTimers != null;
- let query;
- if (shouldTimeProgram) {
- query = this.startTimer();
- }
- runProgram(this.gpgpu, binary, inputsData, outputData, customSetup);
- dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
- if (shouldTimeProgram) {
- query = this.endTimer(query);
- this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
- }
- if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
- preventEagerUnpackingOfOutput === false) {
- const unpacked = this.unpackTensor(output);
- this.disposeIntermediateTensorInfo(output);
- return unpacked;
- }
- return output;
- }
- compileAndRun(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput = false) {
- outputDtype = outputDtype || inputs[0].dtype;
- const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput);
- return engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
- }
- getAndSaveBinary(key, getBinary) {
- if (!(key in this.binaryCache)) {
- this.binaryCache[key] = getBinary();
- }
- return this.binaryCache[key];
- }
- getTextureManager() {
- return this.textureManager;
- }
- dispose() {
- if (this.disposed) {
- return;
- }
- // Avoid disposing the compiled webgl programs during unit testing because
- // it slows down test execution.
- if (!env().getBool('IS_TEST')) {
- const allKeys = Object.keys(this.binaryCache);
- allKeys.forEach(key => {
- this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
- delete this.binaryCache[key];
- });
- }
- this.textureManager.dispose();
- if (this.canvas != null &&
- (typeof (HTMLCanvasElement) !== 'undefined' &&
- this.canvas instanceof HTMLCanvasElement)) {
- this.canvas.remove();
- }
- else {
- this.canvas = null;
- }
- if (this.gpgpuCreatedLocally) {
- this.gpgpu.program = null;
- this.gpgpu.dispose();
- }
- this.disposed = true;
- }
- floatPrecision() {
- if (this.floatPrecisionValue == null) {
- this.floatPrecisionValue = tidy(() => {
- if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
- // Momentarily switching DEBUG flag to false so we don't throw an
- // error trying to upload a small value.
- const debugFlag = env().getBool('DEBUG');
- env().set('DEBUG', false);
- const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
- env().set('DEBUG', debugFlag);
- if (underflowCheckValue > 0) {
- return 32;
- }
- }
- return 16;
- });
- }
- return this.floatPrecisionValue;
- }
- /** Returns the smallest representable number. */
- epsilon() {
- return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
- }
- uploadToGPU(dataId) {
- const texData = this.texData.get(dataId);
- const { shape, dtype, values, texture, usage, isPacked } = texData;
- if (texture != null) {
- // Array is already on GPU. No-op.
- return;
- }
- const shouldTimeProgram = this.activeTimers != null;
- let start;
- if (shouldTimeProgram) {
- start = now();
- }
- let texShape = texData.texShape;
- if (texShape == null) {
- texShape = getTextureShapeFromLogicalShape(shape, isPacked);
- texData.texShape = texShape;
- }
- if (values != null) {
- const shapeAs3D = getShapeAs3D(shape);
- let program;
- let width = texShape[1], height = texShape[0];
- const isByteArray = values instanceof Uint8Array;
- if (isPacked) {
- [width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
- program = new EncodeMatrixPackedProgram(shapeAs3D, [height, width], isByteArray);
- }
- else {
- program =
- new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray);
- }
- const tempDenseInputHandle = this.makeTensorInfo([height, width], dtype);
- if (isByteArray) {
- this.texData.get(tempDenseInputHandle.dataId).usage =
- TextureUsage.PIXELS;
- }
- else {
- this.texData.get(tempDenseInputHandle.dataId).usage =
- TextureUsage.UPLOAD;
- }
- this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
- // We want the output to remain packed regardless of the value of
- // WEBGL_PACK.
- const preventEagerUnpacking = true;
- const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, null, preventEagerUnpacking);
- // Have the original texture assume the identity of the encoded output.
- const outputTexData = this.texData.get(encodedOutputTarget.dataId);
- texData.texture = outputTexData.texture;
- texData.texShape = outputTexData.texShape;
- texData.isPacked = outputTexData.isPacked;
- texData.usage = outputTexData.usage;
- this.disposeIntermediateTensorInfo(tempDenseInputHandle);
- this.texData.delete(encodedOutputTarget.dataId);
- // Once uploaded, don't store the values on cpu.
- texData.values = null;
- if (shouldTimeProgram) {
- this.uploadWaitMs += now() - start;
- }
- }
- else {
- const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
- texData.texture = newTexture;
- }
- }
- convertAndCacheOnCPU(dataId, float32Values) {
- const texData = this.texData.get(dataId);
- const { dtype } = texData;
- this.releaseGPUData(dataId);
- if (float32Values != null) {
- texData.values = float32ToTypedArray(float32Values, dtype);
- }
- return texData.values;
- }
- acquireTexture(texShape, texType, dtype, isPacked) {
- this.numBytesInGPU += this.computeBytes(texShape, dtype);
- if (!this.warnedAboutMemory &&
- this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
- const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
- this.warnedAboutMemory = true;
- console.warn(`High memory usage in GPU: ${mb} MB, ` +
- `most likely due to a memory leak`);
- }
- return this.textureManager.acquireTexture(texShape, texType, isPacked);
- }
- computeBytes(shape, dtype) {
- return shape[0] * shape[1] * bytesPerElement(dtype);
- }
- tryRunOnCpuOrThrow(inputs, fn) {
- if (this.shouldExecuteOnCPU(inputs)) {
- try {
- return fn();
- }
- catch (e) {
- if (env().getBool('IS_TEST')) {
- throw new Error('CPU forwarding failed');
- }
- }
- }
- return null;
- }
- }
- function float32ToTypedArray(a, dtype) {
- if (dtype === 'float32' || dtype === 'complex64') {
- return a;
- }
- else if (dtype === 'int32' || dtype === 'bool') {
- const result = (dtype === 'int32') ? new Int32Array(a.length) :
- new Uint8Array(a.length);
- for (let i = 0; i < result.length; ++i) {
- result[i] = Math.round(a[i]);
- }
- return result;
- }
- else {
- throw new Error(`Unknown dtype ${dtype}`);
- }
- }
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$5 = '0.0.0';
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Enforce use of half precision textures if available on the platform.
- *
- * @doc {heading: 'Environment', namespace: 'webgl'}
- */
- function forceHalfFloat() {
- env().set('WEBGL_FORCE_F16_TEXTURES', true);
- }
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- if (isBrowser()) {
- registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */);
- }
- const webgl = { forceHalfFloat };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function identity$2(args) {
- const { inputs, backend } = args;
- const { x } = inputs;
- backend.incRef(x.dataId);
- return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
- }
- const identityConfig$1 = {
- kernelName: Identity,
- backendName: 'webgl',
- kernelFunc: identity$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * In WebGL data is stored in GPU textures which can't be efficiently copied, so
- * complex tensors share data with their real and imaginary components. Complex
- * tensors increment the `complexParentRefCount` properties of the underlying
- * data buckets to prevent them from being disposed, as the engine's disposal
- * logic does not account for data sharing by complex tensors.
- *
- * When a complex tensor is disposed, it will explicitly decrease the
- * `complexParentRefCount` properties of its underlying components.
- */
- function complex$2(args) {
- const { inputs, backend } = args;
- const { real, imag } = inputs;
- const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
- const complex = backend.texData.get(complexInfo.dataId);
- const realTensorInfo = identity$2({ inputs: { x: real }, backend });
- const realData = backend.texData.get(realTensorInfo.dataId);
- realData.complexParentRefCount++;
- const imagTensorInfo = identity$2({ inputs: { x: imag }, backend });
- const imagData = backend.texData.get(imagTensorInfo.dataId);
- imagData.complexParentRefCount++;
- complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
- return complexInfo;
- }
- const complexConfig$1 = {
- kernelName: Complex,
- backendName: 'webgl',
- kernelFunc: complex$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`;
- const CHECK_NAN_SNIPPET_BINARY = `
- if (isnan(a)) return a;
- if (isnan(b)) return b;
-`;
- const CHECK_NAN_SNIPPET_BINARY_PACKED = `
- result.r = isNaN.r > 0. ? NAN : result.r;
- result.g = isNaN.g > 0. ? NAN : result.g;
- result.b = isNaN.b > 0. ? NAN : result.b;
- result.a = isNaN.a > 0. ? NAN : result.a;
-`;
- /**
- * Template that creates a `KernelFunc` for unary ops.
- * @param opSnippets Op snippet to create `UnaryOpProgram`.
- */
- function unaryKernelFunc$1(opSnippet) {
- return ({ inputs, backend }) => {
- const { x } = inputs;
- const webglBackend = backend;
- const program = new UnaryOpProgram(x.shape, opSnippet);
- return webglBackend.runWebGLProgram(program, [x], x.dtype);
- };
- }
- /**
- * Template that creates a `KernelFunc` for binary ops.
- * @param opSnippet Op snippet to create `BinaryOpProgram`.
- * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
- * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
- * when creating BinaryOpPackedProgram.
- * @param dtype Optional. If set, the result has this dtype. Otherwise, the
- * result has the same dtype as the first input. This is mainly used in
- * comparison kernels, such as Equal, Less, Greater, etc.
- */
- function binaryKernelFunc$1({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) {
- return ({ inputs, backend }) => {
- const { a, b } = inputs;
- const webglBackend = backend;
- if (supportsComplex && a.dtype === 'complex64') {
- const aData = webglBackend.texData.get(a.dataId);
- const bData = webglBackend.texData.get(b.dataId);
- const [real, imag] = [
- [aData.complexTensorInfos.real, bData.complexTensorInfos.real],
- [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
- ].map(complexParts => {
- const [aPart, bPart] = complexParts;
- const aHandle = {
- dataId: aPart.dataId,
- dtype: aPart.dtype,
- shape: a.shape
- };
- const bHandle = {
- dataId: bPart.dataId,
- dtype: bPart.dtype,
- shape: b.shape
- };
- const program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
- return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
- });
- const complexOutput = complex$2({ inputs: { real, imag }, backend: webglBackend });
- webglBackend.disposeIntermediateTensorInfo(real);
- webglBackend.disposeIntermediateTensorInfo(imag);
- // TODO(annxingyuan): Implement CPU forwarding for complex inputs.
- return complexOutput;
- }
- const $dtype = dtype || upcastType(a.dtype, b.dtype);
- if (webglBackend.shouldExecuteOnCPU([a, b]) && cpuKernelImpl != null) {
- const aData = webglBackend.texData.get(a.dataId);
- const bData = webglBackend.texData.get(b.dataId);
- const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, aData.values, bData.values, $dtype);
- const out = webglBackend.makeTensorInfo(outShape, $dtype);
- const outData = webglBackend.texData.get(out.dataId);
- outData.values = outValues;
- return out;
- }
- const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
- packedOpSnippet != null;
- let program;
- if (shouldUsePackedProgram) {
- program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
- }
- else {
- program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
- }
- return webglBackend.runWebGLProgram(program, [a, b], $dtype);
- };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ADD = 'return a + b;';
- const addKernelFunc = binaryKernelFunc$1({
- opSnippet: ADD,
- packedOpSnippet: ADD,
- supportsComplex: true,
- cpuKernelImpl: addImplCPU
- });
- const addConfig$1 = {
- kernelName: Add,
- backendName: 'webgl',
- kernelFunc: addKernelFunc
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const ATAN2 = CHECK_NAN_SNIPPET_BINARY + `
- return atan(a, b);
-`;
- const ATAN2_PACKED = `
- vec4 result = atan(a, b);
- vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
- ` +
- CHECK_NAN_SNIPPET_BINARY_PACKED + `
- return result;
-`;
- const atan2$1 = binaryKernelFunc$1({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
- const atan2Config = {
- kernelName: Atan2,
- backendName: 'webgl',
- kernelFunc: atan2$1,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function avgPool$2(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- assertNotComplex$1(x, 'avgPool');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const dilations = 1;
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- return identity$2({ inputs: { x }, backend });
- }
- const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
- return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
- }
- const avgPoolConfig$1 = {
- kernelName: AvgPool,
- backendName: 'webgl',
- kernelFunc: avgPool$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function avgPoolBackprop$2(args) {
- const { inputs, backend, attrs } = args;
- const { dy, input } = inputs;
- const x = input;
- assertNotComplex$1([dy, input], 'avgPoolBackprop');
- const { filterSize, strides, pad } = attrs;
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
- const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
- return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
- }
- const avgPoolBackpropConfig$1 = {
- kernelName: AvgPoolBackprop,
- backendName: 'webgl',
- kernelFunc: avgPoolBackprop$2
- };
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class BatchNormProgram {
- constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
- this.outputShape = [];
- this.variableNames = ['x', 'mean', 'variance'];
- assertAndGetBroadcastShape(xShape, meanShape);
- assertAndGetBroadcastShape(xShape, varianceShape);
- let offsetSnippet = '0.0';
- if (offsetShape != null) {
- assertAndGetBroadcastShape(xShape, offsetShape);
- this.variableNames.push('offset');
- offsetSnippet = 'getOffsetAtOutCoords()';
- }
- let scaleSnippet = '1.0';
- if (scaleShape != null) {
- assertAndGetBroadcastShape(xShape, scaleShape);
- this.variableNames.push('scale');
- scaleSnippet = 'getScaleAtOutCoords()';
- }
- this.outputShape = xShape;
- this.userCode = `
- void main() {
- float x = getXAtOutCoords();
- float mean = getMeanAtOutCoords();
- float variance = getVarianceAtOutCoords();
- float offset = ${offsetSnippet};
- float scale = ${scaleSnippet};
- float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
- setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class BatchNormPackedProgram {
- constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
- this.packedInputs = true;
- this.packedOutput = true;
- this.variableNames = ['x', 'mean', 'variance'];
- assertAndGetBroadcastShape(xShape, meanShape);
- assertAndGetBroadcastShape(xShape, varianceShape);
- let offsetSnippet = 'vec4(0.0)';
- if (offsetShape != null) {
- assertAndGetBroadcastShape(xShape, offsetShape);
- this.variableNames.push('offset');
- offsetSnippet = 'getOffsetAtOutCoords()';
- }
- let scaleSnippet = 'vec4(1.0)';
- if (scaleShape != null) {
- assertAndGetBroadcastShape(xShape, scaleShape);
- this.variableNames.push('scale');
- scaleSnippet = 'getScaleAtOutCoords()';
- }
- this.outputShape = xShape;
- this.userCode = `
- void main() {
- vec4 offset = ${offsetSnippet};
- vec4 scale = ${scaleSnippet};
-
- vec4 x = getXAtOutCoords();
- vec4 mean = getMeanAtOutCoords();
- vec4 variance = getVarianceAtOutCoords();
-
- vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
-
- setOutput((x - mean) * inv + offset);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const batchNorm$2 = ({ inputs, backend, attrs }) => {
- const { x, mean, variance, offset, scale } = inputs;
- assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
- 'equal ranks.');
- assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
- 'equal ranks.');
- assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
- 'equal ranks.');
- let { varianceEpsilon } = attrs;
- if (varianceEpsilon == null) {
- varianceEpsilon = 0.001;
- }
- const finalInputs = [x, mean, variance];
- let offsetShape = null;
- if (offset != null) {
- offsetShape = offset.shape;
- finalInputs.push(offset);
- }
- let scaleShape = null;
- if (scale != null) {
- scaleShape = scale.shape;
- finalInputs.push(scale);
- }
- const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
- new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
- new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
- const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
- return output;
- };
- const batchNormConfig$1 = {
- kernelName: FusedBatchNorm,
- backendName: 'webgl',
- kernelFunc: batchNorm$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const NOT_EQUAL$1 = `return float(a != b);`;
- const notEqual$2 = binaryKernelFunc$1({ opSnippet: NOT_EQUAL$1, dtype: 'bool' });
- const notEqualConfig$1 = {
- kernelName: NotEqual,
- backendName: 'webgl',
- kernelFunc: notEqual$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function real$2(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const inputData = backend.texData.get(input.dataId);
- return identity$2({ inputs: { x: inputData.complexTensorInfos.real }, backend });
- }
- const realConfig$1 = {
- kernelName: Real,
- backendName: 'webgl',
- kernelFunc: real$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const TO_INT = `return float(int(x));`;
- function int(input, backend) {
- const program = new UnaryOpProgram(input.shape, TO_INT);
- const output = backend.runWebGLProgram(program, [input], 'int32');
- return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function cast$3(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { dtype } = attrs;
- // Casting to complex64.
- if (dtype === 'complex64') {
- if (x.dtype === 'complex64') {
- return identity$2({ inputs: { x }, backend });
- }
- // TODO(annxingyuan): Import kernel function once zeros is modularized.
- const zerosTensor = zeros(x.shape);
- const floatX = cast$3({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
- const result = complex$2({ inputs: { real: floatX, imag: zerosTensor }, backend });
- zerosTensor.dispose();
- backend.disposeIntermediateTensorInfo(floatX);
- return result;
- }
- // Casting from complex64
- if (x.dtype === 'complex64') {
- const realPart = real$2({ inputs: { input: x }, backend });
- const result = cast$3({ inputs: { x: realPart }, backend, attrs: { dtype } });
- backend.disposeIntermediateTensorInfo(realPart);
- return result;
- }
- if (!hasEncodingLoss(x.dtype, dtype)) {
- // We don't change the underlying data, since we cast to higher
- // precision.
- const result = identity$2({ inputs: { x }, backend });
- return { dataId: result.dataId, shape: result.shape, dtype };
- }
- if (dtype === 'int32') {
- return int(x, backend);
- }
- if (dtype === 'bool') {
- const zerosTensorInfo = backend.makeTensorInfo([], 'bool');
- const binaryInputs = { a: x, b: zerosTensorInfo };
- const result = notEqual$2({ inputs: binaryInputs, backend });
- backend.disposeIntermediateTensorInfo(zerosTensorInfo);
- return result;
- }
- throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
- }
- const castConfig$1 = {
- kernelName: Cast,
- backendName: 'webgl',
- kernelFunc: cast$3
- };
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ConcatProgram {
- // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
- constructor(shapes) {
- this.outputShape = [];
- this.outputShape = computeOutShape$1(shapes, 1 /* axis */);
- this.variableNames = shapes.map((_, i) => `T${i}`);
- const offsets = new Array(shapes.length - 1);
- offsets[0] = shapes[0][1];
- for (let i = 1; i < offsets.length; i++) {
- offsets[i] = offsets[i - 1] + shapes[i][1];
- }
- const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
- for (let i = 1; i < offsets.length; i++) {
- const shift = offsets[i - 1];
- snippets.push(`else if (yC < ${offsets[i]}) ` +
- `setOutput(getT${i}(yR, yC-${shift}));`);
- }
- const lastIndex = offsets.length;
- const lastShift = offsets[offsets.length - 1];
- snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
- this.userCode = `
- void main() {
- ivec2 coords = getOutputCoords();
- int yR = coords.x;
- int yC = coords.y;
-
- ${snippets.join('\n ')}
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class ConcatPackedProgram {
- constructor(shapes, axis) {
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = [];
- this.outputShape = computeOutShape$1(shapes, axis);
- const shape = this.outputShape;
- const rank = shape.length;
- const dtype = getCoordsDataType(rank);
- const coords = getChannels('coords', rank);
- const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
- this.variableNames = shapes.map((_, i) => `T${i}`);
- const offsets = new Array(shapes.length - 1);
- offsets[0] = shapes[0][axis];
- for (let i = 1; i < offsets.length; i++) {
- offsets[i] = offsets[i - 1] + shapes[i][axis];
- }
- const channel = channels[axis];
- const lastChannels = channels.slice(-2);
- const allChannels = channels.join();
- let getValueSnippet = `if (${channel} < ${offsets[0]}) {
- return getChannel(
- getT0(${allChannels}), vec2(${lastChannels.join()}));
- }`;
- for (let i = 1; i < offsets.length; i++) {
- const shift = offsets[i - 1];
- // Note: the >= comparison below may seem unnecessary given the check
- // above but is needed to workaround branch execution issues on some
- // devices. It makes all the conditions exclusive without relying on
- // execution order.
- getValueSnippet += `
- if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
- return getChannel(
- getT${i}(${shiftedChannels(channels, channel, shift)}),
- vec2(${shiftedChannels(lastChannels, channel, shift)}));
- }`;
- }
- const lastIndex = offsets.length;
- const shift = offsets[offsets.length - 1];
- getValueSnippet += `
- return getChannel(
- getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),
- vec2(${shiftedChannels(lastChannels, channel, shift)}));`;
- this.userCode = `
- float getValue(${channels.map(x => 'int ' + x)}) {
- ${getValueSnippet}
- }
-
- void main() {
- ${dtype} coords = getOutputCoords();
- vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
-
- ${coords[rank - 1]} = ${coords[rank - 1]} + 1;
- if (${coords[rank - 1]} < ${shape[rank - 1]}) {
- result.g = getValue(${coords});
- }
-
- ${coords[rank - 2]} = ${coords[rank - 2]} + 1;
- if (${coords[rank - 2]} < ${shape[rank - 2]}) {
- result.a = getValue(${coords});
- }
-
- ${coords[rank - 1]} = ${coords[rank - 1]} - 1;
- if (${coords[rank - 2]} < ${shape[rank - 2]} &&
- ${coords[rank - 1]} < ${shape[rank - 1]}) {
- result.b = getValue(${coords});
- }
- setOutput(result);
- }
- `;
- }
- }
- /**
- * Return an expression for coordinates into a vector where a given channel
- * will be offset by [shift].
- *
- * @param channels the channels to consider
- * @param channel the channel we want shifted
- * @param shift the amount to subtract from the channel.
- *
- * @returns a string of the form 'x, y-[shift], z' where any one channel can
- * have the shift applied.
- */
- function shiftedChannels(channels, channel, shift) {
- const channelIdx = channels.indexOf(channel);
- const res = channels.map((c, idx) => {
- if (idx === channelIdx) {
- return `${c} - ${shift}`;
- }
- else {
- return c;
- }
- });
- return res.join();
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function imag$2(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- const inputData = backend.texData.get(input.dataId);
- return identity$2({ inputs: { x: inputData.complexTensorInfos.imag }, backend });
- }
- const imagConfig$1 = {
- kernelName: Imag,
- backendName: 'webgl',
- kernelFunc: imag$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function packedReshape(input, afterShape, backend) {
- const input3DShape = [getBatchDim(input.shape),
- ...getRowsCols(input.shape)];
- const input3D = {
- dtype: input.dtype,
- shape: input3DShape,
- dataId: input.dataId
- };
- const afterShapeAs3D = [getBatchDim(afterShape),
- ...getRowsCols(afterShape)];
- const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
- const preventEagerUnpackingOfOutput = true;
- const output = backend.runWebGLProgram(program, [input3D], input.dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
- return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function reshape$3(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- const { shape } = attrs;
- const webglBackend = backend;
- const xSize = sizeFromShape(x.shape);
- const $shape = inferFromImplicitShape(shape, xSize);
- const $xSize = sizeFromShape($shape);
- assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
- `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
- `shape must have the same number of elements.`);
- const xTexData = webglBackend.texData.get(x.dataId);
- if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
- !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
- return packedReshape(x, $shape, webglBackend);
- }
- webglBackend.incRef(x.dataId);
- return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
- }
- const reshapeConfig$1 = {
- kernelName: Reshape,
- backendName: 'webgl',
- kernelFunc: reshape$3
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function concatImpl(inputs, axis, backend) {
- const dtype = inputs[0].dtype;
- if (dtype === 'complex64') {
- const reals = inputs.map((t) => real$2({ inputs: { input: t }, backend }));
- const imags = inputs.map((t) => imag$2({ inputs: { input: t }, backend }));
- const realConcated = concatImpl(reals, axis, backend);
- const imagConcated = concatImpl(imags, axis, backend);
- const result = complex$2({ inputs: { real: realConcated, imag: imagConcated }, backend });
- reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
- imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
- backend.disposeIntermediateTensorInfo(realConcated);
- backend.disposeIntermediateTensorInfo(imagConcated);
- return result;
- }
- if (inputs.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
- const midIndex = Math.floor(inputs.length / 2);
- const leftSide = concatImpl(inputs.slice(0, midIndex), axis, backend);
- const rightSide = concatImpl(inputs.slice(midIndex), axis, backend);
- const result = concatImpl([leftSide, rightSide], axis, backend);
- backend.disposeIntermediateTensorInfo(leftSide);
- backend.disposeIntermediateTensorInfo(rightSide);
- return result;
- }
- if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
- inputs[0].shape.length > 1) {
- const program = new ConcatPackedProgram(inputs.map(t => t.shape), axis);
- return backend.runWebGLProgram(program, inputs, dtype);
- }
- // Any concat of n-dimensional tensors across any axis can be reduced to
- // a concatenation of two-dimensional tensors across the axis 1 by first
- // partitioning the axes of the original tensors into those less than the
- // axis to be concatenated and the rest. Then reshape the tensors
- // into a two-dimensional tensor by collapsing these two sets of axes and
- // concatenate the resulting matrices across the axis 1, finally reshaping
- // the result to have the proper shape.
- const outShape = computeOutShape$1(inputs.map(t => t.shape), axis);
- const tensors2D = inputs.map(x => reshape$3({
- inputs: { x },
- attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] },
- backend
- }));
- const program = new ConcatProgram(tensors2D.map(t => t.shape));
- const result = backend.runWebGLProgram(program, tensors2D, dtype);
- tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r));
- const reshapedResult = reshape$3({ inputs: { x: result }, attrs: { shape: outShape }, backend });
- backend.disposeIntermediateTensorInfo(result);
- return reshapedResult;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function concat$2(args) {
- const { inputs, backend, attrs } = args;
- const { axis } = attrs;
- const $axis = parseAxisParam(axis, inputs[0].shape)[0];
- const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
- if (sizeFromShape(outShape) === 0) {
- return backend.makeTensorInfo(outShape, inputs[0].dtype);
- }
- // Keep only non-empty tensors (ignore tensors with 0 in their shape).
- const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
- if ($inputs.length === 1) {
- return $inputs[0];
- }
- const shapes = $inputs.map(t => t.shape);
- assertParamsConsistent(shapes, $axis);
- return concatImpl($inputs, $axis, backend);
- }
- const concatConfig$1 = {
- kernelName: Concat,
- backendName: 'webgl',
- kernelFunc: concat$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const COS = CHECK_NAN_SNIPPET_UNARY + `
- return cos(x);
-`;
- const cos$2 = unaryKernelFunc$1(COS);
- const cosConfig$1 = {
- kernelName: Cos,
- backendName: 'webgl',
- kernelFunc: cos$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Without the equality check div produces 0.9999 for a = b, which when
- // floored can cause errors.
- const DIV = `
-if (a == b) {
- return 1.0;
-};
-return a / b;`;
- // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
- // On Linux, the vectorized implementation produces NaNs when a and b are 0.
- const DIV_PACKED = `
- // vec4 one = vec4(equal(a, b));
- // return one + (vec4(1.0) - one) * a / b;
- vec4 result = a / b;
- if(a.x == b.x) {
- result.x = 1.;
- }
- if(a.y == b.y) {
- result.y = 1.;
- }
- if(a.z == b.z) {
- result.z = 1.;
- }
- if(a.w == b.w) {
- result.w = 1.;
- }
-
- return result;
-`;
- const div$2 = binaryKernelFunc$1({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
- const divConfig$1 = {
- kernelName: Div,
- backendName: 'webgl',
- kernelFunc: div$2,
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class FFTProgram {
- constructor(component, inputShape, inverse) {
- this.variableNames = ['real', 'imag'];
- const innerDim = inputShape[1];
- this.outputShape = inputShape;
- const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;
- const resultDenominator = inverse ? `${innerDim}.0` : '1.0';
- let opString;
- if (component === 'real') {
- opString = 'return real * expR - imag * expI;';
- }
- else if (component === 'imag') {
- opString = 'return real * expI + imag * expR;';
- }
- else {
- throw new Error(`FFT component must be either "real" or "imag", got ${component}.`);
- }
- this.userCode = `
- const float exponentMultiplier = ${exponentMultiplierSnippet};
-
- float unaryOpComplex(float real, float expR, float imag, float expI) {
- ${opString}
- }
-
- float mulMatDFT(int batch, int index) {
- float indexRatio = float(index) / float(${innerDim});
- float exponentMultiplierTimesIndexRatio =
- exponentMultiplier * indexRatio;
-
- float result = 0.0;
-
- for (int i = 0; i < ${innerDim}; i++) {
- // x = (-2|2 * PI / N) * index * i;
- float x = exponentMultiplierTimesIndexRatio * float(i);
- float expR = cos(x);
- float expI = sin(x);
- float real = getReal(batch, i);
- float imag = getImag(batch, i);
-
- result +=
- unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
- }
-
- return result;
- }
-
- void main() {
- ivec2 coords = getOutputCoords();
- setOutput(mulMatDFT(coords[0], coords[1]));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function fftImpl$1(x, inverse, backend) {
- const xData = backend.texData.get(x.dataId);
- const inputSize = sizeFromShape(x.shape);
- // Collapse all outer dimensions to a single batch dimension.
- const innerDimensionSize = x.shape[x.shape.length - 1];
- const batch = inputSize / innerDimensionSize;
- const input2D = reshape$3({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } });
- const xShape = input2D.shape;
- const realProgram = new FFTProgram('real', xShape, inverse);
- const imagProgram = new FFTProgram('imag', xShape, inverse);
- const inputs = [
- {
- dataId: xData.complexTensorInfos.real.dataId,
- dtype: xData.complexTensorInfos.real.dtype,
- shape: xShape
- },
- {
- dataId: xData.complexTensorInfos.imag.dataId,
- dtype: xData.complexTensorInfos.imag.dtype,
- shape: xShape
- }
- ];
- const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
- const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
- const complexOutput = complex$2({ inputs: { real: realPart, imag: imagPart }, backend });
- backend.disposeIntermediateTensorInfo(realPart);
- backend.disposeIntermediateTensorInfo(imagPart);
- const complexOutputReshaped = reshape$3({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } });
- backend.disposeIntermediateTensorInfo(complexOutputReshaped);
- return complexOutputReshaped;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function fft$2(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- return fftImpl$1(input, false /* inverse */, backend);
- }
- const fftConfig$1 = {
- kernelName: FFT,
- backendName: 'webgl',
- kernelFunc: fft$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class FlipLeftRightProgram {
- constructor(imageShape) {
- this.variableNames = ['Image'];
- this.outputShape = [];
- const imageWidth = imageShape[2];
- this.outputShape = imageShape;
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int x = coords[2];
-
- int coordX = ${imageWidth} - x;
- float outputValue;
- if(coordX >= 0 && coordX < ${imageWidth}) {
- outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
- } else {
- outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
- }
- setOutput(outputValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const flipLeftRightConfig$1 = {
- kernelName: FlipLeftRight,
- backendName: 'webgl',
- kernelFunc: ({ inputs, backend }) => {
- const { image } = inputs;
- const webglBackend = backend;
- const program = new FlipLeftRightProgram(image.shape);
- const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
- return output;
- }
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class FromPixelsProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- const glsl = getGlslDifferences();
- const [height, width,] = outputShape;
- this.outputShape = outputShape;
- this.userCode = `
- void main() {
- ivec3 coords = getOutputCoords();
- int texR = coords[0];
- int texC = coords[1];
- int depth = coords[2];
- vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
-
- vec4 values = ${glsl.texture2D}(A, uv);
- float value;
- if (depth == 0) {
- value = values.r;
- } else if (depth == 1) {
- value = values.g;
- } else if (depth == 2) {
- value = values.b;
- } else if (depth == 3) {
- value = values.a;
- }
-
- setOutput(floor(value * 255.0 + 0.5));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class FromPixelsPackedProgram {
- constructor(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- const glsl = getGlslDifferences();
- const [height, width,] = outputShape;
- this.outputShape = outputShape;
- this.userCode = `
- void main() {
- ivec3 coords = getOutputCoords();
- int texR = coords[0];
- int texC = coords[1];
- int depth = coords[2];
-
- vec4 result = vec4(0.);
-
- for(int row=0; row<=1; row++) {
- for(int col=0; col<=1; col++) {
- texC = coords[1] + row;
- depth = coords[2] + col;
-
- vec2 uv = (vec2(texC, texR) + halfCR) /
- vec2(${width}.0, ${height}.0);
- vec4 values = ${glsl.texture2D}(A, uv);
- float value;
- if (depth == 0) {
- value = values.r;
- } else if (depth == 1) {
- value = values.g;
- } else if (depth == 2) {
- value = values.b;
- } else if (depth == 3) {
- value = values.a;
- }
-
- result[row * 2 + col] = floor(value * 255.0 + 0.5);
- }
- }
-
- ${glsl.output} = result;
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const fromPixelsConfig = {
- kernelName: FromPixels,
- backendName: 'webgl',
- kernelFunc: fromPixels$1,
- };
- let fromPixels2DContext$1;
- function fromPixels$1(args) {
- const { inputs, backend, attrs } = args;
- let { pixels } = inputs;
- const { numChannels } = attrs;
- const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
- pixels instanceof HTMLVideoElement;
- const isImage = typeof (HTMLImageElement) !== 'undefined' &&
- pixels instanceof HTMLImageElement;
- const [width, height] = isVideo ?
- [
- pixels.videoWidth,
- pixels.videoHeight
- ] :
- [pixels.width, pixels.height];
- const texShape = [height, width];
- const outShape = [height, width, numChannels];
- if (isImage || isVideo) {
- if (fromPixels2DContext$1 == null) {
- fromPixels2DContext$1 = document.createElement('canvas').getContext('2d');
- }
- fromPixels2DContext$1.canvas.width = width;
- fromPixels2DContext$1.canvas.height = height;
- fromPixels2DContext$1.drawImage(pixels, 0, 0, width, height);
- pixels = fromPixels2DContext$1.canvas;
- }
- const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
- // This is a byte texture with pixels.
- backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
- backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
- const program = env().getBool('WEBGL_PACK') ?
- new FromPixelsPackedProgram(outShape) :
- new FromPixelsProgram(outShape);
- const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
- backend.disposeData(tempPixelHandle.dataId);
- return res;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function ifft$2(args) {
- const { inputs, backend } = args;
- const { input } = inputs;
- return fftImpl$1(input, true /* inverse */, backend);
- }
- const ifftConfig$1 = {
- kernelName: IFFT,
- backendName: 'webgl',
- kernelFunc: ifft$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class MeanProgram {
- constructor(reduceInfo, divisor) {
- this.variableNames = ['x'];
- const { windowSize, batchSize, inSize, outSize } = reduceInfo;
- this.outputShape = [batchSize, outSize];
- const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
- const windowSizeVec4Remainder = windowSize % 4;
- let updateSnippet = `sumValue += dot(values, ones);`;
- if (divisor != null) {
- const denominator = 1 / divisor;
- updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) :
- denominator}, ones);`;
- }
- let checkOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkOutOfBounds = `
- if (inIdx < 0 || inIdx >= ${inSize}) {
- return 0.0;
- }
- `;
- }
- this.userCode = `
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
-
- float getValue(int batch, int inIdx) {
- ${checkOutOfBounds}
- return getX(batch, inIdx);
- }
-
- void main() {
- ivec2 coords = getOutputCoords();
- int batch = coords[0];
- int outIdx = coords[1];
- int inOffset = outIdx * ${windowSize};
-
- float sumValue = 0.0;
-
- for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
- int inIdx = inOffset + i;
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2),
- getValue(batch, inIdx + 3)
- );
-
- ${updateSnippet}
- }
-
- int inIdx = inOffset + ${windowSizeNearestVec4};
- if (${windowSizeVec4Remainder === 1}) {
- vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 2}) {
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1), 0.0, 0.0);
-
- ${updateSnippet}
- } else if (${windowSizeVec4Remainder === 3}) {
- vec4 values = vec4(
- getValue(batch, inIdx),
- getValue(batch, inIdx + 1),
- getValue(batch, inIdx + 2), 0.0);
-
- ${updateSnippet}
- }
- setOutput(sumValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Returns an array of configuration objects that describe each stage of the
- // reduction.
- function getReductionStages(inShape) {
- const stages = [];
- while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
- const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
- const windowSize = computeOptimalWindowSize(outSize);
- stages.push({
- inSize: outSize,
- windowSize,
- outSize: Math.ceil(outSize / windowSize)
- });
- }
- return stages;
- }
- function reduce(x, dtype, reductionType, backend) {
- const reductionStages = getReductionStages(x.shape);
- let result = x;
- for (let i = 0; i < reductionStages.length; i++) {
- const { inSize, windowSize, outSize } = reductionStages[i];
- let program;
- let previousResult;
- if (reductionType === 'mean') {
- program = i === 0 ?
- new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) :
- new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize });
- }
- else {
- program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType);
- }
- previousResult = result;
- result = backend.runWebGLProgram(program, [result], dtype);
- if (previousResult.dataId !== x.dataId) {
- backend.disposeIntermediateTensorInfo(previousResult);
- }
- }
- return result;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxImpl$1(x, reduceShape, outShape, backend) {
- const inSize = sizeFromShape(reduceShape);
- const xSize = sizeFromShape(x.shape);
- const batchSize = xSize / inSize;
- const reshapedInput = reshape$3({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
- const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
- const reshapedOutput = reshape$3({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
- backend.disposeIntermediateTensorInfo(reshapedInput);
- backend.disposeIntermediateTensorInfo(reduced);
- return reshapedOutput;
- }
-
- /**
- * @license
- * Copyright 2017 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class TransposeProgram {
- constructor(aShape, newDim) {
- this.variableNames = ['A'];
- const outputShape = new Array(aShape.length);
- for (let i = 0; i < outputShape.length; i++) {
- outputShape[i] = aShape[newDim[i]];
- }
- this.outputShape = outputShape;
- this.rank = outputShape.length;
- const dtype = getCoordsDataType(this.rank);
- const switched = getSwitchedCoords(newDim);
- this.userCode = `
- void main() {
- ${dtype} resRC = getOutputCoords();
- setOutput(getA(${switched}));
- }
- `;
- }
- }
- function getSwitchedCoords(newDim) {
- const rank = newDim.length;
- if (rank > 6) {
- throw Error(`Transpose for rank ${rank} is not yet supported`);
- }
- const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
- const switchedCoords = new Array(rank);
- for (let i = 0; i < newDim.length; i++) {
- switchedCoords[newDim[i]] = originalOrder[i];
- }
- return switchedCoords.join();
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class TransposePackedProgram {
- constructor(aShape, newDim) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- const outputShape = new Array(aShape.length);
- for (let i = 0; i < outputShape.length; i++) {
- outputShape[i] = aShape[newDim[i]];
- }
- this.outputShape = outputShape;
- this.rank = outputShape.length;
- if (this.rank > 6) {
- throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
- }
- const dtype = getCoordsDataType(this.rank);
- const outputOrder = getVecChannels('rc', this.rank);
- const switchedOrder = new Array(this.rank);
- for (let i = 0; i < newDim.length; i++) {
- switchedOrder[newDim[i]] = outputOrder[i];
- }
- const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;
- const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;
- const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;
- this.userCode = `
- void main() {
- ${dtype} rc = getOutputCoords();
- vec4 result = vec4(0.);
- result[0] = ${getc};
- if(${nextColumn}) {
- result[1] = ${getc};
- }
- --${outputOrder[this.rank - 1]};
- if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {
- result[2] = ${getc};
- if(${nextColumn}) {
- result[3] = ${getc};
- }
- }
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function transposeImpl$1(x, perm, backend) {
- const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
- new TransposePackedProgram(x.shape, perm) :
- new TransposeProgram(x.shape, perm);
- return backend.runWebGLProgram(program, [x], x.dtype);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxConfig$1 = {
- kernelName: Max,
- backendName: 'webgl',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { reductionIndices, keepDims } = attrs;
- const webglBackend = backend;
- const xRank = x.shape.length;
- const origAxes = parseAxisParam(reductionIndices, x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, xRank);
- const maxInputIsTransposed = permutedAxes != null;
- const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
- let maxInput = x;
- if (maxInputIsTransposed) {
- if (shouldExecuteOnCPU) {
- const xTexData = webglBackend.texData.get(maxInput.dataId);
- const values = xTexData.values;
- const newShape = new Array(xRank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = x.shape[permutedAxes[i]];
- }
- const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
- maxInput = webglBackend.makeTensorInfo(newShape, x.dtype);
- const maxInputData = webglBackend.texData.get(maxInput.dataId);
- maxInputData.values = maxInputValues;
- }
- else {
- maxInput = transposeImpl$1(x, permutedAxes, webglBackend);
- }
- axes = getInnerMostAxes(axes.length, xRank);
- }
- assertAxesAreInnerMostDims('max', axes, xRank);
- const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes);
- let outShape = maxOutShape;
- if (keepDims) {
- // rather than reshape at the end, set the target shape here.
- outShape = expandShapeToKeepDim(maxOutShape, origAxes);
- }
- let out;
- if (shouldExecuteOnCPU) {
- const xTexData = webglBackend.texData.get(maxInput.dataId);
- const values = xTexData.values;
- const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype);
- out = webglBackend.makeTensorInfo(outShape, x.dtype);
- const outData = webglBackend.texData.get(out.dataId);
- outData.values = outValues;
- }
- else {
- out = maxImpl$1(maxInput, reduceShape, outShape, webglBackend);
- }
- if (maxInputIsTransposed) {
- webglBackend.disposeIntermediateTensorInfo(maxInput);
- }
- return out;
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPool$2(args) {
- const { inputs, backend, attrs } = args;
- const { x } = inputs;
- assertNotComplex$1(x, 'maxPool');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const dilations = 1;
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
- if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
- arraysEqual(convInfo.inShape, convInfo.outShape)) {
- return identity$2({ inputs: { x }, backend });
- }
- const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
- return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
- }
- const maxPoolConfig$1 = {
- kernelName: MaxPool,
- backendName: 'webgl',
- kernelFunc: maxPool$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPoolBackprop$2(args) {
- const { inputs, backend, attrs } = args;
- const { dy, input, output } = inputs;
- const x = input;
- assertNotComplex$1([input, output], 'maxPoolBackprop');
- const { filterSize, strides, pad, dimRoundingMode } = attrs;
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
- const getPositions = true;
- const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
- const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
- const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
- const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
- backend.disposeIntermediateTensorInfo(maxPoolPositions);
- return result;
- }
- const maxPoolBackpropConfig$1 = {
- kernelName: MaxPoolBackprop,
- backendName: 'webgl',
- kernelFunc: maxPoolBackprop$2
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, backend) {
- let program = new Pool2DProgram(convInfo, 'max', false);
- const poolOutput = backend.runWebGLProgram(program, [x], 'float32');
- program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
- const indexOutput = backend.runWebGLProgram(program, [x], 'float32');
- return [poolOutput, indexOutput];
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const maxPoolWithArgmaxConfig$1 = {
- kernelName: MaxPoolWithArgmax,
- backendName: 'webgl',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { filterSize, strides, pad, includeBatchInIndex } = attrs;
- const webglBackend = backend;
- assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
- const dilations = [1, 1];
- assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
- `Got strides ${strides} and dilations '${dilations}'`);
- const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
- const [result, indexes] = maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, webglBackend);
- return [result, indexes];
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function meanImpl(x, reduceShape, outShape, backend) {
- const inSize = sizeFromShape(reduceShape);
- const xSize = sizeFromShape(x.shape);
- const batchSize = xSize / inSize;
- const reshapedInput = reshape$3({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
- const reduced = reduce(reshapedInput, 'float32', 'mean', backend);
- const reshapedOutput = reshape$3({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
- backend.disposeIntermediateTensorInfo(reshapedInput);
- backend.disposeIntermediateTensorInfo(reduced);
- return reshapedOutput;
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const meanConfig = {
- kernelName: Mean,
- backendName: 'webgl',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { keepDims, axis } = attrs;
- const webglBackend = backend;
- const xRank = x.shape.length;
- const origAxes = parseAxisParam(axis, x.shape);
- let axes = origAxes;
- const permutedAxes = getAxesPermutation(axes, xRank);
- const meanInputIsTransposed = permutedAxes != null;
- const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
- const intermediates = [];
- let meanInput = x;
- if (meanInputIsTransposed) {
- if (shouldExecuteOnCPU) {
- const xTexData = webglBackend.texData.get(meanInput.dataId);
- const values = xTexData.values;
- const newShape = new Array(xRank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = x.shape[permutedAxes[i]];
- }
- const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
- meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
- const meanInputData = webglBackend.texData.get(meanInput.dataId);
- meanInputData.values = meanInputValues;
- }
- else {
- meanInput = transposeImpl$1(x, permutedAxes, webglBackend);
- }
- intermediates.push(meanInput);
- axes = getInnerMostAxes(axes.length, xRank);
- }
- assertAxesAreInnerMostDims('sum', axes, xRank);
- const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes);
- let outShape = meanOutShape;
- if (keepDims) {
- // rather than reshape at the end, set the target shape here.
- outShape = expandShapeToKeepDim(meanOutShape, origAxes);
- }
- const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
- for (const i of intermediates) {
- webglBackend.disposeIntermediateTensorInfo(i);
- }
- return out;
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class MirrorPadProgram {
- constructor(xShape, paddings, mode) {
- this.variableNames = ['x'];
- this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
- const rank = xShape.length;
- const dtype = getCoordsDataType(rank);
- const start = paddings.map(p => p[0]).join(',');
- const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
- const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
- const offset = mode === 'reflect' ? 0 : 1;
- if (rank === 1) {
- this.userCode = `
- int start = ${start};
- int end = ${end};
-
- void main() {
- int outC = getOutputCoords();
- if (outC < start) {
- outC = start * 2 - outC - ${offset};
- } else if(outC >= end) {
- outC = (end - 1) * 2 - outC + ${offset};
- }
- setOutput(getX(outC - start));
- }
- `;
- return;
- }
- this.userCode = `
- ${dtype} start = ${dtype}(${start});
- ${dtype} end = ${dtype}(${end});
-
- void main() {
- ${dtype} outC = getOutputCoords();
- for (int i = 0; i < ${rank}; i++) {
- if (outC[i] < start[i]) {
- outC[i] = start[i] * 2 - outC[i] - ${offset};
- } else if(outC[i] >= end[i]) {
- outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
- }
- }
- ${dtype} coords = outC - start;
- setOutput(getX(${unpackedCoords}));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Example shader code for
- * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
- * ```
- * const int start = int(2);
- * const int end = int(5);
- *
- * void main() {
- * int outputLoc = getOutputCoords();
- * vec4 result = vec4(0.);
- *
- * int rc = outputLoc;
- *
- * int source = rc;
- * if (source < start) {
- * source = start * 2 - source - 0;
- * } else if (source >= end) {
- * source = (end - 1) * 2 - source + 0;
- * }
- * source -= start;
- *
- * result[0] = getChannel(getX(source), source);
- * rc += 1;
- * if(rc < 6) {
- * int source = rc;
- * if (source < start) {
- * source = start * 2 - source - 0;
- * } else if (source >= end) {
- * source = (end - 1) * 2 - source + 0;
- * }
- * source -= start;
- *
- * result[1] = getChannel(getX(source), source);
- * }
- *
- * setOutput(result);
- * }
- * ```
- */
- class MirrorPadPackedProgram {
- constructor(xShape, paddings, mode) {
- this.variableNames = ['x'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
- const rank = xShape.length;
- const dtype = getCoordsDataType(rank);
- const start = paddings.map(p => p[0]).join(',');
- const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
- const coords = getChannels('rc', rank);
- const source = getChannels('source', rank);
- const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
- const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
- const offset = mode === 'reflect' ? 0 : 1;
- let mainLoop = '';
- if (rank === 1) {
- const padSetup = `
- ${dtype} source = rc;
- if (source < start) {
- source = start * 2 - source - ${offset};
- } else if (source >= end) {
- source = (end - 1) * 2 - source + ${offset};
- }
- source -= start;
- `;
- mainLoop = `
- ${dtype} rc = outputLoc;
- ${padSetup}
- result[0] = getChannel(getX(${source.join()}), ${innerDims});
- ${coords[rank - 1]} += 1;
- if(${cLimit}) {
- ${padSetup}
- result[1] = getChannel(getX(${source.join()}), ${innerDims});
- }
- `;
- }
- else {
- const padSetup = `
- ${dtype} source = rc;
- ${dtype} lt = ${dtype}(lessThan(source, start));
- ${dtype} gte = ${dtype}(greaterThanEqual(source, end));
- ${dtype} orig = 1 - (lt + gte);
- source = orig * source +
- lt * (start * 2 - source - ${offset}) +
- gte * ((end - 1) * 2 - source + ${offset});
- source -= start;
- `;
- mainLoop = `
- ${dtype} rc = outputLoc;
- ${padSetup}
- result[0] = getChannel(getX(${source.join()}), ${innerDims});
- ${coords[rank - 1]} += 1;
- if(${cLimit}) {
- ${padSetup}
- result[1] = getChannel(getX(${source.join()}), ${innerDims});
- }
- rc = outputLoc;
- ${coords[rank - 2]} += 1;
- if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
- ${padSetup}
- result[2] = getChannel(getX(${source.join()}), ${innerDims});
- ${coords[rank - 1]} += 1;
- if(${cLimit}) {
- ${padSetup}
- result[3] = getChannel(getX(${source.join()}), ${innerDims});
- }
- }
- `;
- }
- this.userCode = `
- const ${dtype} start = ${dtype}(${start});
- const ${dtype} end = ${dtype}(${end});
-
- void main() {
- ${dtype} outputLoc = getOutputCoords();
- vec4 result = vec4(0.);
- ${mainLoop}
- setOutput(result);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => {
- const { x } = inputs;
- const { paddings, mode } = attrs;
- const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
- new MirrorPadPackedProgram(x.shape, paddings, mode) :
- new MirrorPadProgram(x.shape, paddings, mode);
- const output = backend.runWebGLProgram(program, [x], x.dtype);
- return output;
- };
- const mirrorPadConfig$1 = {
- kernelName: MirrorPad,
- backendName: 'webgl',
- kernelFunc: mirrorPadKernelFunc,
- };
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // (Ar + Ai)(Br + Bi) =
- // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
- // Yr = ArBr - AB
- // Yi = ArBi + AiBr
- const COMPLEX_MULTIPLY = {
- REAL: 'return areal * breal - aimag * bimag;',
- IMAG: 'return areal * bimag + aimag * breal;'
- };
- class BinaryOpComplexProgram {
- constructor(op, aShape, bShape) {
- this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
- this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
- this.userCode = `
- float binaryOpComplex(
- float areal, float aimag, float breal, float bimag) {
- ${op}
- }
-
- void main() {
- float areal = getARealAtOutCoords();
- float aimag = getAImagAtOutCoords();
- float breal = getBRealAtOutCoords();
- float bimag = getBImagAtOutCoords();
- setOutput(binaryOpComplex(areal, aimag, breal, bimag));
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const MUL = 'return a * b;';
- function multiply$3(args) {
- const { inputs, backend } = args;
- const { a, b } = inputs;
- if (a.dtype === 'complex64') {
- const aData = backend.texData.get(a.dataId);
- const bData = backend.texData.get(b.dataId);
- const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
- const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
- const inputs = [
- {
- dataId: aData.complexTensorInfos.real.dataId,
- dtype: aData.complexTensorInfos.real.dtype,
- shape: a.shape
- },
- {
- dataId: aData.complexTensorInfos.imag.dataId,
- dtype: aData.complexTensorInfos.imag.dtype,
- shape: a.shape
- },
- {
- dataId: bData.complexTensorInfos.real.dataId,
- dtype: bData.complexTensorInfos.real.dtype,
- shape: b.shape
- },
- {
- dataId: bData.complexTensorInfos.imag.dataId,
- dtype: bData.complexTensorInfos.imag.dtype,
- shape: b.shape
- }
- ];
- const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
- const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
- const complexOutput = complex$2({ inputs: { real: realPart, imag: imagPart }, backend });
- backend.disposeIntermediateTensorInfo(realPart);
- backend.disposeIntermediateTensorInfo(imagPart);
- // TODO(annxingyuan): CPU forwarding for complex inputs.
- return complexOutput;
- }
- if (backend.shouldExecuteOnCPU([a, b])) {
- const aData = backend.texData.get(a.dataId);
- const bData = backend.texData.get(b.dataId);
- const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, 'float32');
- const out = backend.makeTensorInfo(outShape, 'float32');
- const outData = backend.texData.get(out.dataId);
- outData.values = outValues;
- return out;
- }
- let program;
- if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
- program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
- }
- else {
- program = new BinaryOpProgram(MUL, a.shape, b.shape);
- }
- return backend.runWebGLProgram(program, [a, b], a.dtype);
- }
- const multiplyConfig$1 = {
- kernelName: Multiply,
- backendName: 'webgl',
- kernelFunc: multiply$3
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV3Config = {
- kernelName: NonMaxSuppressionV3,
- backendName: 'webgl',
- kernelFunc: ({ inputs, backend, attrs }) => {
- warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
- 'Call tf.nonMaxSuppressionAsync() instead');
- const { boxes, scores } = inputs;
- const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
- const gpuBackend = backend;
- const boxesVals = gpuBackend.readSync(boxes.dataId);
- const scoresVals = gpuBackend.readSync(scores.dataId);
- const maxOutputSizeVal = maxOutputSize;
- const iouThresholdVal = iouThreshold;
- const scoreThresholdVal = scoreThreshold;
- return nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal);
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV4Impl$2 = nonMaxSuppressionV4Impl;
- const nonMaxSuppressionV4Config$1 = {
- kernelName: NonMaxSuppressionV4,
- backendName: 'webgl',
- kernelFunc: ({ inputs, backend, attrs }) => {
- warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
- 'Call tf.nonMaxSuppressionAsync() instead');
- const { boxes, scores } = inputs;
- const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
- const gpuBackend = backend;
- const boxesVals = gpuBackend.readSync(boxes.dataId);
- const scoresVals = gpuBackend.readSync(scores.dataId);
- const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
- return [selectedIndices, validOutputs];
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const nonMaxSuppressionV5Impl$2 = nonMaxSuppressionV5Impl;
- const nonMaxSuppressionV5Config$1 = {
- kernelName: NonMaxSuppressionV5,
- backendName: 'webgl',
- kernelFunc: ({ inputs, backend, attrs }) => {
- warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
- 'Call tf.nonMaxSuppressionAsync() instead');
- const { boxes, scores } = inputs;
- const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
- const gpuBackend = backend;
- const boxesVals = gpuBackend.readSync(boxes.dataId);
- const scoresVals = gpuBackend.readSync(scores.dataId);
- const maxOutputSizeVal = maxOutputSize;
- const iouThresholdVal = iouThreshold;
- const scoreThresholdVal = scoreThreshold;
- const softNmsSigmaVal = softNmsSigma;
- const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$2(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
- return [selectedIndices, selectedScores];
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- class RotateProgram {
- constructor(imageShape, radians, fillValue, center) {
- this.variableNames = ['Image'];
- this.outputShape = [];
- const imageHeight = imageShape[1];
- const imageWidth = imageShape[2];
- const sinFactor = Math.sin(radians).toFixed(3);
- const cosFactor = Math.cos(radians).toFixed(3);
- this.outputShape = imageShape;
- const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
- const centerXString = centerX.toFixed(3);
- const centerYString = centerY.toFixed(3);
- let fillSnippet = '';
- if (typeof fillValue === 'number') {
- fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
- }
- else {
- fillSnippet = `
- vec3 fill = vec3(${fillValue.join(',')});
- float outputValue = fill[coords[3]];`;
- }
- this.userCode = `
- void main() {
- ivec4 coords = getOutputCoords();
- int x = coords[2];
- int y = coords[1];
- float coordXFloat = (float(x) - ${centerXString}) * ${cosFactor} - (float(y) - ${centerYString}) * ${sinFactor};
- float coordYFloat = (float(x) - ${centerXString}) * ${sinFactor} + (float(y) - ${centerYString}) * ${cosFactor};
- int coordX = int(round(coordXFloat + ${centerXString}));
- int coordY = int(round(coordYFloat + ${centerYString}));
- ${fillSnippet}
- if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
- outputValue = getImage(coords[0], coordY, coordX, coords[3]);
- }
- setOutput(outputValue);
- }
- `;
- }
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const rotateWithOffsetConfig$1 = {
- kernelName: RotateWithOffset,
- backendName: 'webgl',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { image } = inputs;
- const { radians, fillValue, center } = attrs;
- const webglBackend = backend;
- const program = new RotateProgram(image.shape, radians, fillValue, center);
- const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
- return output;
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const SIN = CHECK_NAN_SNIPPET_UNARY + `
- return sin(x);
-`;
- const sin$2 = unaryKernelFunc$1(SIN);
- const sinConfig$1 = {
- kernelName: Sin,
- backendName: 'webgl',
- kernelFunc: sin$2,
- };
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const SQUARE = `return x * x;`;
- const square$2 = unaryKernelFunc$1(SQUARE);
- const squareConfig$1 = {
- kernelName: Square,
- backendName: 'webgl',
- kernelFunc: square$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const SQUARED_DIFFERENCE$1 = 'return (a - b) * (a - b);';
- const squaredDifference$2 = binaryKernelFunc$1({ opSnippet: SQUARED_DIFFERENCE$1, packedOpSnippet: SQUARED_DIFFERENCE$1 });
- const squaredDifferenceConfig$1 = {
- kernelName: SquaredDifference,
- backendName: 'webgl',
- kernelFunc: squaredDifference$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const SUB = 'return a - b;';
- const subKernelFunc = binaryKernelFunc$1({
- opSnippet: SUB,
- packedOpSnippet: SUB,
- supportsComplex: true,
- cpuKernelImpl: subImplCPU
- });
- const subConfig$1 = {
- kernelName: Sub,
- backendName: 'webgl',
- kernelFunc: subKernelFunc
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const TAN = `return tan(x);`;
- const tan$2 = unaryKernelFunc$1(TAN);
- const tanConfig$1 = {
- kernelName: Tan,
- backendName: 'webgl',
- kernelFunc: tan$2,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const transposeConfig$1 = {
- kernelName: Transpose,
- backendName: 'webgl',
- kernelFunc: ({ inputs, attrs, backend }) => {
- const { x } = inputs;
- const { perm } = attrs;
- const webglBackend = backend;
- const xRank = x.shape.length;
- const newShape = new Array(xRank);
- for (let i = 0; i < newShape.length; i++) {
- newShape[i] = x.shape[perm[i]];
- }
- let out;
- if (webglBackend.shouldExecuteOnCPU([x])) {
- const xTexData = webglBackend.texData.get(x.dataId);
- const values = xTexData.values;
- const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
- out = webglBackend.makeTensorInfo(newShape, x.dtype);
- const outData = webglBackend.texData.get(out.dataId);
- outData.values = outValues;
- }
- else {
- out = transposeImpl$1(x, perm, webglBackend);
- }
- return out;
- }
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the License);
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function unique$3(args) {
- const { inputs, attrs, backend } = args;
- const { axis } = attrs;
- const { x } = inputs;
- assertNotComplex$1(x, 'unique');
- // For now, always forward calculation to the CPU backend.
- console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
- const values = backend.readSync(x.dataId);
- const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype);
- return [
- backend.makeTensorInfo(outputShape, x.dtype, outputValues),
- backend.makeTensorInfo([indices.length], 'int32', indices),
- ];
- }
- const uniqueConfig$1 = {
- kernelName: Unique,
- backendName: 'webgl',
- kernelFunc: unique$3,
- };
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // List all kernel configs here
- const kernelConfigs$1 = [
- addConfig$1,
- atan2Config,
- avgPoolConfig$1,
- avgPoolBackpropConfig$1,
- batchNormConfig$1,
- castConfig$1,
- complexConfig$1,
- concatConfig$1,
- cosConfig$1,
- divConfig$1,
- fftConfig$1,
- flipLeftRightConfig$1,
- fromPixelsConfig,
- identityConfig$1,
- ifftConfig$1,
- imagConfig$1,
- maxConfig$1,
- maxPoolConfig$1,
- maxPoolBackpropConfig$1,
- maxPoolWithArgmaxConfig$1,
- meanConfig,
- mirrorPadConfig$1,
- multiplyConfig$1,
- nonMaxSuppressionV3Config,
- nonMaxSuppressionV4Config$1,
- nonMaxSuppressionV5Config$1,
- notEqualConfig$1,
- realConfig$1,
- reshapeConfig$1,
- rotateWithOffsetConfig$1,
- sinConfig$1,
- squareConfig$1,
- subConfig$1,
- squaredDifferenceConfig$1,
- tanConfig$1,
- transposeConfig$1,
- uniqueConfig$1,
- ];
- for (const kernelConfig of kernelConfigs$1) {
- registerKernel(kernelConfig);
- }
-
- /**
- * @license
- * Copyright 2020 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
-
- /** @license See the LICENSE file. */
- // This code is auto-generated, do not modify this file!
- const version$6 = '0.0.0';
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- const version$7 = {
- 'tfjs-core': version,
- 'tfjs-backend-cpu': version$4,
- 'tfjs-backend-webgl': version$5,
- 'tfjs-data': version$3,
- 'tfjs-layers': version$1,
- 'tfjs-converter': version$2,
- 'tfjs': version$6
- };
-
- exports.Abs = Abs;
- exports.Acos = Acos;
- exports.Acosh = Acosh;
- exports.AdadeltaOptimizer = AdadeltaOptimizer;
- exports.AdagradOptimizer = AdagradOptimizer;
- exports.AdamOptimizer = AdamOptimizer;
- exports.AdamaxOptimizer = AdamaxOptimizer;
- exports.Add = Add;
- exports.AddN = AddN;
- exports.All = All;
- exports.Any = Any;
- exports.ArgMax = ArgMax;
- exports.ArgMin = ArgMin;
- exports.Asin = Asin;
- exports.Asinh = Asinh;
- exports.Atan = Atan;
- exports.Atan2 = Atan2;
- exports.Atanh = Atanh;
- exports.AvgPool = AvgPool;
- exports.AvgPool3D = AvgPool3D;
- exports.AvgPool3DBackprop = AvgPool3DBackprop;
- exports.AvgPoolBackprop = AvgPoolBackprop;
- exports.BatchMatMul = BatchMatMul;
- exports.BatchToSpaceND = BatchToSpaceND;
- exports.BroadcastTo = BroadcastTo;
- exports.Callback = Callback;
- exports.CallbackList = CallbackList;
- exports.Cast = Cast;
- exports.Ceil = Ceil;
- exports.ClipByValue = ClipByValue;
- exports.Complex = Complex;
- exports.Concat = Concat;
- exports.Conv2D = Conv2D;
- exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
- exports.Conv2DBackpropInput = Conv2DBackpropInput;
- exports.Conv3D = Conv3D;
- exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
- exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
- exports.Cos = Cos;
- exports.Cosh = Cosh;
- exports.CropAndResize = CropAndResize;
- exports.Cumsum = Cumsum;
- exports.CustomCallback = CustomCallback;
- exports.DataStorage = DataStorage;
- exports.DepthToSpace = DepthToSpace;
- exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
- exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
- exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
- exports.Diag = Diag;
- exports.Dilation2D = Dilation2D;
- exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
- exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
- exports.Div = Div;
- exports.EarlyStopping = EarlyStopping;
- exports.Elu = Elu;
- exports.EluGrad = EluGrad;
- exports.Environment = Environment;
- exports.Equal = Equal;
- exports.Erf = Erf;
- exports.Exp = Exp;
- exports.Expm1 = Expm1;
- exports.FFT = FFT;
- exports.Fill = Fill;
- exports.FlipLeftRight = FlipLeftRight;
- exports.Floor = Floor;
- exports.FloorDiv = FloorDiv;
- exports.FromPixels = FromPixels;
- exports.FusedBatchNorm = FusedBatchNorm;
- exports.FusedConv2D = FusedConv2D;
- exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
- exports.GatherNd = GatherNd;
- exports.GatherV2 = GatherV2;
- exports.GraphModel = GraphModel;
- exports.Greater = Greater;
- exports.GreaterEqual = GreaterEqual;
- exports.History = History;
- exports.IFFT = IFFT;
- exports.Identity = Identity;
- exports.Imag = Imag;
- exports.InputSpec = InputSpec;
- exports.IsFinite = IsFinite;
- exports.IsInf = IsInf;
- exports.IsNan = IsNan;
- exports.KernelBackend = KernelBackend;
- exports.LRN = LRN;
- exports.LRNBackprop = LRNBackprop;
- exports.LayerVariable = LayerVariable;
- exports.LayersModel = LayersModel;
- exports.Less = Less;
- exports.LessEqual = LessEqual;
- exports.LinSpace = LinSpace;
- exports.Log = Log;
- exports.Log1p = Log1p;
- exports.LogSoftmax = LogSoftmax;
- exports.LogicalAnd = LogicalAnd;
- exports.LogicalNot = LogicalNot;
- exports.LogicalOr = LogicalOr;
- exports.Max = Max;
- exports.MaxPool = MaxPool;
- exports.MaxPool3D = MaxPool3D;
- exports.MaxPool3DBackprop = MaxPool3DBackprop;
- exports.MaxPoolBackprop = MaxPoolBackprop;
- exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
- exports.Maximum = Maximum;
- exports.Mean = Mean;
- exports.Min = Min;
- exports.Minimum = Minimum;
- exports.MirrorPad = MirrorPad;
- exports.Mod = Mod;
- exports.MomentumOptimizer = MomentumOptimizer;
- exports.Multiply = Multiply;
- exports.Negate = Negate;
- exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
- exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
- exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
- exports.NotEqual = NotEqual;
- exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
- exports.OneHot = OneHot;
- exports.OnesLike = OnesLike;
- exports.Optimizer = Optimizer;
- exports.PadV2 = PadV2;
- exports.Pool = Pool;
- exports.Pow = Pow;
- exports.Prelu = Prelu;
- exports.Prod = Prod;
- exports.RMSPropOptimizer = RMSPropOptimizer;
- exports.RNN = RNN;
- exports.Range = Range;
- exports.Real = Real;
- exports.Reciprocal = Reciprocal;
- exports.Relu = Relu;
- exports.Relu6 = Relu6;
- exports.Reshape = Reshape;
- exports.ResizeBilinear = ResizeBilinear;
- exports.ResizeBilinearGrad = ResizeBilinearGrad;
- exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
- exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
- exports.Reverse = Reverse;
- exports.RotateWithOffset = RotateWithOffset;
- exports.Round = Round;
- exports.Rsqrt = Rsqrt;
- exports.SGDOptimizer = SGDOptimizer;
- exports.ScatterNd = ScatterNd;
- exports.SelectV2 = SelectV2;
- exports.Selu = Selu;
- exports.Sequential = Sequential;
- exports.Sigmoid = Sigmoid;
- exports.Sign = Sign;
- exports.Sin = Sin;
- exports.Sinh = Sinh;
- exports.Slice = Slice;
- exports.Softmax = Softmax;
- exports.Softplus = Softplus;
- exports.SpaceToBatchND = SpaceToBatchND;
- exports.SparseToDense = SparseToDense;
- exports.SplitV = SplitV;
- exports.Sqrt = Sqrt;
- exports.Square = Square;
- exports.SquaredDifference = SquaredDifference;
- exports.Step = Step;
- exports.StridedSlice = StridedSlice;
- exports.Sub = Sub;
- exports.Sum = Sum;
- exports.SymbolicTensor = SymbolicTensor;
- exports.Tan = Tan;
- exports.Tanh = Tanh;
- exports.Tensor = Tensor;
- exports.TensorBuffer = TensorBuffer;
- exports.Tile = Tile;
- exports.TopK = TopK;
- exports.Transpose = Transpose;
- exports.Unique = Unique;
- exports.Unpack = Unpack;
- exports.UnsortedSegmentSum = UnsortedSegmentSum;
- exports.Variable = Variable;
- exports.ZerosLike = ZerosLike;
- exports._FusedMatMul = _FusedMatMul;
- exports.abs = abs;
- exports.acos = acos;
- exports.acosh = acosh;
- exports.add = add$1;
- exports.addN = addN;
- exports.addStrict = addStrict;
- exports.all = all;
- exports.any = any;
- exports.argMax = argMax;
- exports.argMin = argMin;
- exports.asin = asin;
- exports.asinh = asinh;
- exports.atan = atan;
- exports.atan2 = atan2;
- exports.atanh = atanh;
- exports.avgPool = avgPool;
- exports.avgPool3d = avgPool3d;
- exports.backend = backend;
- exports.backend_util = backend_util;
- exports.basicLSTMCell = basicLSTMCell;
- exports.batchNorm = batchNorm;
- exports.batchNorm2d = batchNorm2d;
- exports.batchNorm3d = batchNorm3d;
- exports.batchNorm4d = batchNorm4d;
- exports.batchToSpaceND = batchToSpaceND;
- exports.booleanMaskAsync = booleanMaskAsync;
- exports.broadcastTo = broadcastTo;
- exports.browser = browser;
- exports.buffer = buffer;
- exports.callbacks = callbacks;
- exports.cast = cast;
- exports.ceil = ceil;
- exports.clipByValue = clipByValue;
- exports.clone = clone;
- exports.complex = complex;
- exports.concat = concat;
- exports.concat1d = concat1d;
- exports.concat2d = concat2d;
- exports.concat3d = concat3d;
- exports.concat4d = concat4d;
- exports.constraints = exports_constraints;
- exports.conv1d = conv1d;
- exports.conv2d = conv2d;
- exports.conv2dTranspose = conv2dTranspose;
- exports.conv3d = conv3d;
- exports.conv3dTranspose = conv3dTranspose;
- exports.copyRegisteredKernels = copyRegisteredKernels;
- exports.cos = cos;
- exports.cosh = cosh;
- exports.cosineWindow = cosineWindow;
- exports.cumsum = cumsum;
- exports.customGrad = customGrad;
- exports.data = index;
- exports.deprecationWarn = deprecationWarn;
- exports.depthToSpace = depthToSpace;
- exports.depthwiseConv2d = depthwiseConv2d;
- exports.deregisterOp = deregisterOp;
- exports.device_util = device_util;
- exports.diag = diag;
- exports.dilation2d = dilation2d;
- exports.disableDeprecationWarnings = disableDeprecationWarnings;
- exports.dispose = dispose;
- exports.disposeVariables = disposeVariables;
- exports.div = div;
- exports.divNoNan = divNoNan;
- exports.divStrict = divStrict;
- exports.dot = dot;
- exports.dropout = dropout;
- exports.elu = elu;
- exports.enableDebugMode = enableDebugMode;
- exports.enableProdMode = enableProdMode;
- exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
- exports.engine = engine;
- exports.env = env;
- exports.equal = equal;
- exports.equalStrict = equalStrict;
- exports.erf = erf;
- exports.exp = exp;
- exports.expandDims = expandDims;
- exports.expm1 = expm1;
- exports.eye = eye;
- exports.fft = fft;
- exports.fill = fill;
- exports.findBackend = findBackend;
- exports.findBackendFactory = findBackendFactory;
- exports.floor = floor;
- exports.floorDiv = floorDiv;
- exports.fused = fused_ops;
- exports.gather = gather;
- exports.gatherND = gatherND;
- exports.gather_util = gather_nd_util;
- exports.getBackend = getBackend;
- exports.getGradient = getGradient;
- exports.getKernel = getKernel;
- exports.getKernelsForBackend = getKernelsForBackend;
- exports.grad = grad;
- exports.grads = grads;
- exports.greater = greater;
- exports.greaterEqual = greaterEqual;
- exports.greaterEqualStrict = greaterEqualStrict;
- exports.greaterStrict = greaterStrict;
- exports.ifft = ifft;
- exports.imag = imag;
- exports.image = image;
- exports.inTopKAsync = inTopKAsync;
- exports.initializers = exports_initializers;
- exports.input = input;
- exports.io = io;
- exports.irfft = irfft;
- exports.isFinite = isFinite$1;
- exports.isInf = isInf;
- exports.isNaN = isNaN$1;
- exports.keep = keep;
- exports.kernel_impls = kernel_impls;
- exports.layers = exports_layers;
- exports.leakyRelu = leakyRelu;
- exports.less = less;
- exports.lessEqual = lessEqual;
- exports.lessEqualStrict = lessEqualStrict;
- exports.lessStrict = lessStrict;
- exports.linalg = linalg;
- exports.linspace = linspace;
- exports.loadGraphModel = loadGraphModel;
- exports.loadLayersModel = loadLayersModel;
- exports.localResponseNormalization = localResponseNormalization;
- exports.log = log;
- exports.log1p = log1p;
- exports.logSigmoid = logSigmoid;
- exports.logSoftmax = logSoftmax;
- exports.logSumExp = logSumExp;
- exports.logicalAnd = logicalAnd;
- exports.logicalNot = logicalNot;
- exports.logicalOr = logicalOr;
- exports.logicalXor = logicalXor;
- exports.losses = losses;
- exports.matMul = matMul;
- exports.math = math;
- exports.max = max;
- exports.maxPool = maxPool;
- exports.maxPool3d = maxPool3d;
- exports.maxPoolWithArgmax = maxPoolWithArgmax;
- exports.maximum = maximum;
- exports.maximumStrict = maximumStrict;
- exports.mean = mean;
- exports.memory = memory;
- exports.metrics = exports_metrics;
- exports.min = min;
- exports.minimum = minimum;
- exports.minimumStrict = minimumStrict;
- exports.mirrorPad = mirrorPad;
- exports.mod = mod;
- exports.modStrict = modStrict;
- exports.model = model;
- exports.models = exports_models;
- exports.moments = moments;
- exports.movingAverage = movingAverage;
- exports.mul = mul;
- exports.mulStrict = mulStrict;
- exports.multiRNNCell = multiRNNCell;
- exports.multinomial = multinomial;
- exports.neg = neg;
- exports.nextFrame = nextFrame;
- exports.norm = norm;
- exports.notEqual = notEqual;
- exports.notEqualStrict = notEqualStrict;
- exports.oneHot = oneHot;
- exports.ones = ones$1;
- exports.onesLike = onesLike;
- exports.op = op;
- exports.outerProduct = outerProduct;
- exports.pad = pad;
- exports.pad1d = pad1d;
- exports.pad2d = pad2d;
- exports.pad3d = pad3d;
- exports.pad4d = pad4d;
- exports.pool = pool;
- exports.pow = pow;
- exports.powStrict = powStrict;
- exports.prelu = prelu;
- exports.print = print;
- exports.prod = prod;
- exports.profile = profile;
- exports.rand = rand;
- exports.randomGamma = randomGamma;
- exports.randomNormal = randomNormal;
- exports.randomUniform = randomUniform;
- exports.range = range;
- exports.ready = ready;
- exports.real = real;
- exports.reciprocal = reciprocal;
- exports.registerBackend = registerBackend;
- exports.registerCallbackConstructor = registerCallbackConstructor;
- exports.registerGradient = registerGradient;
- exports.registerKernel = registerKernel;
- exports.registerOp = registerOp;
- exports.regularizers = exports_regularizers;
- exports.relu = relu;
- exports.relu6 = relu6;
- exports.removeBackend = removeBackend;
- exports.reshape = reshape;
- exports.reverse = reverse;
- exports.reverse1d = reverse1d;
- exports.reverse2d = reverse2d;
- exports.reverse3d = reverse3d;
- exports.reverse4d = reverse4d;
- exports.rfft = rfft;
- exports.round = round;
- exports.rsqrt = rsqrt;
- exports.scalar = scalar;
- exports.scatterND = scatterND;
- exports.scatter_util = scatter_nd_util;
- exports.selu = selu;
- exports.separableConv2d = separableConv2d;
- exports.sequential = sequential;
- exports.serialization = serialization;
- exports.setBackend = setBackend;
- exports.setPlatform = setPlatform;
- exports.setdiff1dAsync = setdiff1dAsync;
- exports.sigmoid = sigmoid;
- exports.sign = sign;
- exports.signal = signal;
- exports.sin = sin;
- exports.sinh = sinh;
- exports.slice = slice;
- exports.slice1d = slice1d;
- exports.slice2d = slice2d;
- exports.slice3d = slice3d;
- exports.slice4d = slice4d;
- exports.slice_util = slice_util;
- exports.softmax = softmax;
- exports.softplus = softplus;
- exports.spaceToBatchND = spaceToBatchND;
- exports.sparseToDense = sparseToDense;
- exports.spectral = spectral;
- exports.split = split;
- exports.sqrt = sqrt;
- exports.square = square;
- exports.squaredDifference = squaredDifference;
- exports.squaredDifferenceStrict = squaredDifferenceStrict;
- exports.squeeze = squeeze;
- exports.stack = stack;
- exports.step = step;
- exports.stridedSlice = stridedSlice;
- exports.sub = sub;
- exports.subStrict = subStrict;
- exports.sum = sum$1;
- exports.sumOutType = sumOutType;
- exports.tan = tan;
- exports.tanh = tanh$1;
- exports.tensor = tensor;
- exports.tensor1d = tensor1d;
- exports.tensor2d = tensor2d;
- exports.tensor3d = tensor3d;
- exports.tensor4d = tensor4d;
- exports.tensor5d = tensor5d;
- exports.tensor6d = tensor6d;
- exports.tensor_util = tensor_util;
- exports.test_util = test_util;
- exports.tidy = tidy;
- exports.tile = tile;
- exports.time = time;
- exports.topk = topk;
- exports.train = train;
- exports.transpose = transpose;
- exports.truncatedNormal = truncatedNormal;
- exports.unique = unique;
- exports.unregisterGradient = unregisterGradient;
- exports.unregisterKernel = unregisterKernel;
- exports.unsortedSegmentSum = unsortedSegmentSum;
- exports.unstack = unstack;
- exports.upcastType = upcastType;
- exports.util = util;
- exports.valueAndGrad = valueAndGrad;
- exports.valueAndGrads = valueAndGrads;
- exports.variable = variable;
- exports.variableGrads = variableGrads;
- exports.version = version$7;
- exports.version_converter = version$2;
- exports.version_core = version;
- exports.version_layers = version$1;
- exports.where = where;
- exports.whereAsync = whereAsync;
- exports.zeros = zeros;
- exports.zerosLike = zerosLike;
-
- Object.defineProperty(exports, '__esModule', { value: true });
-
-})));
-//# sourceMappingURL=tf.esnext.js.map
diff --git a/demo/browser.js b/demo/browser.js
index 81869c61..82c342bd 100644
--- a/demo/browser.js
+++ b/demo/browser.js
@@ -1,31 +1,32 @@
-/* global QuickSettings */
-
import human from '../dist/human.esm.js';
import draw from './draw.js';
+import Menu from './menu.js';
// ui options
const ui = {
- baseColor: 'rgba(255, 200, 255, 0.3)',
- baseLabel: 'rgba(255, 200, 255, 0.9)',
+ baseColor: 'rgba(173, 216, 230, 0.3)', // this is 'lightblue', just with alpha channel
+ baseLabel: 'rgba(173, 216, 230, 0.9)',
baseFontProto: 'small-caps {size} "Segoe UI"',
baseLineWidth: 16,
baseLineHeightProto: 2,
- columns: 3,
+ columns: 2,
busy: false,
- facing: 'user',
+ facing: true,
useWorker: false,
worker: 'worker.js',
- samples: ['../assets/sample1.jpg', '../assets/sample2.jpg', '../assets/sample3.jpg', '../assets/sample4.jpg', '../assets/sample5.jpg', '../assets/sample6.jpg'],
+ samples: ['../assets/sample6.jpg', '../assets/sample1.jpg', '../assets/sample4.jpg', '../assets/sample5.jpg', '../assets/sample3.jpg', '../assets/sample2.jpg'],
drawBoxes: true,
drawPoints: false,
drawPolygons: true,
fillPolygons: true,
useDepth: true,
console: true,
+ maxFrames: 10,
};
// configuration overrides
const config = {
+ backend: 'webgl', // if you want to use 'wasm' backend, enable script load of tf and tf-backend-wasm in index.html
face: {
enabled: true,
detector: { maxFaces: 10, skipFrames: 10, minConfidence: 0.5, iouThreshold: 0.3, scoreThreshold: 0.7 },
@@ -40,7 +41,7 @@ const config = {
};
// global variables
-let settings;
+let menu;
let worker;
let timeStamp;
const fps = [];
@@ -63,12 +64,11 @@ const log = (...msg) => {
};
// draws processed results and starts processing of a next frame
-async function drawResults(input, result, canvas) {
+function drawResults(input, result, canvas) {
// update fps
- settings.setValue('FPS', Math.round(1000 / (performance.now() - timeStamp)));
fps.push(1000 / (performance.now() - timeStamp));
- if (fps.length > 20) fps.shift();
- settings.setValue('FPS', Math.round(10 * fps.reduce((a, b) => a + b) / fps.length) / 10);
+ if (fps.length > ui.maxFrames) fps.shift();
+ menu.updateChart('FPS', fps);
// eslint-disable-next-line no-use-before-define
requestAnimationFrame(() => runHumanDetect(input, canvas)); // immediate loop
@@ -81,7 +81,7 @@ async function drawResults(input, result, canvas) {
draw.body(result.body, canvas, ui);
draw.hand(result.hand, canvas, ui);
// update log
- const engine = await human.tf.engine();
+ const engine = human.tf.engine();
const memory = `${engine.state.numBytes.toLocaleString()} bytes ${engine.state.numDataBuffers.toLocaleString()} buffers ${engine.state.numTensors.toLocaleString()} tensors`;
const gpu = engine.backendInstance ? `GPU: ${engine.backendInstance.numBytesInGPU.toLocaleString()} bytes` : '';
document.getElementById('log').innerText = `
@@ -98,7 +98,7 @@ async function setupCamera() {
const canvas = document.getElementById('canvas');
const output = document.getElementById('log');
const live = video.srcObject ? ((video.srcObject.getVideoTracks()[0].readyState === 'live') && (video.readyState > 2) && (!video.paused)) : false;
- let msg = `Setting up camera: live: ${live} facing: ${ui.facing}`;
+ let msg = `Setting up camera: live: ${live} facing: ${ui.facing ? 'front' : 'back'}`;
output.innerText += `\n${msg}`;
log(msg);
// setup webcam. note that navigator.mediaDevices requires that page is accessed via https
@@ -112,7 +112,7 @@ async function setupCamera() {
try {
stream = await navigator.mediaDevices.getUserMedia({
audio: false,
- video: { facingMode: ui.facing, width: window.innerWidth, height: window.innerHeight },
+ video: { facingMode: (ui.facing ? 'user' : 'environment'), width: window.innerWidth, height: window.innerHeight },
});
} catch (err) {
output.innerText += '\nCamera permission denied';
@@ -150,7 +150,7 @@ function webWorker(input, image, canvas) {
}
// main processing function when input is webcam, can use direct invocation or web worker
-async function runHumanDetect(input, canvas) {
+function runHumanDetect(input, canvas) {
timeStamp = performance.now();
// perform detect if live video or not video at all
if (input.srcObject) {
@@ -170,36 +170,23 @@ async function runHumanDetect(input, canvas) {
// perform detection in worker
webWorker(input, data, canvas);
} else {
- let result = {};
- try {
- // perform detection
- result = await human.detect(input, config);
- } catch (err) {
- log('Error during execution:', err.message);
- }
- if (result.error) log(result.error);
- else drawResults(input, result, canvas);
+ human.detect(input, config).then((result) => {
+ if (result.error) log(result.error);
+ else drawResults(input, result, canvas);
+ });
}
}
}
// main processing function when input is image, can use direct invocation or web worker
async function processImage(input) {
- const cfg = {
- backend: 'webgl',
- console: true,
- face: {
- enabled: true,
- detector: { maxFaces: 10, skipFrames: 0, minConfidence: 0.1, iouThreshold: 0.3, scoreThreshold: 0.3 },
- mesh: { enabled: true },
- iris: { enabled: true },
- age: { enabled: true, skipFrames: 0 },
- gender: { enabled: true },
- emotion: { enabled: true, minConfidence: 0.1, useGrayscale: true },
- },
- body: { enabled: true, maxDetections: 10, scoreThreshold: 0.7, nmsRadius: 20 },
- hand: { enabled: true, skipFrames: 0, minConfidence: 0.5, iouThreshold: 0.3, scoreThreshold: 0.5 },
- };
+ // must be zero for images
+ config.face.detector.skipFrames = 0;
+ config.face.emotion.skipFrames = 0;
+ config.face.age.skipFrames = 0;
+ config.hand.skipFrames = 0;
+
+ timeStamp = performance.now();
return new Promise((resolve) => {
const image = document.getElementById('image');
image.onload = async () => {
@@ -209,11 +196,13 @@ async function processImage(input) {
image.height = image.naturalHeight;
canvas.width = image.naturalWidth;
canvas.height = image.naturalHeight;
- const result = await human.detect(image, cfg);
- await drawResults(image, result, canvas);
+ const result = await human.detect(image, config);
+ drawResults(image, result, canvas);
const thumb = document.createElement('canvas');
- thumb.width = window.innerWidth / (ui.columns + 0.02);
+ thumb.width = (window.innerWidth - menu.width) / (ui.columns + 0.1);
thumb.height = canvas.height / (window.innerWidth / thumb.width);
+ thumb.style.margin = '8px';
+ thumb.style.boxShadow = '4px 4px 4px 0 dimgrey';
const ctx = thumb.getContext('2d');
ctx.drawImage(canvas, 0, 0, canvas.width, canvas.height, 0, 0, thumb.width, thumb.height);
document.getElementById('samples').appendChild(thumb);
@@ -253,74 +242,68 @@ async function detectSampleImages() {
for (const sample of ui.samples) await processImage(sample);
}
-// setup settings panel
-function setupUI() {
- settings = QuickSettings.create(10, 10, 'Settings', document.getElementById('main'));
- const style = document.createElement('style');
- style.innerHTML = `
- .qs_main { font: 1rem "Segoe UI"; }
- .qs_label { font: 0.8rem "Segoe UI"; }
- .qs_content { background: darkslategray; }
- .qs_container { background: transparent; color: white; margin: 6px; padding: 6px; }
- .qs_checkbox_label { top: 2px; }
- .qs_button { width: -webkit-fill-available; font: 1rem "Segoe UI"; cursor: pointer; }
- `;
- document.getElementsByTagName('head')[0].appendChild(style);
- settings.addButton('Play/Pause WebCam', () => detectVideo());
- settings.addButton('Process Images', () => detectSampleImages());
- settings.addDropDown('Backend', ['webgl', 'wasm', 'cpu'], async (val) => config.backend = val.value);
- settings.addHTML('title', 'Enabled Models'); settings.hideTitle('title');
- settings.addBoolean('Face Detect', config.face.enabled, (val) => config.face.enabled = val);
- settings.addBoolean('Face Mesh', config.face.mesh.enabled, (val) => config.face.mesh.enabled = val);
- settings.addBoolean('Face Iris', config.face.iris.enabled, (val) => config.face.iris.enabled = val);
- settings.addBoolean('Face Age', config.face.age.enabled, (val) => config.face.age.enabled = val);
- settings.addBoolean('Face Gender', config.face.gender.enabled, (val) => config.face.gender.enabled = val);
- settings.addBoolean('Face Emotion', config.face.emotion.enabled, (val) => config.face.emotion.enabled = val);
- settings.addBoolean('Body Pose', config.body.enabled, (val) => config.body.enabled = val);
- settings.addBoolean('Hand Pose', config.hand.enabled, (val) => config.hand.enabled = val);
- settings.addHTML('title', 'Model Parameters'); settings.hideTitle('title');
- settings.addRange('Max Objects', 1, 20, 5, 1, (val) => {
+function setupMenu() {
+ menu = new Menu(document.body);
+ menu.addButton('Start Video', 'Pause Video', (evt) => detectVideo(evt));
+ menu.addButton('Process Images', 'Process Images', () => detectSampleImages());
+
+ menu.addHTML(' ');
+ menu.addLabel('Enabled Models');
+ menu.addBool('Face Detect', config.face, 'enabled');
+ menu.addBool('Face Mesh', config.face.mesh, 'enabled');
+ menu.addBool('Face Iris', config.face.iris, 'enabled');
+ menu.addBool('Face Age', config.face.age, 'enabled');
+ menu.addBool('Face Gender', config.face.gender, 'enabled');
+ menu.addBool('Face Emotion', config.face.emotion, 'enabled');
+ menu.addBool('Body Pose', config.body, 'enabled');
+ menu.addBool('Hand Pose', config.hand, 'enabled');
+
+ menu.addHTML(' ');
+ menu.addLabel('Model Parameters');
+ menu.addRange('Max Objects', config.face.detector, 'maxFaces', 0, 50, 1, (val) => {
config.face.detector.maxFaces = parseInt(val);
config.body.maxDetections = parseInt(val);
+ config.hand.maxHands = parseInt(val);
});
- settings.addRange('Skip Frames', 1, 20, config.face.detector.skipFrames, 1, (val) => {
+ menu.addRange('Skip Frames', config.face.detector, 'skipFrames', 0, 50, 1, (val) => {
config.face.detector.skipFrames = parseInt(val);
config.face.emotion.skipFrames = parseInt(val);
config.face.age.skipFrames = parseInt(val);
config.hand.skipFrames = parseInt(val);
});
- settings.addRange('Min Confidence', 0.1, 1.0, config.face.detector.minConfidence, 0.05, (val) => {
+ menu.addRange('Min Confidence', config.face.detector, 'minConfidence', 0.0, 1.0, 0.05, (val) => {
config.face.detector.minConfidence = parseFloat(val);
config.face.emotion.minConfidence = parseFloat(val);
config.hand.minConfidence = parseFloat(val);
});
- settings.addRange('Score Threshold', 0.1, 1.0, config.face.detector.scoreThreshold, 0.05, (val) => {
+ menu.addRange('Score Threshold', config.face.detector, 'scoreThreshold', 0.1, 1.0, 0.05, (val) => {
config.face.detector.scoreThreshold = parseFloat(val);
config.hand.scoreThreshold = parseFloat(val);
config.body.scoreThreshold = parseFloat(val);
});
- settings.addRange('IOU Threshold', 0.1, 1.0, config.face.detector.iouThreshold, 0.05, (val) => {
+ menu.addRange('IOU Threshold', config.face.detector, 'iouThreshold', 0.1, 1.0, 0.05, (val) => {
config.face.detector.iouThreshold = parseFloat(val);
config.hand.iouThreshold = parseFloat(val);
});
- settings.addHTML('title', 'UI Options'); settings.hideTitle('title');
- settings.addBoolean('Use Web Worker', ui.useWorker, (val) => ui.useWorker = val);
- settings.addBoolean('Camera Front/Back', true, (val) => {
- ui.facing = val ? 'user' : 'environment';
- setupCamera();
- });
- settings.addBoolean('Use 3D Depth', ui.useDepth, (val) => ui.useDepth = val);
- settings.addBoolean('Draw Boxes', ui.drawBoxes, (val) => ui.drawBoxes = val);
- settings.addBoolean('Draw Points', ui.drawPoints, (val) => ui.drawPoints = val);
- settings.addBoolean('Draw Polygons', ui.drawPolygons, (val) => ui.drawPolygons = val);
- settings.addBoolean('Fill Polygons', ui.fillPolygons, (val) => ui.fillPolygons = val);
- settings.addHTML('line1', ' '); settings.hideTitle('line1');
- settings.addRange('FPS', 0, 100, 0, 1);
+
+ menu.addHTML(' ');
+ menu.addLabel('UI Options');
+ menu.addBool('Use Web Worker', ui, 'useWorker');
+ menu.addBool('Camera Front/Back', ui, 'facing', () => setupCamera());
+ menu.addBool('Use 3D Depth', ui, 'useDepth');
+ menu.addBool('Draw Boxes', ui, 'drawBoxes');
+ menu.addBool('Draw Points', ui, 'drawPoints');
+ menu.addBool('Draw Polygons', ui, 'drawPolygons');
+ menu.addBool('Fill Polygons', ui, 'fillPolygons');
+
+ menu.addHTML(' ');
+ menu.addValue('State', '');
+ menu.addChart('FPS', 'FPS');
}
async function main() {
log('Human demo starting ...');
- setupUI();
+ setupMenu();
const msg = `Human ready: version: ${human.version} TensorFlow/JS version: ${human.tf.version_core}`;
document.getElementById('log').innerText += '\n' + msg;
log(msg);
diff --git a/demo/index.html b/demo/index.html
index 3e6e2955..0d9270c8 100644
--- a/demo/index.html
+++ b/demo/index.html
@@ -14,7 +14,6 @@
-
diff --git a/demo/menu.js b/demo/menu.js
new file mode 100644
index 00000000..c6f47c76
--- /dev/null
+++ b/demo/menu.js
@@ -0,0 +1,166 @@
+const css = `
+ .menu-container { display: block; background: darkslategray; position: fixed; top: 0rem; right: 0; width: fit-content; padding: 0 0.8rem 0 0.8rem; line-height: 1.8rem; z-index: 10; max-height: calc(100% - 4rem); }
+ .menu { display: flex; white-space: nowrap; background: darkslategray; padding: 0.2rem; width: max-content; }
+ .menu-title { padding: 0; }
+ .menu-hr { margin: 0.2rem; border: 1px solid rgba(0, 0, 0, 0.5) }
+ .menu-label { width: 1.3rem; height: 0.8rem; cursor: pointer; position: absolute; top: 0.1rem; left: 0.1rem; z-index: 1; background: lightcoral; border-radius: 1rem; transition: left 0.6s ease; }
+
+ .menu-chart-title { align-items: center; }
+ .menu-chart-canvas { background: transparent; height: 40px; width: 180px; margin: 0.2rem 0.2rem 0.2rem 1rem; }
+
+ .menu-button { border: 0; background: lightblue; width: -webkit-fill-available; padding: 8px; margin: 8px 0 8px 0; cursor: pointer; box-shadow: 4px 4px 4px 0 dimgrey; }
+ .menu-button:hover { background: lightgreen; }
+
+ .menu-checkbox { width: 2.8rem; height: 1rem; background: black; margin: 0.5rem 0.8rem 0 0; position: relative; border-radius: 1rem; }
+ .menu-checkbox:after { content: 'OFF'; color: lightcoral; position: absolute; right: 0.2rem; top: -0.4rem; font-weight: 800; font-size: 0.5rem; }
+ .menu-checkbox:before { content: 'ON'; color: lightgreen; position: absolute; left: 0.3rem; top: -0.4rem; font-weight: 800; font-size: 0.5rem; }
+ input[type=checkbox] { visibility: hidden; }
+ input[type=checkbox]:checked + label { left: 1.4rem; background: lightgreen; }
+
+ .menu-range { margin: 0 0.8rem 0 0; width: 5rem; background: transparent; color: lightblue; }
+ .menu-range:before { content: attr(value); color: white; margin: 0 0.4rem 0 0; font-weight: 800; font-size: 0.6rem; position: relative; top: 0.3rem; }
+ input[type=range] { -webkit-appearance: none; }
+ input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 1rem; cursor: pointer; background: black; border-radius: 1rem; border: 1px; }
+ input[type=range]::-webkit-slider-thumb { border: 1px solid #000000; margin-top: 0.05rem; height: 0.9rem; width: 1.5rem; border-radius: 1rem; background: lightblue; cursor: pointer; -webkit-appearance: none; }
+ `;
+
+function createCSS() {
+ const el = document.createElement('style');
+ el.innerHTML = css;
+ document.getElementsByTagName('head')[0].appendChild(el);
+}
+
+function createElem(parent) {
+ const el = document.createElement('div');
+ el.id = 'menu';
+ el.className = 'menu-container';
+ if (typeof parent === 'object') parent.appendChild(el);
+ else document.getElementById(parent).appendChild(el);
+ return el;
+}
+
+class Menu {
+ constructor(parent) {
+ createCSS();
+ this.menu = createElem(parent);
+ this._id = 0;
+ this._maxFPS = 0;
+ }
+
+ get newID() {
+ this._id++;
+ return `menu-${this._id}`;
+ }
+
+ get ID() {
+ return `menu-${this._id}`;
+ }
+
+ get width() {
+ return this.menu.offsetWidth;
+ }
+
+ get height() {
+ return this.menu.offsetHeight;
+ }
+
+ async addLabel(title) {
+ const el = document.createElement('div');
+ el.className = 'menu menu-title';
+ el.id = this.newID;
+ el.innerHTML = title;
+ this.menu.appendChild(el);
+ }
+
+ async addBool(title, object, variable, callback) {
+ const el = document.createElement('div');
+ el.className = 'menu';
+ el.innerHTML = `${title}`;
+ this.menu.appendChild(el);
+ document.getElementById(this.ID).addEventListener('change', (evt) => {
+ object[variable] = evt.target.checked;
+ if (callback) callback(evt.target.checked);
+ });
+ }
+
+ async addRange(title, object, variable, min, max, step, callback) {
+ const el = document.createElement('div');
+ el.className = 'menu';
+ el.innerHTML = `${title}`;
+ this.menu.appendChild(el);
+ document.getElementById(this.ID).addEventListener('change', (evt) => {
+ object[variable] = evt.target.value;
+ evt.target.setAttribute('value', evt.target.value);
+ if (callback) callback(evt.target.value);
+ });
+ }
+
+ async addHTML(html) {
+ const el = document.createElement('div');
+ el.className = 'menu';
+ el.id = this.newID;
+ if (html) el.innerHTML = html;
+ this.menu.appendChild(el);
+ }
+
+ async addButton(titleOn, titleOff, callback) {
+ const el = document.createElement('button');
+ el.className = 'menu menu-button';
+ el.type = 'button';
+ el.id = this.newID;
+ el.innerText = titleOn;
+ this.menu.appendChild(el);
+ document.getElementById(this.ID).addEventListener('click', () => {
+ if (el.innerText === titleOn) el.innerText = titleOff;
+ else el.innerText = titleOn;
+ if (callback) callback(el.innerText !== titleOn);
+ });
+ }
+
+ async addValue(title, val) {
+ const el = document.createElement('div');
+ el.className = 'menu';
+ el.id = title;
+ el.innerText = `${title}: ${val}`;
+ this.menu.appendChild(el);
+ }
+
+ // eslint-disable-next-line class-methods-use-this
+ async updateValue(title, val) {
+ const el = document.getElementById(title);
+ el.innerText = `${title}: ${val}`;
+ }
+
+ async addChart(title, id) {
+ const el = document.createElement('div');
+ el.className = 'menu menu-chart-title';
+ el.id = this.newID;
+ el.innerHTML = `${title}`;
+ this.menu.appendChild(el);
+ }
+
+ // eslint-disable-next-line class-methods-use-this
+ async updateChart(id, values) {
+ if (!values || (values.length === 0)) return;
+ const canvas = document.getElementById(`menu-canvas-${id}`);
+ if (!canvas) return;
+ const ctx = canvas.getContext('2d');
+ ctx.fillStyle = 'darkslategray';
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
+ const width = canvas.width / values.length;
+ const max = 1 + Math.max(...values);
+ const height = canvas.height / max;
+ for (const i in values) {
+ const gradient = ctx.createLinearGradient(0, (max - values[i]) * height, 0, 0);
+ gradient.addColorStop(0.1, 'lightblue');
+ gradient.addColorStop(0.4, 'darkslategray');
+ ctx.fillStyle = gradient;
+ ctx.fillRect(i * width, 0, width - 4, canvas.height);
+ ctx.fillStyle = 'black';
+ ctx.font = '12px "Segoe UI"';
+ ctx.fillText(Math.round(values[i]), i * width, canvas.height - 2, width);
+ }
+ }
+}
+
+export default Menu;