human/dist/human.esm.js

4109 lines
1.8 MiB
JavaScript
Raw Normal View History

2020-11-17 16:18:15 +01:00
var __create=Object.create;var __defProp=Object.defineProperty;var __getProtoOf=Object.getPrototypeOf;var __hasOwnProp=Object.prototype.hasOwnProperty;var __getOwnPropNames=Object.getOwnPropertyNames;var __getOwnPropDesc=Object.getOwnPropertyDescriptor;var __markAsModule=target=>__defProp(target,"__esModule",{value:true});var __commonJS=(callback,module)=>()=>{if(!module){module={exports:{}};callback(module.exports,module)}return module.exports};var __export=(target,all5)=>{__markAsModule(target);for(var name in all5)__defProp(target,name,{get:all5[name],enumerable:true})};var __exportStar=(target,module,desc)=>{__markAsModule(target);if(typeof module==="object"||typeof module==="function"){for(let key of __getOwnPropNames(module))if(!__hasOwnProp.call(target,key)&&key!=="default")__defProp(target,key,{get:()=>module[key],enumerable:!(desc=__getOwnPropDesc(module,key))||desc.enumerable})}return target};var __toModule=module=>{if(module&&module.__esModule)return module;return __exportStar(__defProp(__create(__getProtoOf(module)),"default",{value:module,enumerable:true}),module)};var require_browser=__commonJS(()=>{});var require_alea=__commonJS((exports3,module)=>{(function(global2,module2,define2){function Alea(seed){var me=this,mash=Mash();me.next=function(){var t=2091639*me.s0+me.c*23283064365386963e-26;me.s0=me.s1;me.s1=me.s2;return me.s2=t-(me.c=t|0)};me.c=1;me.s0=mash(" ");me.s1=mash(" ");me.s2=mash(" ");me.s0-=mash(seed);if(me.s0<0){me.s0+=1}me.s1-=mash(seed);if(me.s1<0){me.s1+=1}me.s2-=mash(seed);if(me.s2<0){me.s2+=1}mash=null}function copy(f,t){t.c=f.c;t.s0=f.s0;t.s1=f.s1;t.s2=f.s2;return t}function impl(seed,opts){var xg=new Alea(seed),state6=opts&&opts.state,prng=xg.next;prng.int32=function(){return xg.next()*4294967296|0};prng.double=function(){return prng()+(prng()*2097152|0)*11102230246251565e-32};prng.quick=prng;if(state6){if(typeof state6=="object")copy(state6,xg);prng.state=function(){return copy(xg,{})}}return prng}function Mash(){var n=4022871197;var mash=function(data2){data2=data2.toString();for(var i=0;i<data2.length;i++){n+=data2.charCodeAt(i);var h=.02519603282416938*n;n=h>>>0;h-=n;h*=n;n=h>>>0;h-=n;n+=h*4294967296}return(n>>>0)*23283064365386963e-26};return mash}if(module2&&module2.exports){module2.exports=impl}else if(define2&&define2.amd){define2(function(){return impl})}else{this.alea=impl}})(exports3,typeof module=="object"&&module,typeof define=="function"&&define)});var require_xor128=__commonJS((exports3,module)=>{(function(global2,module2,define2){function XorGen(seed){var me=this,strseed="";me.x=0;me.y=0;me.z=0;me.w=0;me.next=function(){var t=me.x^me.x<<11;me.x=me.y;me.y=me.z;me.z=me.w;return me.w^=me.w>>>19^t^t>>>8};if(seed===(seed|0)){me.x=seed}else{strseed+=seed}for(var k=0;k<strseed.length+64;k++){me.x^=strseed.charCodeAt(k)|0;me.next()}}function copy(f,t){t.x=f.x;t.y=f.y;t.z=f.z;t.w=f.w;return t}function impl(seed,opts){var xg=new XorGen(seed),state6=opts&&opts.state,prng=function(){return(xg.next()>>>0)/4294967296};prng.double=function(){do{var top=xg.next()>>>11,bot=(xg.next()>>>0)/4294967296,result=(top+bot)/(1<<21)}while(result===0);return result};prng.int32=xg.next;prng.quick=prng;if(state6){if(typeof state6=="object")copy(state6,xg);prng.state=function(){return copy(xg,{})}}return prng}if(module2&&module2.exports){module2.exports=impl}else if(define2&&define2.amd){define2(function(){return impl})}else{this.xor128=impl}})(exports3,typeof module=="object"&&module,typeof define=="function"&&define)});var require_xorwow=__commonJS((exports3,module)=>{(function(global2,module2,define2){function XorGen(seed){var me=this,strseed="";me.next=function(){var t=me.x^me.x>>>2;me.x=me.y;me.y=me.z;me.z=me.w;me.w=me.v;return(me.d=me.d+362437|0)+(me.v=me.v^me.v<<4^(t^t<<1))|0};me.x=0;me.y=0;me.z=0;me.w=0;me.v=0;if(seed===(seed|0)){me.x=seed}else{strseed+=seed}for(var k=0;k<strseed.length+64;k++){me.x^=strseed.charCodeAt(k)|0;if(k==strseed.length){me.d=me.x<<10^me.x>>>4}me.next()}}function copy(f,t){t.x=f.x;t.y=f.y;t.z=f.z;t.w=f.w;t.v=f.v;t.d=f.d;return t}function impl(seed,opts){var xg=new Xor
Manifest JSON has weights with names: ${allManifestWeightNames.join(", ")}.`)}const groupIndicesToFetch=groupIndicesToFetchMap.reduce((accumulator,shouldFetch,i)=>{if(shouldFetch){accumulator.push(i)}return accumulator},[]);const fetchUrls=[];groupIndicesToFetch.forEach(i=>{manifest[i].paths.forEach(filepath=>{const fetchUrl=filePathPrefix+(!filePathPrefix.endsWith("/")?"/":"")+filepath;fetchUrls.push(fetchUrl)})});const buffers=await fetchWeightsFunction(fetchUrls);const weightsTensorMap={};let bufferIndexOffset=0;groupIndicesToFetch.forEach(i=>{const numBuffers=manifest[i].paths.length;let groupBytes=0;for(let i2=0;i2<numBuffers;i2++){groupBytes+=buffers[bufferIndexOffset+i2].byteLength}const groupBuffer=new ArrayBuffer(groupBytes);const groupByteBuffer=new Uint8Array(groupBuffer);let groupBufferOffset=0;for(let i2=0;i2<numBuffers;i2++){const buffer11=new Uint8Array(buffers[bufferIndexOffset+i2]);groupByteBuffer.set(buffer11,groupBufferOffset);groupBufferOffset+=buffer11.byteLength}const weightsEntries=groupWeightsToFetch[i];weightsEntries.forEach(weightsEntry=>{const byteBuffer=groupBuffer.slice(weightsEntry.groupOffset,weightsEntry.groupOffset+weightsEntry.sizeBytes);const nameToTensorMap=decodeWeights(byteBuffer,[weightsEntry.manifestEntry]);for(const name in nameToTensorMap){weightsTensorMap[name]=nameToTensorMap[name]}});bufferIndexOffset+=numBuffers});return weightsTensorMap}}const OCTET_STREAM_MIME_TYPE="application/octet-stream";const JSON_TYPE="application/json";class HTTPRequest{constructor(path,loadOptions){this.DEFAULT_METHOD="POST";if(loadOptions==null){loadOptions={}}this.weightPathPrefix=loadOptions.weightPathPrefix;this.onProgress=loadOptions.onProgress;this.weightUrlConverter=loadOptions.weightUrlConverter;if(loadOptions.fetchFunc!=null){assert(typeof loadOptions.fetchFunc==="function",()=>"Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)");this.fetch=loadOptions.fetchFunc}else{this.fetch=env().platform.fetch}assert(path!=null&&path.length>0,()=>"URL path for http must not be null, undefined or empty.");if(Array.isArray(path)){assert(path.length===2,()=>`URL paths for http must have a length of 2, (actual length is ${path.length}).`)}this.path=path;if(loadOptions.requestInit!=null&&loadOptions.requestInit.body!=null){throw new Error("requestInit is expected to have no pre-existing body, but has one.")}this.requestInit=loadOptions.requestInit||{}}async save(modelArtifacts){if(modelArtifacts.modelTopology instanceof ArrayBuffer){throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet.")}const init2=Object.assign({method:this.DEFAULT_METHOD},this.requestInit);init2.body=new FormData;const weightsManifest=[{paths:["./model.weights.bin"],weights:modelArtifacts.weightSpecs}];const modelTopologyAndWeightManifest={modelTopology:modelArtifacts.modelTopology,format:modelArtifacts.format,generatedBy:modelArtifacts.generatedBy,convertedBy:modelArtifacts.convertedBy,userDefinedMetadata:modelArtifacts.userDefinedMetadata,weightsManifest};init2.body.append("model.json",new Blob([JSON.stringify(modelTopologyAndWeightManifest)],{type:JSON_TYPE}),"model.json");if(modelArtifacts.weightData!=null){init2.body.append("model.weights.bin",new Blob([modelArtifacts.weightData],{type:OCTET_STREAM_MIME_TYPE}),"model.weights.bin")}const response=await this.fetch(this.path,init2);if(response.ok){return{modelArtifactsInfo:getModelArtifactsInfoForJSON(modelArtifacts),responses:[response]}}else{throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ${response.status}.`)}}async load(){const modelConfigRequest=await this.fetch(this.path,this.requestInit);if(!modelConfigRequest.ok){throw new Error(`Request to ${this.path} failed with status code ${modelConfigRequest.status}. Please verify this URL points to the model JSON of the model to load.`)}let modelConfig;try{modelConfig=await modelConfigRequest.json()}catch(e){let message=`Failed to parse model JSON of response from ${this.path}.`;if(th
2020-11-16 21:51:46 +01:00
Actual: ${actualFlat}.
Expected: ${expectedFlat}.`)}for(let i=0;i<expectedFlat.length;++i){const a=actualFlat[i];const e=expectedFlat[i];if(!predicate(a,e)){throw new Error(`Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.
Actual: ${actualFlat}.
2020-11-17 16:18:15 +01:00
Expected: ${expectedFlat}.`)}}}function expectPromiseToFail(fn,done){fn().then(()=>done.fail(),()=>done())}function expectArraysEqual(actual,expected){const exp13=typeof expected==="string"||typeof expected==="number"||typeof expected==="boolean"?[expected]:expected;if(isString(actual)||isString(actual[0])||isString(expected)||isString(expected[0])){return expectArraysPredicate(actual,exp13,(a,b)=>a==b)}return expectArraysPredicate(actual,expected,(a,b)=>areClose(a,b,0))}function expectNumbersClose(a,e,epsilon3){if(epsilon3==null){epsilon3=testEpsilon()}if(!areClose(a,e,epsilon3)){throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`)}}function areClose(a,e,epsilon3){if(!isFinite(a)&&!isFinite(e)){return true}if(isNaN(a)||isNaN(e)||Math.abs(a-e)>epsilon3){return false}return true}function expectValuesInRange(actual,low,high){for(let i=0;i<actual.length;i++){if(actual[i]<low||actual[i]>high){throw new Error(`Value out of range:${actual[i]} low: ${low}, high: ${high}`)}}}function expectArrayBuffersEqual(actual,expected){expect(new Float32Array(actual)).toEqual(new Float32Array(expected))}const version="2.7.0";function enableProdMode(){env().set("PROD",true)}function enableDebugMode(){env().set("DEBUG",true)}function disableDeprecationWarnings(){env().set("DEPRECATION_WARNINGS_ENABLED",false);console.warn(`TensorFlow.js deprecation warnings have been disabled.`)}function deprecationWarn(msg){if(env().getBool("DEPRECATION_WARNINGS_ENABLED")){console.warn(msg+" You can disable deprecation warnings with tf.disableDeprecationWarnings().")}}setDeprecationWarningFn(deprecationWarn);function disposeVariables(){ENGINE.disposeVariables()}function engine15(){return ENGINE}function memory(){return ENGINE.memory()}function profile(f){return ENGINE.profile(f)}function tidy(nameOrFn,fn){return ENGINE.tidy(nameOrFn,fn)}function dispose(container2){const tensors=getTensorsInContainer(container2);tensors.forEach(tensor168=>tensor168.dispose())}function keep(result){return ENGINE.keep(result)}function time(f){return ENGINE.time(f)}function setBackend(backendName){return ENGINE.setBackend(backendName)}function ready(){return ENGINE.ready()}function getBackend(){return ENGINE.backendName}function removeBackend(name){ENGINE.removeBackend(name)}function findBackend(name){return ENGINE.findBackend(name)}function findBackendFactory(name){return ENGINE.findBackendFactory(name)}function registerBackend(name,factory,priority=1){return ENGINE.registerBackend(name,factory,priority)}function backend2(){return ENGINE.backend}function setPlatform(platformName,platform){env().setPlatform(platformName,platform)}function add_(a,b){let $a=convertToTensor(a,"a","add");let $b=convertToTensor(b,"b","add");[$a,$b]=makeTypesMatch($a,$b);const forward=(backend3,save)=>{const res=backend3.add($a,$b);save([$a,$b]);return res};const inputs={a:$a,b:$b};return ENGINE.runKernelFunc(forward,inputs,null,Add)}const add2=op({add_});function floorDiv_(a,b){let $a=convertToTensor(a,"a","floorDiv");let $b=convertToTensor(b,"b","floorDiv");[$a,$b]=makeTypesMatch($a,$b);const forward=(backend3,save)=>{const res=backend3.floorDiv($a,$b);save([$a,$b]);return res};const inputs={a:$a,b:$b};return ENGINE.runKernelFunc(forward,inputs,null,FloorDiv)}const floorDiv=op({floorDiv_});function div_(a,b){let $a=convertToTensor(a,"a","div");let $b=convertToTensor(b,"b","div");[$a,$b]=makeTypesMatch($a,$b);if($a.dtype==="int32"&&$b.dtype==="int32"){return floorDiv($a,$b)}const forward=(backend3,save)=>{const res=backend3.realDivide($a,$b);save([$a,$b]);return res};const inputs={a:$a,b:$b};const attrs={};return ENGINE.runKernelFunc(forward,inputs,null,Div,attrs)}const div=op({div_});function mul_(a,b){let $a=convertToTensor(a,"a","mul");let $b=convertToTensor(b,"b","mul");[$a,$b]=makeTypesMatch($a,$b);const forward=(backend3,save)=>{const res=backend3.multiply($a,$b);save([$a,$b]);return res};const inputs={a:$a,b:$b};return ENGINE.runKernelFunc(forward,inputs,null,Multiply)}const mul=op({mul_});function abs_(x){const $x=convertToTensor(x,"x","abs");const inputs={x:$x};return ENGI
with dtype ${tensor168.dtype}. `)}})}const forward=(backend3,save)=>{const $axis=parseAxisParam(axis,$tensors[0].shape)[0];const outShape=computeOutShape2($tensors.map(t=>t.shape),$axis);if(sizeFromShape(outShape)===0){return tensor4([],outShape)}$tensors=$tensors.filter(t=>t.size>0);if($tensors.length===1){return $tensors[0]}const shapes=$tensors.map(t=>t.shape);assertParamsConsistent(shapes,$axis);const res=backend3.concat($tensors,$axis);save($tensors);return res};const inputs=$tensors;const attr={axis};return ENGINE.runKernelFunc(forward,inputs,null,Concat,attr)}const concat=op({concat_});function sigmoid_(x){const $x=convertToTensor(x,"x","sigmoid");const inputs={x:$x};return ENGINE.runKernelFunc((backend3,save)=>{const res=backend3.sigmoid($x);save([res]);return res},inputs,null,Sigmoid)}const sigmoid=op({sigmoid_});function slice_(x,begin,size){const $x=convertToTensor(x,"x","slice");if($x.rank===0){throw new Error("Slicing scalar is not possible")}const forward=(backend3,save)=>{const[begin_,size_]=parseSliceParams($x,begin,size);assertParamsValid($x,begin_,size_);save([$x]);return backend3.slice($x,begin_,size_)};const inputs={x:$x};const attrs={begin,size};return ENGINE.runKernelFunc(forward,inputs,null,Slice,attrs)}const slice=op({slice_});function tanh_(x){const $x=convertToTensor(x,"x","tanh");const inputs={x:$x};return ENGINE.runKernelFunc((backend3,save)=>{const y=backend3.tanh($x);save([y]);return y},inputs,null,Tanh)}const tanh2=op({tanh_});function basicLSTMCell_(forgetBias,lstmKernel,lstmBias,data2,c,h){const $forgetBias=convertToTensor(forgetBias,"forgetBias","basicLSTMCell");const $lstmKernel=convertToTensor(lstmKernel,"lstmKernel","basicLSTMCell");const $lstmBias=convertToTensor(lstmBias,"lstmBias","basicLSTMCell");const $data=convertToTensor(data2,"data","basicLSTMCell");const $c=convertToTensor(c,"c","basicLSTMCell");const $h=convertToTensor(h,"h","basicLSTMCell");const combined=concat([$data,$h],1);const weighted=matMul(combined,$lstmKernel);const res=add2(weighted,$lstmBias);const batchSize=res.shape[0];const sliceCols=res.shape[1]/4;const sliceSize=[batchSize,sliceCols];const i=slice(res,[0,0],sliceSize);const j=slice(res,[0,sliceCols],sliceSize);const f=slice(res,[0,sliceCols*2],sliceSize);const o=slice(res,[0,sliceCols*3],sliceSize);const newC=add2(mul(sigmoid(i),tanh2(j)),mul($c,sigmoid(add2($forgetBias,f))));const newH=mul(tanh2(newC),sigmoid(o));return[newC,newH]}const basicLSTMCell=op({basicLSTMCell_});function batchToSpaceND_(x,blockShape,crops){const $x=convertToTensor(x,"x","batchToSpaceND");const prod5=blockShape.reduce((a,b)=>a*b);assert($x.rank>=1+blockShape.length,()=>`input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);assert(crops.length===blockShape.length,()=>`crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);assert($x.shape[0]%prod5===0,()=>`input tensor batch is ${$x.shape[0]} but is not divisible by the product of the elements of blockShape ${blockShape.join(" * ")} === ${prod5}`);const forward=backend3=>{return backend3.batchToSpaceND($x,blockShape,crops)};const inputs={x:$x};const attrs={blockShape,crops};return ENGINE.runKernelFunc(forward,inputs,null,BatchToSpaceND,attrs)}const batchToSpaceND=op({batchToSpaceND_});function xAs4D(x){let x4D;if(x.rank===0||x.rank===1){x4D=reshape(x,[1,1,1,x.size])}else if(x.rank===2){x4D=reshape(x,[1,1,x.shape[0],x.shape[1]])}else if(x.rank===3){x4D=reshape(x,[1,x.shape[0],x.shape[1],x.shape[2]])}else{x4D=x}return x4D}function batchNorm_(x,mean7,variance,offset,scale2,varianceEpsilon){if(varianceEpsilon==null){varianceEpsilon=.001}const $x=convertToTensor(x,"x","batchNorm");const $mean=convertToTensor(mean7,"mean","batchNorm");const $variance=convertToTensor(variance,"variance","batchNorm");let $scale;if(scale2!=null){$scale=convertToTensor(scale2,"scale","batchNorm")}let $offset;if(offset!=null){$offset=convertToTensor(offset,"offset","batchNorm")}assert($mean.rank===$variance.rank,()=>"Batch normalization gradient requires mean and variance to ha
2020-11-16 21:51:46 +01:00
${inputHeight} and ${blockSize} for depthToSpace with input shape
2020-11-17 16:18:15 +01:00
${$x.shape}`);assert(inputWidth*blockSize>=0,()=>`Negative dimension size caused by overflow when multiplying
2020-11-16 21:51:46 +01:00
${inputWidth} and ${blockSize} for depthToSpace with input shape
2020-11-17 16:18:15 +01:00
${$x.shape}`);assert(inputDepth%(blockSize*blockSize)===0,()=>`Dimension size must be evenly divisible by ${blockSize*blockSize} but is ${inputDepth} for depthToSpace with input shape ${$x.shape}`);const forward=backend3=>backend3.depthToSpace($x,blockSize,dataFormat);const inputs={x:$x};const attrs={blockSize,dataFormat};return ENGINE.runKernelFunc(forward,inputs,null,DepthToSpace,attrs)}const depthToSpace=op({depthToSpace_});function depthwiseConv2d_(x,filter,strides,pad11,dataFormat="NHWC",dilations=[1,1],dimRoundingMode){const $x=convertToTensor(x,"x","depthwiseConv2d");const $filter=convertToTensor(filter,"filter","depthwiseConv2d");let x4D=$x;let reshapedTo4D=false;if($x.rank===3){reshapedTo4D=true;x4D=reshape($x,[1,$x.shape[0],$x.shape[1],$x.shape[2]])}assert(x4D.rank===4,()=>`Error in depthwiseConv2d: input must be rank 4, but got rank ${x4D.rank}.`);assert($filter.rank===4,()=>`Error in depthwiseConv2d: filter must be rank 4, but got rank ${$filter.rank}.`);assert(x4D.shape[3]===$filter.shape[2],()=>`Error in depthwiseConv2d: number of input channels (${x4D.shape[3]}) must match the inChannels dimension in filter ${$filter.shape[2]}.`);if(dimRoundingMode!=null){assert(isInt(pad11),()=>`Error in depthwiseConv2d: pad must be an integer when using, dimRoundingMode ${dimRoundingMode} but got pad ${pad11}.`)}const forward=(backend3,save)=>{if(dilations==null){dilations=[1,1]}assert(eitherStridesOrDilationsAreOne(strides,dilations),()=>`Error in depthwiseConv2d: Either strides or dilations must be 1. Got strides ${strides} and dilations '${dilations}'`);const convInfo=computeConv2DInfo(x4D.shape,$filter.shape,strides,dilations,pad11,dimRoundingMode,true);const res2=backend3.depthwiseConv2D(x4D,$filter,convInfo);save([x4D,$filter]);return res2};const inputs={x:x4D,filter:$filter};const attrs={strides,pad:pad11,dataFormat,dilations,dimRoundingMode};const res=ENGINE.runKernelFunc(forward,inputs,null,DepthwiseConv2dNative,attrs);if(reshapedTo4D){return reshape(res,[res.shape[1],res.shape[2],res.shape[3]])}return res}const depthwiseConv2d=op({depthwiseConv2d_});function diag_(x){const $x=convertToTensor(x,"x","diag");const forward=backend3=>{const flat=reshape($x,[$x.size]);const result=backend3.diag(flat);const outShape=[...x.shape,...x.shape];return reshape(result,outShape)};const inputs={x:$x};return ENGINE.runKernelFunc(forward,inputs,null,Diag)}const diag=op({diag_});function dilation2d_(x,filter,strides,pad11,dilations=[1,1],dataFormat="NHWC"){const $x=convertToTensor(x,"x","dilation2d");const $filter=convertToTensor(filter,"filter","dilation2d");assert($x.rank===3||$x.rank===4,()=>`Error in dilation2d: input must be rank 3 or 4, but got rank ${$x.rank}.`);assert($filter.rank===3,()=>`Error in dilation2d: filter must be rank 3, but got rank ${$filter.rank}.`);assert(dataFormat==="NHWC",()=>`Error in dilation2d: Only NHWC is currently supported, but got dataFormat of ${dataFormat}`);let x4D=$x;let reshapedTo4D=false;if($x.rank===3){x4D=reshape($x,[1,$x.shape[0],$x.shape[1],$x.shape[2]]);reshapedTo4D=true}const inputs={x:x4D,filter:$filter};const attrs={strides,pad:pad11,dilations};const res=ENGINE.runKernel(Dilation2D,inputs,attrs);if(reshapedTo4D){return reshape(res,[res.shape[1],res.shape[2],res.shape[3]])}return res}const dilation2d=op({dilation2d_});function getBroadcastDims(inShape,outShape){const inRank=inShape.length;const dims=[];for(let i=0;i<inRank;i++){const dim=inRank-1-i;const a=inShape[dim]||1;const b=outShape[outShape.length-1-i]||1;if(b>1&&a===1){dims.unshift(dim)}}return dims}function getReductionAxes(inShape,outShape){const result=[];for(let i=0;i<outShape.length;i++){const inDim=inShape[inShape.length-i-1];const outAxis=outShape.length-i-1;const outDim=outShape[outAxis];if(inDim==null||inDim===1&&outDim>1){result.unshift(outAxis)}}return result}function assertAndGetBroadcastShape(shapeA,shapeB){const result=[];const l=Math.max(shapeA.length,shapeB.length);for(let i=0;i<l;i++){let a=shapeA[shapeA.length-i-1];if(a==null){a=1}let b=shapeB[shapeB.length-i-1];if(b==null){b=1}if(a===1){result.uns
rank ${$x.rank}.`);assert(isInt(depthRadius),()=>`Error in localResponseNormalization: depthRadius must be an integer but got depthRadius ${depthRadius}.`);let x4D=$x;let reshapedTo4D=false;if($x.rank===3){reshapedTo4D=true;x4D=reshape($x,[1,$x.shape[0],$x.shape[1],$x.shape[2]])}const forward=(backend3,save)=>{const y=backend3.localResponseNormalization4D(x4D,depthRadius,bias,alpha,beta);save([x4D,y]);return y};const inputs={x:x4D};const attrs={depthRadius,bias,alpha,beta};const res=ENGINE.runKernelFunc(forward,inputs,null,LRN,attrs);if(reshapedTo4D){return reshape(res,[res.shape[1],res.shape[2],res.shape[3]])}else{return res}}const localResponseNormalization=op({localResponseNormalization_});function log_(x){const $x=convertToTensor(x,"x","log");const inputs={x:$x};return ENGINE.runKernelFunc((backend3,save)=>{const res=backend3.log($x);save([$x]);return res},inputs,null,Log)}const log=op({log_});function log1p_(x){const $x=convertToTensor(x,"x","log1p");const inputs={x:$x};return ENGINE.runKernelFunc((backend3,save)=>{const res=backend3.log1p($x);save([$x]);return res},inputs,null,Log1p)}const log1p=op({log1p_});function grad(f){assert(isFunction(f),()=>"The f passed in grad(f) must be a function");return(x,dy)=>{const $x=convertToTensor(x,"x","tf.grad",null);const $dy=dy!=null?convertToTensor(dy,"dy","tf.grad"):null;return ENGINE.tidy(()=>{const{value,grads:grads2}=ENGINE.gradients(()=>f($x),[$x],$dy);if($dy!=null){assertShapesMatch(value.shape,$dy.shape,"The shape of dy passed in grad(f)(x, dy) must match the shape returned by f(x)")}checkGrads(grads2);return grads2[0]})}}function grads(f){assert(isFunction(f),()=>"The f passed in grads(f) must be a function");return(args,dy)=>{assert(Array.isArray(args),()=>"The args passed in grads(f)(args) must be an array of `Tensor`s or `TensorLike`s");const $args=convertToTensorArray(args,"args","tf.grads",null);const $dy=dy!=null?convertToTensor(dy,"dy","tf.grads"):null;return ENGINE.tidy(()=>{const{value,grads:grads2}=ENGINE.gradients(()=>f(...$args),$args,$dy);if($dy!=null){assertShapesMatch(value.shape,$dy.shape,"The shape of dy passed in grads(f)([x1,...], dy) must match the shape returned by f([x1,...])")}checkGrads(grads2);return grads2})}}function valueAndGrad(f){assert(isFunction(f),()=>"The f passed in valueAndGrad(f) must be a function");return(x,dy)=>{assert(x instanceof Tensor,()=>"The x passed in valueAndGrad(f)(x) must be a tensor");assert(dy==null||dy instanceof Tensor,()=>"The dy passed in valueAndGrad(f)(x, dy) must be a tensor");const{grads:grads2,value}=ENGINE.gradients(()=>f(x),[x],dy);checkGrads(grads2);return{grad:grads2[0],value}}}function valueAndGrads(f){assert(isFunction(f),()=>"The f passed in valueAndGrads(f) must be a function");return(args,dy)=>{assert(Array.isArray(args)&&args.every(arg=>arg instanceof Tensor),()=>"The args passed in valueAndGrads(f)(args) must be array of tensors");assert(dy==null||dy instanceof Tensor,()=>"The dy passed in valueAndGrads(f)(args, dy) must be a tensor");const res=ENGINE.gradients(()=>f(...args),args,dy);if(dy!=null){assertShapesMatch(res.value.shape,dy.shape,"The shape of dy passed in valueAndGrads(f)([x1,...], dy) must match the shape returned by f([x1,...])")}checkGrads(res.grads);return res}}function variableGrads(f,varList){assert(isFunction(f),()=>"The f passed in variableGrads(f) must be a function");assert(varList==null||Array.isArray(varList)&&varList.every(v=>v instanceof Variable),()=>"The varList passed in variableGrads(f, varList) must be an array of variables");const specifiedVarList=varList!=null;if(!specifiedVarList){varList=[];for(const varName in ENGINE.registeredVariables){varList.push(ENGINE.registeredVariables[varName])}}const specifiedNonTrainable=specifiedVarList?varList.filter(variable3=>!variable3.trainable):null;const originalVarCount=varList.length;varList=varList.filter(variable3=>variable3.trainable);assert(varList.length>0,()=>`variableGrads() expects at least one of the input variables to be trainable, but none of the ${originalVarCount} variables is trainable.`);cons
the f you passed encloses all operations that lead from x to y.`)}}function neg_(x){const $x=convertToTensor(x,"x","neg");const inputs={x:$x};return ENGINE.runKernelFunc(backend3=>backend3.neg($x),inputs,null,Negate)}const neg=op({neg_});function softplus_(x){const $x=convertToTensor(x,"x","softplus");const inputs={x:$x};return ENGINE.runKernelFunc((backend3,save)=>{const res=backend3.softplus($x);save([$x]);return res},inputs,null,Softplus)}const softplus=op({softplus_});function logSigmoid_(x){const $x=convertToTensor(x,"x","logSigmoid");const customOp=customGrad(x2=>{const value=neg(softplus(neg(x2)));const gradFunc=dy=>{const derX=mul(dy,sigmoid(neg(x2)));return derX};return{value,gradFunc}});return customOp($x)}const logSigmoid=op({logSigmoid_});function max_(x,axis=null,keepDims=false){const $x=convertToTensor(x,"x","max");const forward=(backend3,save)=>{const origAxes=parseAxisParam(axis,$x.shape);let axes=origAxes;const permutedAxes=getAxesPermutation(axes,$x.rank);let maxInput=$x;if(permutedAxes!=null){maxInput=transpose($x,permutedAxes);axes=getInnerMostAxes(axes.length,maxInput.rank)}const y=backend3.max(maxInput,axes);if(permutedAxes!=null){maxInput.dispose()}let res=y;if(keepDims){const expandedShape=expandShapeToKeepDim(res.shape,parseAxisParam(axis,$x.shape));res=reshape(res,expandedShape);y.dispose()}save([$x,res]);return res};const inputs={x:$x};const attrs={reductionIndices:axis,keepDims};return ENGINE.runKernelFunc(forward,inputs,null,Max,attrs)}const max=op({max_});function sub_(a,b){let $a=convertToTensor(a,"a","sub");let $b=convertToTensor(b,"b","sub");[$a,$b]=makeTypesMatch($a,$b);const forward=(backend3,save)=>{const res=backend3.subtract($a,$b);save([$a,$b]);return res};const inputs={a:$a,b:$b};return ENGINE.runKernelFunc(forward,inputs,null,Sub)}const sub=op({sub_});function sum_(x,axis=null,keepDims=false){let $x=convertToTensor(x,"x","sum");if($x.dtype==="bool"){$x=cast($x,"int32")}const forward=(backend3,save)=>{save([$x]);const axes=parseAxisParam(axis,$x.shape);const permutation=getAxesPermutation(axes,$x.rank);let reductionAxes=axes;let permutedX=$x;if(permutation!=null){permutedX=transpose($x,permutation);reductionAxes=getInnerMostAxes(reductionAxes.length,$x.rank)}let value=backend3.sum(permutedX,reductionAxes);if(keepDims){const newShape=expandShapeToKeepDim(value.shape,axes);value=reshape(value,newShape)}return value};const inputs={x:$x};const attrs={axis,keepDims};return ENGINE.runKernelFunc(forward,inputs,null,Sum,attrs)}const sum2=op({sum_});function logSoftmax_(logits,axis=-1){const $logits=convertToTensor(logits,"logits","logSoftmax");if(axis===-1){axis=$logits.rank-1}if(axis!==$logits.rank-1){throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${$logits.rank} and axis was ${axis}`)}const forward=(backend3,save)=>{const keepDims=true;const xMax=max(logits,axis,true);const shifted=sub(logits,xMax);const value=sub(cast(shifted,"float32"),log(sum2(exp(shifted),axis,keepDims)));save([value]);return value};const inputs={logits:$logits};const attrs={axis};return ENGINE.runKernelFunc(forward,inputs,null,LogSoftmax,attrs)}const logSoftmax=op({logSoftmax_});function logSumExp_(x,axis=null,keepDims=false){const $x=convertToTensor(x,"x","logSumExp");const axes=parseAxisParam(axis,$x.shape);const xMax=max($x,axes,true);const a=sub($x,xMax);const b=exp(a);const c=sum2(b,axes);const d=log(c);const res=add2(reshape(xMax,d.shape),d);if(keepDims){const newShape=expandShapeToKeepDim(res.shape,axes);return reshape(res,newShape)}return res}const logSumExp=op({logSumExp_});function logicalAnd_(a,b){const $a=convertToTensor(a,"a","logicalAnd","bool");const $b=convertToTensor(b,"b","logicalAnd","bool");assertAndGetBroadcastShape($a.shape,$b.shape);const inputs={a:$a,b:$b};return ENGINE.runKernelFunc(backend3=>backend3.logicalAnd($a,$b),inputs,null,LogicalAnd)}const logicalAnd=op({logicalAnd_});function logicalNot_(x){const $x=convertToTensor(x,"x","logicalNot","bool");const inputs={x:$x};return ENGINE.runKernelFunc(backend3=>backend3.logicalNot($x),inputs,null,
2020-11-16 21:51:46 +01:00
1. The ${printableModuleName} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
2. The custom ${printableModuleName} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`)}}return fn}else{const config2=identifier;if(config2["className"]==null||config2["config"]==null){throw new ValueError(`${printableModuleName}: Improper config format: ${JSON.stringify(config2)}.
'className' and 'config' must set.`)}const className=config2["className"];let cls,fromConfig;if(className in customObjects){[cls,fromConfig]=customObjects[className]}else if(className in _GLOBAL_CUSTOM_OBJECTS){[cls,fromConfig]=_GLOBAL_CUSTOM_OBJECTS["className"]}else if(className in moduleObjects){[cls,fromConfig]=moduleObjects[className]}if(cls==null){throw new ValueError(`Unknown ${printableModuleName}: ${className}. This may be due to one of the following reasons:
1. The ${printableModuleName} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
2020-11-17 16:18:15 +01:00
2. The custom ${printableModuleName} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`)}if(fromConfig!=null){const customObjectsCombined={};for(const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)){customObjectsCombined[key]=_GLOBAL_CUSTOM_OBJECTS[key]}for(const key of Object.keys(customObjects)){customObjectsCombined[key]=customObjects[key]}const nestedConfig=config2["config"];nestedConfig["customObjects"]=customObjectsCombined;const backupCustomObjects=Object.assign({},_GLOBAL_CUSTOM_OBJECTS);for(const key of Object.keys(customObjects)){_GLOBAL_CUSTOM_OBJECTS[key]=customObjects[key]}convertNDArrayScalarsInConfig(config2["config"]);const returnObj=fromConfig(cls,config2["config"],customObjects,fastWeightInit);_GLOBAL_CUSTOM_OBJECTS=Object.assign({},backupCustomObjects);return returnObj}else{const backupCustomObjects=Object.assign({},_GLOBAL_CUSTOM_OBJECTS);for(const key of Object.keys(customObjects)){_GLOBAL_CUSTOM_OBJECTS[key]=customObjects[key]}const returnObj=new cls(config2["config"]);_GLOBAL_CUSTOM_OBJECTS=Object.assign({},backupCustomObjects);return returnObj}}}function numberCompare(a,b){return a<b?-1:a>b?1:0}function reverseNumberCompare(a,b){return-1*numberCompare(a,b)}function unique5(xs){if(xs==null){return xs}const out=[];for(const x of xs){if(out.indexOf(x)===-1){out.push(x)}}return out}function isObjectEmpty(obj){if(obj==null){throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`)}for(const key in obj){if(obj.hasOwnProperty(key)){return false}}return true}function checkStringTypeUnionValue(values,label,value){if(value==null){return}if(values.indexOf(value)<0){throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`)}}function checkArrayTypeAndLength(x,expectedType,minLength=0,maxLength=Infinity){assert2(minLength>=0);assert2(maxLength>=minLength);return Array.isArray(x)&&x.length>=minLength&&x.length<=maxLength&&x.every(e=>typeof e===expectedType)}function assertPositiveInteger(value,name){if(Array.isArray(value)){util_exports.assert(value.length>0,()=>`${name} is unexpectedly an empty array.`);value.forEach((v,i)=>assertPositiveInteger(v,`element ${i+1} of ${name}`))}else{util_exports.assert(Number.isInteger(value)&&value>0,()=>`Expected ${name} to be a positive integer, but got ${formatAsFriendlyString(value)}.`)}}function formatAsFriendlyString(value){if(value===null){return"null"}else if(Array.isArray(value)){return"["+value.map(v=>formatAsFriendlyString(v)).join(",")+"]"}else if(typeof value==="string"){return`"${value}"`}else{return`${value}`}}function debounce(f,waitMs){let lastTime=util_exports.now();let lastResult;const f2=(...args)=>{const now3=util_exports.now();if(now3-lastTime<waitMs){return lastResult}lastTime=now3;lastResult=f(...args);return lastResult};return f2}function mapActivationToFusedKernel(activationName){if(activationName==="relu"){return"relu"}if(activationName==="linear"){return"linear"}if(activationName==="elu"){return"elu"}return null}function calcL2Norms(w,axis){return tidy(()=>sqrt(sum2(mul(w,w),axis,true)))}class Constraint extends serialization_exports.Serializable{getConfig(){return{}}}class MaxNorm extends Constraint{constructor(args){super();this.defaultMaxValue=2;this.defaultAxis=0;this.maxValue=args.maxValue!=null?args.maxValue:this.defaultMaxValue;this.axis=args.axis!=null?args.axis:this.defaultAxis}apply(w){return tidy(()=>{const norms=calcL2Norms(w,this.axis);const desired=clipByValue(norms,0,this.maxValue);return mul(w,div(desired,add2(epsilon(),norms)))})}getConfig(){return{maxValue:this.maxValue,axis:this.axis}}}MaxNorm.className="MaxNorm";serialization_exports.registerClass(MaxNorm);class UnitNorm extends Constraint{constructor(args){super();this.defaultAxis=0;this.axis=args.axis!=null?args.axis:this.defaultAxis}apply(w){return tidy(()=>div(w,add2(epsilon(),calcL2Norms(w,this.axis))))}getConfig(){return{axis:this.axis}}}UnitNorm.className="UnitNorm";serialization_exports.registerClass(UnitNorm);class NonNeg extends Constraint{apply(w){return relu(w)}}NonNeg.
because the value dtype is ${tensor168.dtype}, but TensorArray dtype is ${this.dtype}.`)}if(this.size()===0&&(this.elementShape==null||this.elementShape.length===0)){this.elementShape=tensor168.shape}assertShapesMatchAllowUndefinedSize(this.elementShape,tensor168.shape,`TensorArray ${this.name}: Could not write to TensorArray index ${index}.`);if(t.read){throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`)}if(t.written){throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`)}t.tensor=tensor168;keep(tensor168);t.written=true;this.tensors[index]=t}writeMany(indices,tensors){if(indices.length!==tensors.length){throw new Error(`TensorArray ${this.name}: could not write multiple tensors,because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`)}indices.forEach((i,index)=>this.write(i,tensors[index]))}gather(indices,dtype){if(!!dtype&&dtype!==this.dtype){throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`)}if(!indices){indices=[];for(let i=0;i<this.size();i++){indices.push(i)}}else{indices=indices.slice(0,this.size())}if(indices.length===0){return tensor4([],[0].concat(this.elementShape))}const tensors=this.readMany(indices);assertShapesMatchAllowUndefinedSize(this.elementShape,tensors[0].shape,"TensorArray shape mismatch: ");return stack(tensors,0)}concat(dtype){if(!!dtype&&dtype!==this.dtype){throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`)}if(this.size()===0){return tensor4([],[0].concat(this.elementShape))}const indices=[];for(let i=0;i<this.size();i++){indices.push(i)}const tensors=this.readMany(indices);assertShapesMatchAllowUndefinedSize(this.elementShape,tensors[0].shape,`TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`);return concat(tensors,0)}scatter(indices,tensor168){if(tensor168.dtype!==this.dtype){throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor168.dtype}`)}if(indices.length!==tensor168.shape[0]){throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor168.shape[0]}`)}const maxIndex=Math.max(...indices);if(!this.dynamicSize&&maxIndex>=this.maxSize){throw new Error(`Max index must be < array size (${maxIndex} vs. ${this.maxSize})`)}this.writeMany(indices,unstack(tensor168,0))}split(length,tensor168){if(tensor168.dtype!==this.dtype){throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor168.dtype}`)}let totalLength=0;const cumulativeLengths=length.map(len=>{totalLength+=len;return totalLength});if(totalLength!==tensor168.shape[0]){throw new Error(`Expected sum of lengths to be equal to
2020-11-16 21:51:46 +01:00
tensor.shape[0], but sum of lengths is
2020-11-17 16:18:15 +01:00
${totalLength}, and tensor's shape is: ${tensor168.shape}`)}if(!this.dynamicSize&&length.length!==this.maxSize){throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), and the TensorArray is not marked as dynamically resizeable`)}const elementPerRow=totalLength===0?0:tensor168.size/totalLength;const tensors=[];tidy(()=>{tensor168=reshape(tensor168,[1,totalLength,elementPerRow]);for(let i=0;i<length.length;++i){const previousLength=i===0?0:cumulativeLengths[i-1];const indices2=[0,previousLength,0];const sizes=[1,length[i],elementPerRow];tensors[i]=reshape(slice(tensor168,indices2,sizes),this.elementShape)}return tensors});const indices=[];for(let i=0;i<length.length;i++){indices[i]=i}this.writeMany(indices,tensors)}}class TensorList{constructor(tensors,elementShape,elementDtype,maxNumElements=-1){this.tensors=tensors;this.elementShape=elementShape;this.elementDtype=elementDtype;if(tensors!=null){tensors.forEach(tensor168=>{if(elementDtype!==tensor168.dtype){throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor168.dtype}`)}assertShapesMatchAllowUndefinedSize(elementShape,tensor168.shape,"TensorList shape mismatch: ");keep(tensor168)})}this.idTensor=scalar(0);this.maxNumElements=maxNumElements;keep(this.idTensor)}get id(){return this.idTensor.id}copy(){return new TensorList([...this.tensors],this.elementShape,this.elementDtype)}clearAndClose(keepIds){this.tensors.forEach(tensor168=>{if(keepIds==null||!keepIds.has(tensor168.id)){tensor168.dispose()}});this.tensors.length=0;this.idTensor.dispose()}size(){return this.tensors.length}stack(elementShape,elementDtype,numElements=-1){if(elementDtype!==this.elementDtype){throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`)}if(numElements!==-1&&this.tensors.length!==numElements){throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`)}assertShapesMatchAllowUndefinedSize(elementShape,this.elementShape,"TensorList shape mismatch: ");return tidy(()=>{const reshapedTensors=this.tensors.map(tensor168=>reshape(tensor168,elementShape));return stack(reshapedTensors,0)})}popBack(elementShape,elementDtype){if(elementDtype!==this.elementDtype){throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`)}if(this.size()===0){throw new Error("Trying to pop from an empty list.")}const tensor168=this.tensors.pop();assertShapesMatchAllowUndefinedSize(tensor168.shape,elementShape,"TensorList shape mismatch: ");return reshape(tensor168,elementShape)}pushBack(tensor168){if(tensor168.dtype!==this.elementDtype){throw new Error(`Invalid data types; op elements ${tensor168.dtype}, but list elements ${this.elementDtype}`)}assertShapesMatchAllowUndefinedSize(tensor168.shape,this.elementShape,"TensorList shape mismatch: ");if(this.maxNumElements===this.size()){throw new Error(`Trying to push element into a full list.`)}keep(tensor168);this.tensors.push(tensor168)}resize(size){if(size<0){throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`)}if(this.maxNumElements!==-1&&size>this.maxNumElements){throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`)}this.tensors.length=size}getItem(elementIndex,elementShape,elementDtype){if(elementDtype!==this.elementDtype){throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`)}if(elementIndex<0||elementIndex>this.tensors.length){throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`)}if(this.tensors[elementIndex]==null){throw new Error(`element at index ${elementIndex} is null.`)}assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape,elementShape,"TensorList shape mismatch: ");return this.tensors[elementIndex]}setItem(elementIndex,tensor168){if(tensor168.dtype!==this.elementDtype){throw new Error(`Invalid data types;
2020-11-16 21:51:46 +01:00
tensor.shape[0], but sum of lengths is
2020-11-17 16:18:15 +01:00
${totalLength}, and tensor's shape is: ${tensor168.shape}`)}const elementPerRow=totalLength===0?0:tensor168.size/totalLength;const tensors=tidy(()=>{const tensors2=[];tensor168=reshape(tensor168,[1,totalLength,elementPerRow]);for(let i=0;i<length.length;++i){const previousLength=i===0?0:cumulativeLengths[i-1];const indices=[0,previousLength,0];const sizes=[1,length[i],elementPerRow];tensors2[i]=reshape(slice(tensor168,indices,sizes),elementShape)}tensor168.dispose();return tensors2});const list=new TensorList([],elementShape,tensor168.dtype,length.length);for(let i=0;i<tensors.length;i++){list.setItem(i,tensors[i])}return list}const executeOp3=async(node,tensorMap,context)=>{switch(node.op){case"If":case"StatelessIf":{const thenFunc=getParamValue("thenBranch",node,tensorMap,context);const elseFunc=getParamValue("elseBranch",node,tensorMap,context);const cond=getParamValue("cond",node,tensorMap,context);const args=getParamValue("args",node,tensorMap,context);const condValue=await cond.data();if(condValue[0]){return context.functionMap[thenFunc].executeFunctionAsync(args,context.tensorArrayMap,context.tensorListMap)}else{return context.functionMap[elseFunc].executeFunctionAsync(args,context.tensorArrayMap,context.tensorListMap)}}case"While":case"StatelessWhile":{const bodyFunc=getParamValue("body",node,tensorMap,context);const condFunc=getParamValue("cond",node,tensorMap,context);const args=getParamValue("args",node,tensorMap,context);const condResult=await context.functionMap[condFunc].executeFunctionAsync(args,context.tensorArrayMap,context.tensorListMap);const argIds=args.map(tensor168=>tensor168.id);let condValue=await condResult[0].data();condResult.forEach(tensor168=>{if(!tensor168.kept&&argIds.indexOf(tensor168.id)===-1){tensor168.dispose()}});let result=args;while(condValue[0]){const origResult=result;result=await context.functionMap[bodyFunc].executeFunctionAsync(result,context.tensorArrayMap,context.tensorListMap);const resultIds=result.map(tensor168=>tensor168.id);origResult.forEach(tensor168=>{if(!tensor168.kept&&argIds.indexOf(tensor168.id)===-1&&resultIds.indexOf(tensor168.id)===-1){tensor168.dispose()}});const condResult2=await context.functionMap[condFunc].executeFunctionAsync(result,context.tensorArrayMap,context.tensorListMap);condValue=await condResult2[0].data();condResult2.forEach(tensor168=>{if(!tensor168.kept&&argIds.indexOf(tensor168.id)===-1&&resultIds.indexOf(tensor168.id)===-1){tensor168.dispose()}})}return result}case"LoopCond":{const pred=getParamValue("pred",node,tensorMap,context);return[cloneTensor(pred)]}case"Switch":{const pred=getParamValue("pred",node,tensorMap,context);let data2=getParamValue("data",node,tensorMap,context);if(!data2.kept){data2=cloneTensor(data2)}return(await pred.data())[0]?[void 0,data2]:[data2,void 0]}case"Merge":{const inputName=node.inputNames.find(name=>getTensor(name,tensorMap,context)!==void 0);if(inputName){const data2=getTensor(inputName,tensorMap,context);return[cloneTensor(data2)]}return void 0}case"Enter":{const frameId=getParamValue("frameName",node,tensorMap,context);const data2=getParamValue("tensor",node,tensorMap,context);context.enterFrame(frameId);return[cloneTensor(data2)]}case"Exit":{const data2=getParamValue("tensor",node,tensorMap,context);context.exitFrame();return[cloneTensor(data2)]}case"NextIteration":{const data2=getParamValue("tensor",node,tensorMap,context);context.nextIteration();return[cloneTensor(data2)]}case"TensorArrayV3":{const size=getParamValue("size",node,tensorMap,context);const dtype=getParamValue("dtype",node,tensorMap,context);const elementShape=getParamValue("elementShape",node,tensorMap,context);const dynamicSize=getParamValue("dynamicSize",node,tensorMap,context);const clearAfterRead=getParamValue("clearAfterRead",node,tensorMap,context);const identicalElementShapes=getParamValue("identicalElementShapes",node,tensorMap,context);const name=getParamValue("name",node,tensorMap,context);const tensorArray=new TensorArray(name,dtype,size,elementShape,identicalElementShapes,dynamicSize,clearAfterRead);context.addTensorArr
${batchSize}`);let size;if(this.size===Infinity||this.size==null){size=this.size}else if(smallLastBatch){size=Math.ceil(this.size/batchSize)}else{size=Math.floor(this.size/batchSize)}return datasetFromIteratorFn(async()=>{return(await base2.iterator()).columnMajorBatch(batchSize,smallLastBatch,deepBatchConcat)},size)}concatenate(dataset5){const base2=this;let size;if(this.size===Infinity||dataset5.size===Infinity){size=Infinity}else if(this.size!=null&&dataset5.size!=null){size=this.size+dataset5.size}else{size=null}return datasetFromIteratorFn(async()=>(await base2.iterator()).concatenate(await dataset5.iterator()),size)}filter(predicate){const base2=this;let size;if(this.size===Infinity){size=Infinity}else{size=null}return datasetFromIteratorFn(async()=>{return(await base2.iterator()).filter(x=>tidy(()=>predicate(x)))},size)}async forEachAsync(f){return(await this.iterator()).forEachAsync(f)}map(transform){const base2=this;return datasetFromIteratorFn(async()=>{return(await base2.iterator()).map(x=>tidy(()=>transform(x)))},this.size)}mapAsync(transform){const base2=this;return datasetFromIteratorFn(async()=>{return(await base2.iterator()).mapAsync(transform)},this.size)}prefetch(bufferSize){if(bufferSize==null){throw new RangeError("`Dataset.prefetch()` requires bufferSize to be specified.")}const base2=this;return datasetFromIteratorFn(async()=>(await base2.iterator()).prefetch(bufferSize),this.size)}repeat(count2){const base2=this;let size;if(this.size!=null&&count2>0){size=this.size*count2}else if(count2===0){size=0}else if(this.size!=null&&(count2===void 0||count2<0)){size=Infinity}else{size=null}return datasetFromIteratorFn(async()=>{const iteratorIterator=iteratorFromFunction(async()=>({value:await base2.iterator(),done:false}));return iteratorFromConcatenated(iteratorIterator.take(count2))},size)}skip(count2){const base2=this;let size;if(this.size!=null&&count2>=0&&this.size>=count2){size=this.size-count2}else if(this.size!=null&&(this.size<count2||count2===void 0||count2<0)){size=0}else{size=null}return datasetFromIteratorFn(async()=>(await base2.iterator()).skip(count2),size)}shuffle(bufferSize,seed,reshuffleEachIteration=true){if(bufferSize==null||bufferSize<0){if(this.size==null){throw new RangeError("`Dataset.shuffle()` requires bufferSize to be specified.")}else{throw new RangeError(`\`Dataset.shuffle()\` requires bufferSize to be specified. If your data fits in main memory (for regular JS objects), and/or GPU memory (for \`tf.Tensor\`s), consider setting bufferSize to the dataset size (${this.size} elements)`)}}const base2=this;const random=seedrandom3.alea(seed||util_exports.now().toString());return datasetFromIteratorFn(async()=>{let seed2=random.int32();if(reshuffleEachIteration){seed2+=random.int32()}return(await base2.iterator()).shuffle(bufferSize,seed2.toString())},this.size)}take(count2){const base2=this;let size;if(this.size!=null&&this.size>count2){size=count2}else if(this.size!=null&&this.size<=count2){size=this.size}else{size=null}return datasetFromIteratorFn(async()=>(await base2.iterator()).take(count2),size)}async toArray(){if(this.size===Infinity){throw new Error("Can not convert infinite data stream to array.")}return(await this.iterator()).toArray()}async toArrayForTest(){if(this.size===Infinity){throw new Error("Can not convert infinite data stream to array.")}return(await this.iterator()).toArrayForTest()}}Dataset.MAX_BUFFER_SIZE=1e4;function datasetFromIteratorFn(iteratorFn,size=null){return new class extends Dataset{constructor(){super(...arguments);this.size=size}async iterator(){return iteratorFn()}}}function array(items){return datasetFromIteratorFn(async()=>iteratorFromItems(items),items.length)}function zip(datasets){if(!isIterable2(datasets)){throw new Error("The argument to zip() must be an object or array.")}let size;if(Array.isArray(datasets)){for(let i=0;i<datasets.length;i++){size=size==null?datasets[i].size:Math.min(size,datasets[i].size)}}else if(datasets instanceof Object){for(const ds in datasets){size=size==null?datasets[ds].size:Math.min(size,datasets[ds].siz
2020-11-16 21:51:46 +01:00
void main() {
${snippets.join("\n ")}
2020-11-17 16:18:15 +01:00
float result = ${operation211};
2020-11-16 21:51:46 +01:00
setOutput(result);
}
2020-11-17 16:18:15 +01:00
`}}class AddNPackedProgram{constructor(outputShape,shapes){this.outputShape=[];this.packedInputs=true;this.packedOutput=true;this.outputShape=outputShape;this.variableNames=shapes.map((_,i)=>`T${i}`);const snippets=[];this.variableNames.forEach(variable3=>{snippets.push(`vec4 v${variable3} = get${variable3}AtOutCoords();`)});const operation211=this.variableNames.map(variable3=>{return`v${variable3}`}).join(" + ");this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
${snippets.join("\n ")}
2020-11-17 16:18:15 +01:00
vec4 result = ${operation211};
2020-11-16 21:51:46 +01:00
setOutput(result);
}
2020-11-17 16:18:15 +01:00
`}}class ArgMinMaxProgram{constructor(reduceInfo,op2,firstPass){this.variableNames=["A"];const{windowSize,batchSize,outSize}=reduceInfo;if(!firstPass){this.variableNames.push("bestIndicesA")}this.outputShape=[batchSize,outSize];const compOp=op2==="max"?">":"<";const indexSnippet=firstPass?"inOffset + i;":"round(getBestIndicesA(batch, inOffset + i));";this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
int bestIndex = inOffset;
float bestValue = getA(batch, bestIndex);
for (int i = 0; i < ${windowSize}; i++) {
int inIdx = ${indexSnippet};
float candidate = getA(batch, inIdx);
if (candidate ${compOp} bestValue) {
bestValue = candidate;
bestIndex = inIdx;
}
}
setOutput(float(bestIndex));
}
2020-11-17 16:18:15 +01:00
`}}function getVecChannels(name,rank){return["x","y","z","w","u","v"].slice(0,rank).map(d=>`${name}.${d}`)}function getChannels(name,rank){if(rank===1){return[name]}return getVecChannels(name,rank)}function getSourceCoords(rank,dims){if(rank===1){return"rc"}let coords2="";for(let i=0;i<rank;i++){coords2+=dims[i];if(i<rank-1){coords2+=","}}return coords2}function getGlslDifferences(){let version20;let attribute;let varyingVs;let varyingFs;let texture2D;let output;let defineOutput;let defineSpecialNaN;let defineSpecialInf;let defineRound;if(env().getNumber("WEBGL_VERSION")===2){version20="#version 300 es";attribute="in";varyingVs="out";varyingFs="in";texture2D="texture";output="outputColor";defineOutput="out vec4 outputColor;";defineSpecialNaN=`
2020-11-16 21:51:46 +01:00
bool isnan_custom(float val) {
return (val > 0.0 || val < 0.0) ? false : val != 0.0;
}
bvec4 isnan_custom(vec4 val) {
return bvec4(isnan_custom(val.x),
isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
}
#define isnan(value) isnan_custom(value)
`;defineSpecialInf=``;defineRound=`
#define round(value) newRound(value)
int newRound(float value) {
return int(floor(value + 0.5));
}
ivec4 newRound(vec4 value) {
return ivec4(floor(value + vec4(0.5)));
}
2020-11-17 16:18:15 +01:00
`}else{version20="";attribute="attribute";varyingVs="varying";varyingFs="varying";texture2D="texture2D";output="gl_FragColor";defineOutput="";defineSpecialNaN=`
2020-11-16 21:51:46 +01:00
#define isnan(value) isnan_custom(value)
bool isnan_custom(float val) {
return (val > 0. || val < 1. || val == 0.) ? false : true;
}
bvec4 isnan_custom(vec4 val) {
return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
}
`;defineSpecialInf=`
uniform float INFINITY;
bool isinf(float val) {
return abs(val) == INFINITY;
}
bvec4 isinf(vec4 val) {
return equal(abs(val), vec4(INFINITY));
}
`;defineRound=`
int round(float value) {
return int(floor(value + 0.5));
}
ivec4 round(vec4 value) {
return ivec4(floor(value + vec4(0.5)));
}
2020-11-17 16:18:15 +01:00
`}return{version:version20,attribute,varyingVs,varyingFs,texture2D,output,defineOutput,defineSpecialNaN,defineSpecialInf,defineRound}}function getLogicalCoordinatesFromFlatIndex(coords2,shape,index="index"){const strides=util_exports.computeStrides(shape);return strides.map((stride,i)=>{const line1=`int ${coords2[i]} = ${index} / ${stride}`;const line2=i===strides.length-1?`int ${coords2[i+1]} = ${index} - ${coords2[i]} * ${stride}`:`index -= ${coords2[i]} * ${stride}`;return`${line1}; ${line2};`}).join("")}function getFlatIndexFrom3D(shape){const strides=util_exports.computeStrides(shape).map(d=>d.toString());return`
2020-11-16 21:51:46 +01:00
int getFlatIndex(ivec3 coords) {
return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}const ENCODE_FLOAT_SNIPPET=`
const float FLOAT_MAX = 1.70141184e38;
const float FLOAT_MIN = 1.17549435e-38;
2020-11-08 18:32:31 +01:00
2020-11-16 21:51:46 +01:00
lowp vec4 encode_float(highp float v) {
if (isnan(v)) {
return vec4(255, 255, 255, 255);
2020-11-10 15:54:07 +01:00
}
2020-11-08 18:32:31 +01:00
2020-11-16 21:51:46 +01:00
highp float av = abs(v);
if(av < FLOAT_MIN) {
return vec4(0.0, 0.0, 0.0, 0.0);
} else if(v > FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
} else if(v < -FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
highp vec4 c = vec4(0,0,0,0);
highp float e = floor(log2(av));
highp float m = exp2(fract(log2(av))) - 1.0;
c[2] = floor(128.0 * m);
m -= c[2] / 128.0;
c[1] = floor(32768.0 * m);
m -= c[1] / 32768.0;
c[0] = floor(8388608.0 * m);
highp float ebias = e + 127.0;
c[3] = floor(ebias / 2.0);
ebias -= c[3] * 2.0;
c[2] += floor(ebias) * 128.0;
c[3] += 128.0 * step(0.0, -v);
return c / 255.0;
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`;const{getBroadcastDims:getBroadcastDims2}=backend_util_exports;function makeShader(inputsInfo,outputShape,userCode,usesPackedTextures){const prefixSnippets=[];inputsInfo.forEach(x=>{const size=util_exports.sizeFromShape(x.shapeInfo.logicalShape);if(x.shapeInfo.isUniform){prefixSnippets.push(`uniform float ${x.name}${size>1?`[${size}]`:""};`)}else{prefixSnippets.push(`uniform sampler2D ${x.name};`);prefixSnippets.push(`uniform int offset${x.name};`)}});const inputPrefixSnippet=prefixSnippets.join("\n");const inputSamplingSnippet=inputsInfo.map(x=>getInputSamplingSnippet(x,outputShape,usesPackedTextures)).join("\n");const outTexShape=outputShape.texShape;const glsl=getGlslDifferences();const floatTextureSampleSnippet=getFloatTextureSampleSnippet(glsl);let outputSamplingSnippet;let floatTextureSetOutputSnippet;let shaderPrefix=getShaderPrefix(glsl);if(outputShape.isPacked){outputSamplingSnippet=getPackedOutputSamplingSnippet(outputShape.logicalShape,outTexShape);floatTextureSetOutputSnippet=getFloatTextureSetRGBASnippet(glsl)}else{outputSamplingSnippet=getOutputSamplingSnippet(outputShape.logicalShape,outTexShape);floatTextureSetOutputSnippet=getFloatTextureSetRSnippet(glsl)}if(usesPackedTextures){shaderPrefix+=SHADER_PACKED_PREFIX}const source=[shaderPrefix,floatTextureSampleSnippet,floatTextureSetOutputSnippet,inputPrefixSnippet,outputSamplingSnippet,inputSamplingSnippet,userCode].join("\n");return source}function getSamplerFromInInfo(inInfo){const shape=inInfo.shapeInfo.logicalShape;switch(shape.length){case 0:return getSamplerScalar(inInfo);case 1:return getSampler1D(inInfo);case 2:return getSampler2D(inInfo);case 3:return getSampler3D(inInfo);case 4:return getSampler4D(inInfo);case 5:return getSampler5D(inInfo);case 6:return getSampler6D(inInfo);default:throw new Error(`${shape.length}-D input sampling is not yet supported`)}}function getPackedSamplerFromInInfo(inInfo){const shape=inInfo.shapeInfo.logicalShape;switch(shape.length){case 0:return getPackedSamplerScalar(inInfo);case 1:return getPackedSampler1D(inInfo);case 2:return getPackedSampler2D(inInfo);case 3:return getPackedSampler3D(inInfo);default:return getPackedSamplerND(inInfo)}}function getInputSamplingSnippet(inInfo,outShapeInfo,usesPackedTextures=false){let res="";if(usesPackedTextures){res+=getPackedSamplerFromInInfo(inInfo)}else{res+=getSamplerFromInInfo(inInfo)}const inShape=inInfo.shapeInfo.logicalShape;const outShape=outShapeInfo.logicalShape;if(inShape.length<=outShape.length){if(usesPackedTextures){res+=getPackedSamplerAtOutputCoords(inInfo,outShapeInfo)}else{res+=getSamplerAtOutputCoords(inInfo,outShapeInfo)}}return res}function getPackedOutputSamplingSnippet(outShape,outTexShape){switch(outShape.length){case 0:return getOutputScalarCoords();case 1:return getOutputPacked1DCoords(outShape,outTexShape);case 2:return getOutputPacked2DCoords(outShape,outTexShape);case 3:return getOutputPacked3DCoords(outShape,outTexShape);default:return getOutputPackedNDCoords(outShape,outTexShape)}}function getOutputSamplingSnippet(outShape,outTexShape){switch(outShape.length){case 0:return getOutputScalarCoords();case 1:return getOutput1DCoords(outShape,outTexShape);case 2:return getOutput2DCoords(outShape,outTexShape);case 3:return getOutput3DCoords(outShape,outTexShape);case 4:return getOutput4DCoords(outShape,outTexShape);case 5:return getOutput5DCoords(outShape,outTexShape);case 6:return getOutput6DCoords(outShape,outTexShape);default:throw new Error(`${outShape.length}-D output sampling is not yet supported`)}}function getFloatTextureSampleSnippet(glsl){return`
2020-11-16 21:51:46 +01:00
float sampleTexture(sampler2D textureSampler, vec2 uv) {
return ${glsl.texture2D}(textureSampler, uv).r;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getFloatTextureSetRSnippet(glsl){return`
void setOutput(float val) {
${glsl.output} = vec4(val, 0, 0, 0);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getFloatTextureSetRGBASnippet(glsl){return`
void setOutput(vec4 val) {
${glsl.output} = val;
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getShaderPrefix(glsl){const SHADER_PREFIX=`${glsl.version}
precision highp float;
precision highp int;
precision highp sampler2D;
${glsl.varyingFs} vec2 resultUV;
${glsl.defineOutput}
const vec2 halfCR = vec2(0.5, 0.5);
struct ivec5
{
int x;
int y;
int z;
int w;
int u;
};
struct ivec6
{
int x;
int y;
int z;
int w;
int u;
int v;
};
uniform float NAN;
${glsl.defineSpecialNaN}
${glsl.defineSpecialInf}
${glsl.defineRound}
2020-11-08 18:32:31 +01:00
2020-11-16 21:51:46 +01:00
int imod(int x, int y) {
return x - y * (x / y);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
int idiv(int a, int b, float sign) {
int res = a / b;
int mod = imod(a, b);
if (sign < 0. && mod != 0) {
res -= 1;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
return res;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
//Based on the work of Dave Hoskins
//https://www.shadertoy.com/view/4djSRW
#define HASHSCALE1 443.8975
float random(float seed){
vec2 p = resultUV * seed;
vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
p3 += dot(p3, p3.yzx + 19.19);
return fract((p3.x + p3.y) * p3.z);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
${SAMPLE_1D_SNIPPET}
${SAMPLE_2D_SNIPPET}
${SAMPLE_3D_SNIPPET}
`;return SHADER_PREFIX}const SAMPLE_1D_SNIPPET=`
vec2 uvFromFlat(int texNumR, int texNumC, int index) {
int texR = index / texNumC;
int texC = index - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
int texelIndex = index / 2;
int texR = texelIndex / texNumC;
int texC = texelIndex - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;const SAMPLE_2D_SNIPPET=`
vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
int texNumC, int row, int col) {
int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
int texR = texelIndex / texNumC;
int texC = texelIndex - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;const SAMPLE_3D_SNIPPET=`
vec2 packedUVfrom3D(int texNumR, int texNumC,
int texelsInBatch, int texelsInLogicalRow, int b,
int row, int col) {
int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
int texR = index / texNumC;
int texC = index - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;const SHADER_PACKED_PREFIX=`
float getChannel(vec4 frag, vec2 innerDims) {
vec2 modCoord = mod(innerDims, 2.);
return modCoord.x == 0. ?
(modCoord.y == 0. ? frag.r : frag.g) :
(modCoord.y == 0. ? frag.b : frag.a);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
float getChannel(vec4 frag, int dim) {
float modCoord = mod(float(dim), 2.);
return modCoord == 0. ? frag.r : frag.g;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`;function getOutputScalarCoords(){return`
int getOutputCoords() {
2020-11-12 18:58:55 +01:00
return 0;
}
2020-11-16 21:51:46 +01:00
`}function getOutputPacked1DCoords(shape,texShape){const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];if(packedTexShape[0]===1){return`
int getOutputCoords() {
return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(packedTexShape[1]===1){return`
int getOutputCoords() {
return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
}
`}return`
int getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getOutput1DCoords(shape,texShape){if(texShape[0]===1){return`
int getOutputCoords() {
return int(resultUV.x * ${texShape[1]}.0);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(texShape[1]===1){return`
int getOutputCoords() {
return int(resultUV.y * ${texShape[0]}.0);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}return`
int getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
return resTexRC.x * ${texShape[1]} + resTexRC.y;
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getOutputPacked3DCoords(shape,texShape){const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];const texelsInLogicalRow=Math.ceil(shape[2]/2);const texelsInBatch=texelsInLogicalRow*Math.ceil(shape[1]/2);return`
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
int b = index / ${texelsInBatch};
index -= b * ${texelsInBatch};
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec3(b, r, c);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getOutput3DCoords(shape,texShape){const coordsFromIndexSnippet=getLogicalCoordinatesFromFlatIndex(["r","c","d"],shape);return`
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
return ivec3(r, c, d);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getOutputPackedNDCoords(shape,texShape){const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];const texelsInLogicalRow=Math.ceil(shape[shape.length-1]/2);const texelsInBatch=texelsInLogicalRow*Math.ceil(shape[shape.length-2]/2);let texelsInBatchN=texelsInBatch;let batches=``;let coords2="b, r, c";for(let b=2;b<shape.length-1;b++){texelsInBatchN*=shape[shape.length-b-1];batches=`
int b${b} = index / ${texelsInBatchN};
index -= b${b} * ${texelsInBatchN};
`+batches;coords2=`b${b}, `+coords2}return`
ivec${shape.length} getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
${batches}
int b = index / ${texelsInBatch};
index -= b * ${texelsInBatch};
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec${shape.length}(${coords2});
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getOutput4DCoords(shape,texShape){const coordsFromIndexSnippet=getLogicalCoordinatesFromFlatIndex(["r","c","d","d2"],shape);return`
ivec4 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
return ivec4(r, c, d, d2);
}
`}function getOutput5DCoords(shape,texShape){const coordsFromIndexSnippet=getLogicalCoordinatesFromFlatIndex(["r","c","d","d2","d3"],shape);return`
ivec5 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
${texShape[1]}));
2020-11-08 18:32:31 +01:00
2020-11-16 21:51:46 +01:00
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
ivec5 outShape = ivec5(r, c, d, d2, d3);
return outShape;
}
`}function getOutput6DCoords(shape,texShape){const coordsFromIndexSnippet=getLogicalCoordinatesFromFlatIndex(["r","c","d","d2","d3","d4"],shape);return`
ivec6 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
ivec6 result = ivec6(r, c, d, d2, d3, d4);
return result;
}
2020-11-17 16:18:15 +01:00
`}function getOutputPacked2DCoords(shape,texShape){const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];if(util_exports.arraysEqual(shape,texShape)){return`
2020-11-16 21:51:46 +01:00
ivec2 getOutputCoords() {
return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const texelsInLogicalRow=Math.ceil(shape[1]/2);return`
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec2(r, c);
}
2020-11-17 16:18:15 +01:00
`}function getOutput2DCoords(shape,texShape){if(util_exports.arraysEqual(shape,texShape)){return`
2020-11-16 21:51:46 +01:00
ivec2 getOutputCoords() {
return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(shape[1]===1){return`
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
return ivec2(index, 0);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(shape[0]===1){return`
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
return ivec2(0, index);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}return`
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
int r = index / ${shape[1]};
int c = index - r * ${shape[1]};
return ivec2(r, c);
}
`}function getFlatOffsetUniformName(texName){return`offset${texName}`}function getPackedSamplerScalar(inputInfo){const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const glsl=getGlslDifferences();return`
vec4 ${funcName}() {
return ${glsl.texture2D}(${texName}, halfCR);
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
`}function getSamplerScalar(inputInfo){const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);if(inputInfo.shapeInfo.isUniform){return`float ${funcName}() {return ${texName};}`}const[texNumR,texNumC]=inputInfo.shapeInfo.texShape;if(texNumR===1&&texNumC===1){return`
float ${funcName}() {
return sampleTexture(${texName}, halfCR);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const[tNumR,tNumC]=inputInfo.shapeInfo.texShape;const offset=getFlatOffsetUniformName(texName);return`
float ${funcName}() {
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
return sampleTexture(${texName}, uv);
}
`}function getPackedSampler1D(inputInfo){const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const texShape=inputInfo.shapeInfo.texShape;const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];const glsl=getGlslDifferences();return`
vec4 ${funcName}(int index) {
vec2 uv = packedUVfrom1D(
${packedTexShape[0]}, ${packedTexShape[1]}, index);
return ${glsl.texture2D}(${texName}, uv);
}
`}function getSampler1D(inputInfo){const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int index) {
${getUniformSampler(inputInfo)}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const texShape=inputInfo.shapeInfo.texShape;const tNumR=texShape[0];const tNumC=texShape[1];if(tNumC===1&&tNumR===1){return`
float ${funcName}(int index) {
return sampleTexture(${texName}, halfCR);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const offset=getFlatOffsetUniformName(texName);if(tNumC===1){return`
float ${funcName}(int index) {
vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(tNumR===1){return`
float ${funcName}(int index) {
vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}return`
float ${funcName}(int index) {
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
return sampleTexture(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}function getPackedSampler2D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const texShape=inputInfo.shapeInfo.texShape;const texNumR=texShape[0];const texNumC=texShape[1];const glsl=getGlslDifferences();if(texShape!=null&&util_exports.arraysEqual(shape,texShape)){return`
2020-11-16 21:51:46 +01:00
vec4 ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
return ${glsl.texture2D}(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];const valuesPerRow=Math.ceil(shape[1]/2);return`
vec4 ${funcName}(int row, int col) {
vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
return ${glsl.texture2D}(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}function getSampler2D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const texShape=inputInfo.shapeInfo.texShape;if(texShape!=null&&util_exports.arraysEqual(shape,texShape)){const texNumR2=texShape[0];const texNumC2=texShape[1];return`
2020-11-16 21:51:46 +01:00
float ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC2}.0, ${texNumR2}.0);
return sampleTexture(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}const{newShape,keptDims}=util_exports.squeezeShape(shape);const squeezedShape=newShape;if(squeezedShape.length<shape.length){const newInputInfo=squeezeInputInfo(inputInfo,squeezedShape);const params=["row","col"];return`
2020-11-16 21:51:46 +01:00
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int row, int col) {
int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
${getUniformSampler(inputInfo)}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const texNumR=texShape[0];const texNumC=texShape[1];const offset=getFlatOffsetUniformName(texName);if(texNumC===1){return`
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`}if(texNumR===1){return`
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
return sampleTexture(${texName}, uv);
}
`}return`
float ${funcName}(int row, int col) {
// Explicitly use integer operations as dot() only works on floats.
int index = row * ${shape[1]} + col + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
}
`}function getPackedSampler3D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const texShape=inputInfo.shapeInfo.texShape;const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];if(shape[0]===1){const squeezedShape=shape.slice(1);const keptDims=[1,2];const newInputInfo=squeezeInputInfo(inputInfo,squeezedShape);const params=["b","row","col"];return`
${getPackedSamplerFromInInfo(newInputInfo)}
vec4 ${funcName}(int b, int row, int col) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
}
`}const texNumR=packedTexShape[0];const texNumC=packedTexShape[1];const valuesPerRow=Math.ceil(shape[2]/2);const texelsInBatch=valuesPerRow*Math.ceil(shape[1]/2);const glsl=getGlslDifferences();return`
vec4 ${funcName}(int b, int row, int col) {
vec2 uv = packedUVfrom3D(
${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
return ${glsl.texture2D}(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}function getSampler3D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const stride0=shape[1]*shape[2];const stride1=shape[2];const{newShape,keptDims}=util_exports.squeezeShape(shape);const squeezedShape=newShape;if(squeezedShape.length<shape.length){const newInputInfo=squeezeInputInfo(inputInfo,squeezedShape);const params=["row","col","depth"];return`
2020-11-16 21:51:46 +01:00
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
}
`}if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int row, int col, int depth) {
int index = round(dot(vec3(row, col, depth),
vec3(${stride0}, ${stride1}, 1)));
${getUniformSampler(inputInfo)}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const texShape=inputInfo.shapeInfo.texShape;const texNumR=texShape[0];const texNumC=texShape[1];const flatOffset=inputInfo.shapeInfo.flatOffset;if(texNumC===stride0&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth) {
float texR = float(row);
float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`}if(texNumC===stride1&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth) {
float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
float texC = float(depth);
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`}const offset=getFlatOffsetUniformName(texName);return`
float ${funcName}(int row, int col, int depth) {
// Explicitly use integer operations as dot() only works on floats.
int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}function getPackedSamplerND(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const rank=shape.length;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const texShape=inputInfo.shapeInfo.texShape;const packedTexShape=[Math.ceil(texShape[0]/2),Math.ceil(texShape[1]/2)];const texNumR=packedTexShape[0];const texNumC=packedTexShape[1];const valuesPerRow=Math.ceil(shape[rank-1]/2);let texelsInBatch=valuesPerRow*Math.ceil(shape[rank-2]/2);let params=`int b, int row, int col`;let index=`b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;for(let b=2;b<rank-1;b++){params=`int b${b}, `+params;texelsInBatch*=shape[rank-b-1];index=`b${b} * ${texelsInBatch} + `+index}const glsl=getGlslDifferences();return`
2020-11-16 21:51:46 +01:00
vec4 ${funcName}(${params}) {
2020-11-17 16:18:15 +01:00
int index = ${index};
2020-11-16 21:51:46 +01:00
int texR = index / ${texNumC};
int texC = index - texR * ${texNumC};
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
return ${glsl.texture2D}(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}function getSampler4D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const stride2=shape[3];const stride1=shape[2]*stride2;const stride0=shape[1]*stride1;const{newShape,keptDims}=util_exports.squeezeShape(shape);if(newShape.length<shape.length){const newInputInfo=squeezeInputInfo(inputInfo,newShape);const params=["row","col","depth","depth2"];return`
2020-11-16 21:51:46 +01:00
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth, int depth2) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int row, int col, int depth, int depth2) {
int index = round(dot(vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, 1)));
${getUniformSampler(inputInfo)}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const flatOffset=inputInfo.shapeInfo.flatOffset;const texShape=inputInfo.shapeInfo.texShape;const texNumR=texShape[0];const texNumC=texShape[1];if(texNumC===stride0&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth, int depth2) {
float texR = float(row);
float texC =
dot(vec3(col, depth, depth2),
vec3(${stride1}, ${stride2}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(texNumC===stride2&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth, int depth2) {
float texR = dot(vec3(row, col, depth),
vec3(${shape[1]*shape[2]}, ${shape[2]}, 1));
float texC = float(depth2);
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const offset=getFlatOffsetUniformName(texName);return`
float ${funcName}(int row, int col, int depth, int depth2) {
// Explicitly use integer operations as dot() only works on floats.
int index = row * ${stride0} + col * ${stride1} +
depth * ${stride2} + depth2;
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
return sampleTexture(${texName}, uv);
}
2020-11-17 16:18:15 +01:00
`}function getSampler5D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const stride3=shape[4];const stride2=shape[3]*stride3;const stride1=shape[2]*stride2;const stride0=shape[1]*stride1;const{newShape,keptDims}=util_exports.squeezeShape(shape);if(newShape.length<shape.length){const newInputInfo=squeezeInputInfo(inputInfo,newShape);const params=["row","col","depth","depth2","depth3"];return`
2020-11-16 21:51:46 +01:00
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
float index = dot(
vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
depth3;
${getUniformSampler(inputInfo)}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const flatOffset=inputInfo.shapeInfo.flatOffset;const texShape=inputInfo.shapeInfo.texShape;const texNumR=texShape[0];const texNumC=texShape[1];if(texNumC===stride0&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
int texR = row;
float texC = dot(vec4(col, depth, depth2, depth3),
vec4(${stride1}, ${stride2}, ${stride3}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(texNumC===stride3&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
float texR = dot(
vec4(row, col, depth, depth2),
vec4(${shape[1]*shape[2]*shape[3]},
${shape[2]*shape[3]}, ${shape[3]}, 1));
int texC = depth3;
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const offset=getFlatOffsetUniformName(texName);return`
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
// Explicitly use integer operations as dot() only works on floats.
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
depth2 * ${stride3} + depth3 + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}function getSampler6D(inputInfo){const shape=inputInfo.shapeInfo.logicalShape;const texName=inputInfo.name;const funcName="get"+texName.charAt(0).toUpperCase()+texName.slice(1);const{newShape,keptDims}=util_exports.squeezeShape(shape);if(newShape.length<shape.length){const newInputInfo=squeezeInputInfo(inputInfo,newShape);const params=["row","col","depth","depth2","depth3","depth4"];return`
2020-11-16 21:51:46 +01:00
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
return ${funcName}(${getSqueezedParams(params,keptDims)});
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}const stride4=shape[5];const stride3=shape[4]*stride4;const stride2=shape[3]*stride3;const stride1=shape[2]*stride2;const stride0=shape[1]*stride1;if(inputInfo.shapeInfo.isUniform){return`
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
int index = round(dot(
vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
dot(
vec2(depth3, depth4),
vec2(${stride4}, 1)));
${getUniformSampler(inputInfo)}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}const flatOffset=inputInfo.shapeInfo.flatOffset;const texShape=inputInfo.shapeInfo.texShape;const texNumR=texShape[0];const texNumC=texShape[1];if(texNumC===stride0&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
int texR = row;
float texC = dot(vec4(col, depth, depth2, depth3),
vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
float(depth4);
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}if(texNumC===stride4&&flatOffset==null){return`
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
float texR = dot(vec4(row, col, depth, depth2),
vec4(${shape[1]*shape[2]*shape[3]*shape[4]},
${shape[2]*shape[3]*shape[4]},
${shape[3]*shape[4]},
${shape[4]})) + float(depth3);
int texC = depth4;
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}const offset=getFlatOffsetUniformName(texName);return`
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
// Explicitly use integer operations as dot() only works on floats.
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}function getUniformSampler(inputInfo){const texName=inputInfo.name;const inSize=util_exports.sizeFromShape(inputInfo.shapeInfo.logicalShape);if(inSize<2){return`return ${texName};`}return`
2020-11-16 21:51:46 +01:00
for (int i = 0; i < ${inSize}; i++) {
if (i == index) {
return ${texName}[i];
2020-11-08 18:32:31 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}function getPackedSamplerAtOutputCoords(inputInfo,outShapeInfo){const texName=inputInfo.name;const texFuncSnippet=texName.charAt(0).toUpperCase()+texName.slice(1);const funcName="get"+texFuncSnippet+"AtOutCoords";const inRank=inputInfo.shapeInfo.logicalShape.length;const outRank=outShapeInfo.logicalShape.length;const broadcastDims=getBroadcastDims2(inputInfo.shapeInfo.logicalShape,outShapeInfo.logicalShape);const type=getCoordsDataType(outRank);const rankDiff=outRank-inRank;let coordsSnippet;const fields=["x","y","z","w","u","v"];if(inRank===0){coordsSnippet=""}else if(outRank<2&&broadcastDims.length>=1){coordsSnippet="coords = 0;"}else{coordsSnippet=broadcastDims.map(d=>`coords.${fields[d+rankDiff]} = 0;`).join("\n")}let unpackedCoordsSnippet="";if(outRank<2&&inRank>0){unpackedCoordsSnippet="coords"}else{unpackedCoordsSnippet=inputInfo.shapeInfo.logicalShape.map((s,i)=>`coords.${fields[i+rankDiff]}`).join(", ")}let output=`return outputValue;`;const inSize=util_exports.sizeFromShape(inputInfo.shapeInfo.logicalShape);const isInputScalar=inSize===1;const outSize=util_exports.sizeFromShape(outShapeInfo.logicalShape);const isOutputScalar=outSize===1;if(inRank===1&&!isInputScalar&&!isOutputScalar){output=`
2020-11-16 21:51:46 +01:00
return vec4(outputValue.xy, outputValue.xy);
`}else if(isInputScalar&&!isOutputScalar){if(outRank===1){output=`
return vec4(outputValue.x, outputValue.x, 0., 0.);
`}else{output=`
return vec4(outputValue.x);
`}}else if(broadcastDims.length){const rows=inRank-2;const cols=inRank-1;if(broadcastDims.indexOf(rows)>-1&&broadcastDims.indexOf(cols)>-1){output=`return vec4(outputValue.x);`}else if(broadcastDims.indexOf(rows)>-1){output=`return vec4(outputValue.x, outputValue.y, outputValue.x, outputValue.y);`}else if(broadcastDims.indexOf(cols)>-1){output=`return vec4(outputValue.xx, outputValue.zz);`}}return`
vec4 ${funcName}() {
${type} coords = getOutputCoords();
${coordsSnippet}
vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
${output}
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}function getSamplerAtOutputCoords(inputInfo,outShapeInfo){const texName=inputInfo.name;const texFuncSnippet=texName.charAt(0).toUpperCase()+texName.slice(1);const funcName="get"+texFuncSnippet+"AtOutCoords";const outTexShape=outShapeInfo.texShape;const inTexShape=inputInfo.shapeInfo.texShape;const inRank=inputInfo.shapeInfo.logicalShape.length;const outRank=outShapeInfo.logicalShape.length;if(!inputInfo.shapeInfo.isUniform&&inRank===outRank&&inputInfo.shapeInfo.flatOffset==null&&util_exports.arraysEqual(inTexShape,outTexShape)){return`
2020-11-16 21:51:46 +01:00
float ${funcName}() {
return sampleTexture(${texName}, resultUV);
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}const type=getCoordsDataType(outRank);const broadcastDims=getBroadcastDims2(inputInfo.shapeInfo.logicalShape,outShapeInfo.logicalShape);const rankDiff=outRank-inRank;let coordsSnippet;const fields=["x","y","z","w","u","v"];if(inRank===0){coordsSnippet=""}else if(outRank<2&&broadcastDims.length>=1){coordsSnippet="coords = 0;"}else{coordsSnippet=broadcastDims.map(d=>`coords.${fields[d+rankDiff]} = 0;`).join("\n")}let unpackedCoordsSnippet="";if(outRank<2&&inRank>0){unpackedCoordsSnippet="coords"}else{unpackedCoordsSnippet=inputInfo.shapeInfo.logicalShape.map((s,i)=>`coords.${fields[i+rankDiff]}`).join(", ")}return`
2020-11-16 21:51:46 +01:00
float ${funcName}() {
${type} coords = getOutputCoords();
${coordsSnippet}
return get${texFuncSnippet}(${unpackedCoordsSnippet});
}
2020-11-17 16:18:15 +01:00
`}function getCoordsDataType(rank){if(rank<=1){return"int"}else if(rank===2){return"ivec2"}else if(rank===3){return"ivec3"}else if(rank===4){return"ivec4"}else if(rank===5){return"ivec5"}else if(rank===6){return"ivec6"}else{throw Error(`GPU for rank ${rank} is not yet supported`)}}function squeezeInputInfo(inInfo,squeezedShape){const newInputInfo=JSON.parse(JSON.stringify(inInfo));newInputInfo.shapeInfo.logicalShape=squeezedShape;return newInputInfo}function getSqueezedParams(params,keptDims){return keptDims.map(d=>params[d]).join(", ")}class ArgMinMaxPackedProgram{constructor(shape,windowSize,op2,firstPass){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;util_exports.assert(shape.length>2,()=>`Packed arg${op2.charAt(0).toUpperCase()+op2.slice(1)} supports only inputs with rank above 2.`);const inSize=shape[shape.length-1];const outSize=Math.ceil(inSize/windowSize);this.outputShape=shape.slice(0,-1);if(outSize>1){this.outputShape.push(outSize)}if(!firstPass){this.variableNames.push("bestIndicesA")}const outShape=this.outputShape;const rank=outShape.length;const dtype=getCoordsDataType(rank);const coords2=getChannels("coords",rank);let sourceLocSetup;let sourceRank;if(outSize===1){sourceRank=rank+1;const sourceLocDType=getCoordsDataType(sourceRank);sourceLocSetup=`
2020-11-16 21:51:46 +01:00
${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords2.join()}, 0);
++${coords2[rank-1]};
${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords2.join()}, 0);
++${coords2[rank-2]};
${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords2.join()}, 0);
--${coords2[rank-1]};
${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords2.join()}, 0);
--${coords2[rank-2]};`}else{sourceRank=rank;sourceLocSetup=`
${dtype} sourceLocR = coords;
++${coords2[rank-1]};
${dtype} sourceLocG = coords;
++${coords2[rank-2]};
${dtype} sourceLocA = coords;
--${coords2[rank-1]};
${dtype} sourceLocB = coords;
2020-11-17 16:18:15 +01:00
--${coords2[rank-2]};`}const channels=["x","y","z","w","u","v"].slice(0,sourceRank);const inChannel="."+channels[sourceRank-1];const intChannels=channels.map(x=>"int "+x);const srcRCoords=getChannels("sourceLocR",sourceRank-1).concat("inIdx.r");const srcGCoords=getChannels("sourceLocG",sourceRank-1).concat("inIdx.g");const srcBCoords=getChannels("sourceLocB",sourceRank-1).concat("inIdx.b");const srcACoords=getChannels("sourceLocA",sourceRank-1).concat("inIdx.a");const compOp=op2==="max"?"greaterThan":"lessThan";const fetchCandidateIdx=firstPass?"":`
2020-11-16 21:51:46 +01:00
inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
getBestIndicesAChannel(${srcGCoords.join()}),
getBestIndicesAChannel(${srcBCoords.join()}),
getBestIndicesAChannel(${srcACoords.join()})));`;const fetchValue=`vec4(
getAChannel(${srcRCoords.join()}),
hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;const getBestIndicesAChannelSnippet=firstPass?"":`
float getBestIndicesAChannel(${intChannels.join()}) {
return getChannel(getBestIndicesA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
}`;this.userCode=`
float getAChannel(${intChannels.join()}) {
return getChannel(getA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
${getBestIndicesAChannelSnippet}
void main() {
${dtype} coords = getOutputCoords();
bool hasNextCol = ${coords2[rank-1]} < ${outShape[rank-1]-1};
bool hasNextRow = ${coords2[rank-2]} < ${outShape[rank-2]-1};
${sourceLocSetup}
ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
ivec4 inIdx = srcIdx;
vec4 bestIndex = vec4(inIdx);
vec4 bestValue = ${fetchValue};
for (int i = 0; i < ${windowSize}; i++) {
inIdx = srcIdx;
${fetchCandidateIdx}
vec4 candidate = ${fetchValue};
bvec4 nan = isnan(candidate);
bvec4 replace = bvec4(
vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
bestValue = vec4(replace.x ? candidate.x : bestValue.x,
replace.y ? candidate.y : bestValue.y,
replace.z ? candidate.z : bestValue.z,
replace.w ? candidate.w : bestValue.w);
bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
srcIdx++;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(bestIndex);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class AvgPool2DBackpropProgram{constructor(convInfo){this.variableNames=["dy"];this.outputShape=convInfo.inShape;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padTop=effectiveFilterHeight-1-convInfo.padInfo.top;const padLeft=effectiveFilterWidth-1-convInfo.padInfo.left;const avgMultiplier=1/(filterHeight*filterWidth);this.userCode=`
const ivec2 pads = ivec2(${padTop}, ${padLeft});
const float avgMultiplier = float(${avgMultiplier});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 dyRCCorner = coords.yz - pads;
int dyRCorner = dyRCCorner.x;
int dyCCorner = dyRCCorner.y;
// Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC+= ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(b, idyR, idyC, d);
dotProd += dyValue * avgMultiplier;
2020-11-08 18:32:31 +01:00
}
2020-11-06 22:21:20 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class AvgPool3DBackpropProgram{constructor(convInfo){this.variableNames=["dy"];this.outputShape=convInfo.inShape;const filterDepth=convInfo.filterDepth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationDepth=convInfo.dilationDepth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const effectiveFilterDepth=convInfo.effectiveFilterDepth;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padFront=effectiveFilterDepth-1-convInfo.padInfo.front;const padTop=effectiveFilterHeight-1-convInfo.padInfo.top;const padLeft=effectiveFilterWidth-1-convInfo.padInfo.left;const avgMultiplier=1/(filterDepth*filterHeight*filterWidth);this.userCode=`
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
const float avgMultiplier = float(${avgMultiplier});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyDCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
// Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get
// dx(xD, xR, xC, ch).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyD = int(dyD);
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
dotProd += dyValue * avgMultiplier;
}
2020-11-12 18:58:55 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}const CHECK_NAN_SNIPPET=`
if (isnan(a)) return a;
if (isnan(b)) return b;
`;const INT_DIV=`
float s = sign(a) * sign(b);
int ia = round(a);
int ib = round(b);
if (ib != 0) {
// Windows (D3D) wants guaranteed non-zero int division at compile-time.
return float(idiv(ia, ib, s));
} else {
return NAN;
}
`;const POW=`
if(a < 0.0 && floor(b) < b){
return NAN;
}
if (b == 0.0) {
return 1.0;
}
return (round(mod(b, 2.0)) != 1) ?
pow(abs(a), b) : sign(a) * pow(abs(a), b);
2020-11-17 16:18:15 +01:00
`;const EQUAL=`return float(a == b);`;const LESS=`return float(a < b);`;const LESS_EQUAL=`return float(a <= b);`;const GREATER=`return float(a > b);`;const GREATER_EQUAL=`return float(a >= b);`;const LOGICAL_AND=`return float(a >= 1.0 && b >= 1.0);`;const LOGICAL_OR=`return float(a >= 1.0 || b >= 1.0);`;const MAX=CHECK_NAN_SNIPPET+`
2020-11-16 21:51:46 +01:00
return max(a, b);
`;const MIN=CHECK_NAN_SNIPPET+`
return min(a, b);
`;const MOD=`if (b == 0.0) return NAN;
2020-11-17 16:18:15 +01:00
return mod(a, b);`;const ELU_DER=`return (b >= 1.0) ? a : a * (b + 1.0);`;const PRELU=`return (a < 0.) ? b * a : a;`;class BinaryOpProgram{constructor(op2,aShape,bShape){this.variableNames=["A","B"];this.outputShape=backend_util_exports.assertAndGetBroadcastShape(aShape,bShape);this.userCode=`
2020-11-16 21:51:46 +01:00
float binaryOperation(float a, float b) {
2020-11-17 16:18:15 +01:00
${op2}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
float a = getAAtOutCoords();
float b = getBAtOutCoords();
setOutput(binaryOperation(a, b));
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}}const CHECK_NAN_SNIPPET2=`
2020-11-16 21:51:46 +01:00
result.r = isNaN.r > 0. ? NAN : result.r;
result.g = isNaN.g > 0. ? NAN : result.g;
result.b = isNaN.b > 0. ? NAN : result.b;
result.a = isNaN.a > 0. ? NAN : result.a;
2020-11-17 16:18:15 +01:00
`;const INT_DIV2=`
2020-11-16 21:51:46 +01:00
ivec4 ia = round(a);
ivec4 ib = round(b);
bvec4 cond = notEqual(ib, ivec4(0));
ivec4 result = ivec4(0);
vec4 s = sign(a) * sign(b);
// Windows (D3D) wants guaranteed non-zero int division at compile-time.
if (cond[0]) {
result[0] = idiv(ia[0], ib[0], s[0]);
}
if (cond[1]) {
result[1] = idiv(ia[1], ib[1], s[1]);
}
if (cond[2]) {
result[2] = idiv(ia[2], ib[2], s[2]);
}
if (cond[3]) {
result[3] = idiv(ia[3], ib[3], s[3]);
}
return vec4(result);
2020-11-17 16:18:15 +01:00
`;const POW2=`
2020-11-16 21:51:46 +01:00
// isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.
vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
vec4 result = multiplier * pow(abs(a), b);
// Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS
bvec4 isExpZero = equal(b, vec4(0.0));
result.r = isExpZero.r ? 1.0 : result.r;
result.g = isExpZero.g ? 1.0 : result.g;
result.b = isExpZero.b ? 1.0 : result.b;
result.a = isExpZero.a ? 1.0 : result.a;
vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));
2020-11-17 16:18:15 +01:00
`+CHECK_NAN_SNIPPET2+`
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const PRELU2=`
2020-11-16 21:51:46 +01:00
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
2020-11-17 16:18:15 +01:00
`;const ELU_DER2=`
2020-11-16 21:51:46 +01:00
vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
2020-11-17 16:18:15 +01:00
`;const EQUAL2=`
2020-11-16 21:51:46 +01:00
return vec4(equal(a, b));
2020-11-17 16:18:15 +01:00
`;const LESS2=`
2020-11-16 21:51:46 +01:00
return vec4(lessThan(a, b));
2020-11-17 16:18:15 +01:00
`;const LESS_EQUAL2=`
2020-11-16 21:51:46 +01:00
return vec4(lessThanEqual(a, b));
2020-11-17 16:18:15 +01:00
`;const GREATER2=`
2020-11-16 21:51:46 +01:00
return vec4(greaterThan(a, b));
2020-11-17 16:18:15 +01:00
`;const GREATER_EQUAL2=`
2020-11-16 21:51:46 +01:00
return vec4(greaterThanEqual(a, b));
2020-11-17 16:18:15 +01:00
`;const LOGICAL_AND2=`
2020-11-16 21:51:46 +01:00
return vec4(
vec4(greaterThanEqual(a, vec4(1.0))) *
vec4(greaterThanEqual(b, vec4(1.0))));
2020-11-17 16:18:15 +01:00
`;const LOGICAL_OR2=`
2020-11-16 21:51:46 +01:00
return min(
vec4(greaterThanEqual(a, vec4(1.0))) +
vec4(greaterThanEqual(b, vec4(1.0))),
vec4(1.0));
2020-11-17 16:18:15 +01:00
`;const MAX2=`
2020-11-16 21:51:46 +01:00
vec4 result = vec4(max(a, b));
vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
2020-11-17 16:18:15 +01:00
`+CHECK_NAN_SNIPPET2+`
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const MIN2=`
2020-11-16 21:51:46 +01:00
vec4 result = vec4(min(a, b));
vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
2020-11-17 16:18:15 +01:00
`+CHECK_NAN_SNIPPET2+`
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const MOD2=`
2020-11-16 21:51:46 +01:00
vec4 result = mod(a, b);
vec4 isNaN = vec4(equal(b, vec4(0.0)));
2020-11-17 16:18:15 +01:00
`+CHECK_NAN_SNIPPET2+`
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;class BinaryOpPackedProgram{constructor(op2,aShape,bShape,checkOutOfBounds=false){this.variableNames=["A","B"];this.supportsBroadcasting=true;this.packedInputs=true;this.packedOutput=true;this.outputShape=backend_util_exports.assertAndGetBroadcastShape(aShape,bShape);const rank=this.outputShape.length;let checkOutOfBoundsString="";if(checkOutOfBounds){if(rank===0||util_exports.sizeFromShape(this.outputShape)===1){checkOutOfBoundsString=`
2020-11-16 21:51:46 +01:00
result.y = 0.;
result.z = 0.;
result.w = 0.;
`}else{const dtype=getCoordsDataType(rank);checkOutOfBoundsString=`
${dtype} coords = getOutputCoords();
`;if(rank===1){checkOutOfBoundsString+=`
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`}else{const channels=getChannels("coords",rank);checkOutOfBoundsString+=`
bool nextRowOutOfBounds =
(${channels[rank-2]} + 1) >= ${this.outputShape[rank-2]};
bool nextColOutOfBounds =
(${channels[rank-1]} + 1) >= ${this.outputShape[rank-1]};
result.y = nextColOutOfBounds ? 0. : result.y;
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`}}}this.userCode=`
vec4 binaryOperation(vec4 a, vec4 b) {
2020-11-17 16:18:15 +01:00
${op2}
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
vec4 a = getAAtOutCoords();
vec4 b = getBAtOutCoords();
vec4 result = binaryOperation(a, b);
${checkOutOfBoundsString}
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class ClipProgram{constructor(aShape){this.variableNames=["A"];this.outputShape=aShape;this.userCode=`
uniform float minVal;
uniform float maxVal;
void main() {
float value = getAAtOutCoords();
if (isnan(value)) {
setOutput(value);
return;
}
setOutput(clamp(value, minVal, maxVal));
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}getCustomSetupFunc(min8,max10){return(gpgpu,webGLProgram)=>{if(this.minLoc==null){this.minLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"minVal");this.maxLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"maxVal")}gpgpu.gl.uniform1f(this.minLoc,min8);gpgpu.gl.uniform1f(this.maxLoc,max10)}}}class ClipPackedProgram{constructor(aShape){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;this.outputShape=aShape;this.userCode=`
2020-11-16 21:51:46 +01:00
uniform float minVal;
uniform float maxVal;
void main() {
vec4 value = getAAtOutCoords();
if (any(isnan(value))) {
setOutput(value);
return;
}
setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}getCustomSetupFunc(min8,max10){return(gpgpu,webGLProgram)=>{if(this.minLoc==null){this.minLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"minVal");this.maxLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"maxVal")}gpgpu.gl.uniform1f(this.minLoc,min8);gpgpu.gl.uniform1f(this.maxLoc,max10)}}}class ComplexAbsProgram{constructor(shape){this.variableNames=["real","imag"];this.outputShape=shape;this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
float re = abs(getRealAtOutCoords());
float im = abs(getImagAtOutCoords());
float mx = max(re, im);
// sadly the length function in glsl is not underflow-safe
// (at least not on Intel GPUs). So the safe solution is
// to ensure underflow-safety in all cases.
setOutput(
mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv2DDerFilterProgram{constructor(convInfo){this.variableNames=["x","dy"];this.outputShape=convInfo.filterShape;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const isChannelsLast=convInfo.dataFormat==="channelsLast";this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int wR = coords.x;
int wC = coords.y;
int d1 = coords.z;
int d2 = coords.w;
// Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
if (${isChannelsLast}) {
float dyValue = getDy(b, yR, yC, d2);
float xValue = getX(b, xR, xC, d1);
dotProd += (xValue * dyValue);
} else {
float dyValue = getDy(b, d2, yR, yC);
float xValue = getX(b, d1, xR, xC);
dotProd += (xValue * dyValue);
}
}
2020-11-08 18:32:31 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv2DDerInputProgram{constructor(convInfo){this.variableNames=["dy","W"];this.outputShape=convInfo.inShape;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const isChannelsLast=convInfo.dataFormat==="channelsLast";const padTop=filterHeight-1-convInfo.padInfo.top;const padLeft=filterWidth-1-convInfo.padInfo.left;const rowDim=isChannelsLast?1:2;const colDim=isChannelsLast?2:3;const channelDim=isChannelsLast?3:1;this.userCode=`
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d1 = coords[${channelDim}];
ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
int dyRCorner = dyCorner.x;
int dyCCorner = dyCorner.y;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
if (${isChannelsLast}) {
float xValue = getDy(batch, idyR, idyC, d2);
float wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
} else {
float xValue = getDy(batch, d2, idyR, idyC);
float wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
}
2020-11-12 18:58:55 +01:00
}
2020-11-10 02:13:38 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv3DDerFilterProgram{constructor(convInfo){this.variableNames=["x","dy"];this.outputShape=convInfo.filterShape;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const padFront=convInfo.padInfo.front;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;this.userCode=`
void main() {
ivec5 coords = getOutputCoords();
int wF = coords.x;
int wR = coords.y;
int wC = coords.z;
int d1 = coords.w;
int d2 = coords.u;
float dotProd = 0.0;
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
int xF = wF + yF * ${strideDepth} - ${padFront};
if (xF < 0 || xF >= ${convInfo.inDepth}) {
continue;
}
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float dyValue = getDy(b, yF, yR, yC, d2);
float xValue = getX(b, xF, xR, xC, d1);
dotProd += (xValue * dyValue);
}
2020-11-08 18:32:31 +01:00
}
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv3DDerInputProgram{constructor(convInfo){this.variableNames=["dy","W"];this.outputShape=convInfo.inShape;const filterDepth=convInfo.filterDepth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const padFront=filterDepth-1-convInfo.padInfo.front;const padTop=filterHeight-1-convInfo.padInfo.top;const padLeft=filterWidth-1-convInfo.padInfo.left;this.userCode=`
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int d1 = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyFCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
float dotProd = 0.0;
for (int wF = 0; wF < ${filterDepth}; wF++) {
float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
continue;
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
int idyF = int(dyF);
int wFPerm = ${filterDepth} - 1 - wF;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
float xValue = getDy(batch, idyF, idyR, idyC, d2);
float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
}
}
2020-11-12 18:17:57 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class DepthwiseConv2DDerFilterProgram{constructor(convInfo){this.variableNames=["x","dy"];this.outputShape=convInfo.filterShape;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const channelMul=convInfo.outChannels/convInfo.inChannels;this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int wR = coords.x;
int wC = coords.y;
int d1 = coords.z;
int dm = coords.w;
int d2 = d1 * ${channelMul} + dm;
float dotProd = 0.0;
// TO DO: Vec4 over the batch size
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float dyValue = getDy(b, yR, yC, d2);
float xValue = getX(b, xR, xC, d1);
dotProd += (xValue * dyValue);
}
2020-11-12 18:58:55 +01:00
}
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class DepthwiseConv2DDerInputProgram{constructor(convInfo){this.variableNames=["dy","W"];this.outputShape=convInfo.inShape;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const padTop=filterHeight-1-convInfo.padInfo.top;const padLeft=filterWidth-1-convInfo.padInfo.left;const channelMul=convInfo.outChannels/convInfo.inChannels;this.userCode=`
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d1 = coords[3];
ivec2 dyCorner = coords.yz - pads;
int dyRCorner = dyCorner.x;
int dyCCorner = dyCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
// TO DO: Vec4 over the channelMul
for (int dm = 0; dm < ${channelMul}; dm++) {
int d2 = d1 * ${channelMul} + dm;
float xValue = getDy(batch, idyR, idyC, d2);
float wValue = getW(wRPerm, wCPerm, d1, dm);
dotProd += xValue * wValue;
2020-11-12 18:58:55 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv2DProgram{constructor(convInfo,addBias=false,activation2=null,hasPreluActivationWeights=false){this.variableNames=["x","W"];this.outputShape=convInfo.outShape;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const inputDepthNearestVec4=Math.floor(convInfo.inChannels/4)*4;const inputDepthVec4Remainder=convInfo.inChannels%4;const isChannelsLast=convInfo.dataFormat==="channelsLast";const rowDim=isChannelsLast?1:2;const colDim=isChannelsLast?2:3;const channelDim=isChannelsLast?3:1;let activationSnippet="",applyActivationSnippet="";if(activation2){if(hasPreluActivationWeights){activationSnippet=`float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation2}
}`}else{activationSnippet=`
float activation(float x) {
${activation2}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}applyActivationSnippet=`result = activation(result);`}const addBiasSnippet=addBias?"result += getBiasAtOutCoords();":"";if(addBias){this.variableNames.push("bias")}if(hasPreluActivationWeights){this.variableNames.push("preluActivationWeights")}this.userCode=`
${activationSnippet}
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d2 = coords[${channelDim}];
ivec2 xRCCorner =
ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
vec4 wValues = vec4(
getW(wR, wC, d1, d2),
getW(wR, wC, d1 + 1, d2),
getW(wR, wC, d1 + 2, d2),
getW(wR, wC, d1 + 3, d2)
);
if (${isChannelsLast}) {
vec4 xValues = vec4(
getX(batch, xR, xC, d1),
getX(batch, xR, xC, d1 + 1),
getX(batch, xR, xC, d1 + 2),
getX(batch, xR, xC, d1 + 3)
);
dotProd += dot(xValues, wValues);
} else {
vec4 xValues = vec4(
getX(batch, d1, xR, xC),
getX(batch, d1 + 1, xR, xC),
getX(batch, d1 + 2, xR, xC),
getX(batch, d1 + 3, xR, xC)
);
dotProd += dot(xValues, wValues);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
}
if (${inputDepthVec4Remainder===1}) {
if (${isChannelsLast}) {
dotProd +=
getX(batch, xR, xC, ${inputDepthNearestVec4}) *
getW(wR, wC, ${inputDepthNearestVec4}, d2);
} else {
dotProd +=
getX(batch, ${inputDepthNearestVec4}, xR, xC) *
getW(wR, wC, ${inputDepthNearestVec4}, d2);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
} else if (${inputDepthVec4Remainder===2}) {
vec2 wValues = vec2(
getW(wR, wC, ${inputDepthNearestVec4}, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
);
if (${isChannelsLast}) {
vec2 xValues = vec2(
getX(batch, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
);
dotProd += dot(xValues, wValues);
} else {
vec2 xValues = vec2(
getX(batch, ${inputDepthNearestVec4}, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
);
dotProd += dot(xValues, wValues);
}
} else if (${inputDepthVec4Remainder===3}) {
vec3 wValues = vec3(
getW(wR, wC, ${inputDepthNearestVec4}, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
);
if (${isChannelsLast}) {
vec3 xValues = vec3(
getX(batch, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
);
dotProd += dot(xValues, wValues);
} else {
vec3 xValues = vec3(
getX(batch, ${inputDepthNearestVec4}, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
);
dotProd += dot(xValues, wValues);
}
}
2020-11-08 18:32:31 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Conv3DProgram{constructor(convInfo){this.variableNames=["x","W"];this.outputShape=convInfo.outShape;const padFront=convInfo.padInfo.front;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationDepth=convInfo.dilationDepth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const filterDepth=convInfo.filterDepth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const inputDepthNearestVec4=Math.floor(convInfo.inChannels/4)*4;const inputDepthVec4Remainder=convInfo.inChannels%4;this.userCode=`
const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int d2 = coords.u;
ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xFCorner = xFRCCorner.x;
int xRCorner = xFRCCorner.y;
int xCCorner = xFRCCorner.z;
// Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get
// y(yF, yR, yC, d2). ? = to be determined. : = across all
// values in that axis.
float dotProd = 0.0;
for (int wF = 0; wF < ${filterDepth}; wF++) {
int xF = xFCorner + wF * ${dilationDepth};
if (xF < 0 || xF >= ${convInfo.inDepth}) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
vec4 xValues = vec4(
getX(batch, xF, xR, xC, d1),
getX(batch, xF, xR, xC, d1 + 1),
getX(batch, xF, xR, xC, d1 + 2),
getX(batch, xF, xR, xC, d1 + 3)
);
vec4 wValues = vec4(
getW(wF, wR, wC, d1, d2),
getW(wF, wR, wC, d1 + 1, d2),
getW(wF, wR, wC, d1 + 2, d2),
getW(wF, wR, wC, d1 + 3, d2)
);
dotProd += dot(xValues, wValues);
}
if (${inputDepthVec4Remainder===1}) {
dotProd +=
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
} else if (${inputDepthVec4Remainder===2}) {
vec2 xValues = vec2(
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
);
vec2 wValues = vec2(
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
);
dotProd += dot(xValues, wValues);
} else if (${inputDepthVec4Remainder===3}) {
vec3 xValues = vec3(
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
);
vec3 wValues = vec3(
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
);
dotProd += dot(xValues, wValues);
2020-11-12 18:58:55 +01:00
}
}
2020-11-06 22:21:20 +01:00
}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class DepthwiseConv2DProgram{constructor(convInfo,addBias=false,activation2=null,hasPreluActivation=false){this.variableNames=["x","W"];this.outputShape=convInfo.outShape;const xNumRows=convInfo.inHeight;const xNumCols=convInfo.inWidth;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const channelMul=convInfo.outChannels/convInfo.inChannels;let activationSnippet="",applyActivationSnippet="";if(activation2){if(hasPreluActivation){activationSnippet=`float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation2}
}`}else{activationSnippet=`
float activation(float x) {
${activation2}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}applyActivationSnippet=`result = activation(result);`}const addBiasSnippet=addBias?"result += getBiasAtOutCoords();":"";if(addBias){this.variableNames.push("bias")}if(hasPreluActivation){this.variableNames.push("preluActivationWeights")}this.userCode=`
${activationSnippet}
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords.w;
int d1 = d2 / ${channelMul};
int q = d2 - d1 * ${channelMul};
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
// TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${xNumRows}) {
continue;
}
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${xNumCols}) {
continue;
}
float xVal = getX(batch, xR, xC, d1);
float wVal = getW(wR, wC, d1, q);
dotProd += xVal * wVal;
2020-11-12 18:58:55 +01:00
}
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`}}class DepthwiseConvPacked2DProgram{constructor(convInfo,addBias=false,activation2=null,hasPreluActivation=false){this.variableNames=["x","W"];this.packedInputs=true;this.packedOutput=true;this.outputShape=convInfo.outShape;const xNumRows=convInfo.inHeight;const xNumCols=convInfo.inWidth;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const filterHeight=convInfo.filterHeight;const filterWidth=convInfo.filterWidth;const texelsAcross=filterWidth;let mainLoop=`int xR; int xC; int xCOffset;`;for(let r=0;r<filterHeight;r++){for(let c=0;c<filterWidth;c++){mainLoop+=`
vec4 xTexelR${r}C${c*2} = vec4(0.);
vec4 wR${r}C${c} = vec4(0.);
vec4 xR${r}C${c} = vec4(0.);`}}for(let r=0;r<filterHeight;r++){for(let texelC=0;texelC<texelsAcross;texelC++){const c=texelC*2;mainLoop+=`
xR = xRCorner + ${r*dilationHeight};
xC = xCCorner + ${c*dilationWidth};
`;if(strideWidth===1){if(c<filterWidth){if(padLeft%2===1){mainLoop+=`
xCOffset = xC + 1;
if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
// Need to manually clear unused channels in case
// we're reading from recycled texture.
if(xCOffset + 1 >= ${xNumCols}) {
xTexelR${r}C${c}.zw = vec2(0.);
}
} else {
xTexelR${r}C${c} = vec4(0.);
}
xCOffset = xC + 1 - 2;
if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${xNumCols}) {
vec4 previous = getX(batch, xR, xCOffset, d1);
// Need to manually clear unused channels in case
// we're reading from recycled texture.
if(xCOffset + 1 >= ${xNumCols}) {
previous.zw = vec2(0.);
}
xR${r}C${c} = vec4(previous.zw, xTexelR${r}C${c}.xy);
} else {
xR${r}C${c} = vec4(0, 0, xTexelR${r}C${c}.xy);
}
`}else{mainLoop+=`
if(xR >= 0 && xR < ${xNumRows} && xC >= 0 && xC < ${xNumCols}) {
xTexelR${r}C${c} = getX(batch, xR, xC, d1);
} else {
xTexelR${r}C${c} = vec4(0.);
}
xR${r}C${c} = xTexelR${r}C${c};
2020-11-17 16:18:15 +01:00
`}if(c+1<filterWidth){const nextTexelOffset=padLeft%2===0?util_exports.nearestLargerEven(dilationWidth):dilationWidth;if(dilationWidth%2===0&&padLeft%2===1||dilationWidth%2!==0&&padLeft%2!==1){mainLoop+=`
2020-11-16 21:51:46 +01:00
xCOffset = xC + ${padLeft%2} + ${nextTexelOffset};
if(xR >= 0 && xR < ${xNumRows} &&
xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c+2} = getX(batch, xR, xCOffset, d1);
}
`;if(dilationWidth>1){mainLoop+=`
xCOffset -= 2;
if(xR >= 0 && xR < ${xNumRows} &&
xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
} else {
xTexelR${r}C${c} = vec4(0.);
}
`}mainLoop+=`
xR${r}C${c+1} = vec4(
xTexelR${r}C${c}.zw, xTexelR${r}C${c+2}.xy);
`}else{mainLoop+=`
xCOffset = xC + ${nextTexelOffset};
if(xR >= 0 && xR < ${xNumRows} &&
xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c+2} = getX(batch, xR, xCOffset, d1);
}
xR${r}C${c+1} = xTexelR${r}C${c+2};
`}}}}else{if(c<filterWidth){mainLoop+=`
if(xR >= 0 && xR < ${xNumRows}) {
`;if(padLeft%2===1){mainLoop+=`
xCOffset = xC + 1 - ${strideWidth};
if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);
} else {
xTexelR${r}C${c} = vec4(0.);
}
if(xC + 1 >= 0 && xC + 1 < ${xNumCols}) {
xTexelR${r}C${c+2} = getX(batch, xR, xC + 1, d1);
} else {
xTexelR${r}C${c+2} = vec4(0.);
}
xR${r}C${c} = vec4(
xTexelR${r}C${c}.zw, xTexelR${r}C${c+2}.zw);
`;if(c+1<filterWidth){mainLoop+=`
vec4 final = vec4(0.);
xCOffset = xC + 1 + ${strideWidth};
if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
final = getX(batch, xR, xCOffset, d1);
}
xR${r}C${c+1} = vec4(xTexelR${r}C${c+2}.xy, final.xy);
`}}else{mainLoop+=`
if(xC >= 0 && xC < ${xNumCols}) {
xTexelR${r}C${c} = getX(batch, xR, xC, d1);
} else {
xTexelR${r}C${c} = vec4(0.);
}
xCOffset = xC + ${strideWidth};
if(xCOffset >= 0 && xCOffset < ${xNumCols}) {
xTexelR${r}C${c+2} = getX(batch, xR, xCOffset, d1);
} else {
xTexelR${r}C${c+2} = vec4(0.);
}
xR${r}C${c} = vec4(
xTexelR${r}C${c}.xy, xTexelR${r}C${c+2}.xy);
`;if(c+1<filterWidth){mainLoop+=`
xR${r}C${c+1} = vec4(
xTexelR${r}C${c}.zw, xTexelR${r}C${c+2}.zw);
`}}mainLoop+=`}`}}if(c<filterWidth){mainLoop+=`
vec4 wTexelR${r}C${c} = getW(${r}, ${c}, d1, q);
wR${r}C${c} = vec4(wTexelR${r}C${c}.xz, wTexelR${r}C${c}.xz);
`;if(c+1<filterWidth){mainLoop+=`
vec4 wTexelR${r}C${c+1} = getW(${r}, ${c+1}, d1, q);
wR${r}C${c+1} =
vec4(wTexelR${r}C${c+1}.xz, wTexelR${r}C${c+1}.xz);`}}}}for(let r=0;r<filterHeight;r++){for(let c=0;c<filterWidth;c++){mainLoop+=`dotProd += xR${r}C${c} * wR${r}C${c};`}}let activationSnippet="",applyActivationSnippet="";if(activation2){if(hasPreluActivation){activationSnippet=`vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation2}
}`}else{activationSnippet=`vec4 activation(vec4 x) {
${activation2}
}`}applyActivationSnippet=`result = activation(result);`}const addBiasSnippet=addBias?"result += getBiasAtOutCoords();":"";if(addBias){this.variableNames.push("bias")}if(hasPreluActivation){this.variableNames.push("preluActivationWeights")}this.userCode=`
${activationSnippet}
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords.w;
int d1 = d2;
int q = 0;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
vec4 dotProd = vec4(0.);
${mainLoop}
vec4 result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class CropAndResizeProgram{constructor(imageShape,boxShape,cropSize,method,extrapolationValue){this.variableNames=["Image","Boxes","BoxInd"];this.outputShape=[];const[batch,imageHeight,imageWidth,depth]=imageShape;const[numBoxes]=boxShape;const[cropHeight,cropWidth]=cropSize;this.outputShape=[numBoxes,cropHeight,cropWidth,depth];const methodId=method==="bilinear"?1:0;const[inputHeightFloat,inputWidthFloat]=[`${imageHeight-1}.0`,`${imageWidth-1}.0`];const[heightRatio,heightScale,inY]=cropHeight>1?[`${(imageHeight-1)/(cropHeight-1)}`,"(y2-y1) * height_ratio",`y1*${inputHeightFloat} + float(y)*(height_scale)`]:["0.0","0.0",`0.5 * (y1+y2) * ${inputHeightFloat}`];const[widthRatio,widthScale,inX]=cropWidth>1?[`${(imageWidth-1)/(cropWidth-1)}`,"(x2-x1) * width_ratio",`x1*${inputWidthFloat} + float(x)*(width_scale)`]:["0.0","0.0",`0.5 * (x1+x2) * ${inputWidthFloat}`];this.userCode=`
const float height_ratio = float(${heightRatio});
const float width_ratio = float(${widthRatio});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int y = coords[1];
int x = coords[2];
int d = coords[3];
// get box vals
float y1 = getBoxes(b,0);
float x1 = getBoxes(b,1);
float y2 = getBoxes(b,2);
float x2 = getBoxes(b,3);
// get image in batch index
int bInd = round(getBoxInd(b));
if(bInd < 0 || bInd >= ${batch}) {
2020-11-12 18:58:55 +01:00
return;
}
2020-11-16 21:51:46 +01:00
float height_scale = ${heightScale};
float width_scale = ${widthScale};
float in_y = ${inY};
if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
setOutput(float(${extrapolationValue}));
return;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
float in_x = ${inX};
if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
setOutput(float(${extrapolationValue}));
return;
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
vec2 sourceFracIndexCR = vec2(in_x,in_y);
if(${methodId} == 1) {
// Compute the four integer indices.
ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
float top = topLeft + (topRight - topLeft) * fracCR.x;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
float newValue = top + (bottom - top) * fracCR.y;
setOutput(newValue);
} else {
// Compute the coordinators of nearest neighbor point.
ivec2 sourceNearestCR = ivec2(floor(
sourceFracIndexCR + vec2(0.5,0.5)));
float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
setOutput(newValue);
2020-11-12 18:58:55 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}}class CumSumProgram{constructor(shape,exclusive,reverse12){this.variableNames=["x"];this.outputShape=shape;const rank=shape.length;const val=exclusive?"0.0":`getX(${getCoords(rank,"coords")})`;const length=shape[shape.length-1];let condition="";let idxString="";if(exclusive){condition=reverse12?`end != ${length-1}`:"end != 0";idxString=reverse12?"end + 1":"end - 1"}else{condition=reverse12?`end + pow2 < ${length}`:"end >= pow2";idxString=reverse12?"end + pow2":"end - pow2"}this.userCode=`
2020-11-16 21:51:46 +01:00
uniform float index;
void main() {
${getCoordsDataType(rank)} coords = getOutputCoords();
int end = ${getFinalCoord(rank,"coords")};
float val = ${val};
int pow2 = int(pow(2.0, index));
if (${condition}) {
int idx = ${idxString};
${getFinalCoord(rank,"coords")} = idx;
val += getX(${getCoords(rank,"coords")});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(val);
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}getCustomSetupFunc(index){return(gpgpu,webGLProgram)=>{if(this.index==null){this.index=gpgpu.getUniformLocation(webGLProgram,"index")}gpgpu.gl.uniform1f(this.index,index)}}}function getCoords(rank,name){if(rank===1){return`${name}`}else if(rank===2){return`${name}.x, ${name}.y`}else if(rank===3){return`${name}.x, ${name}.y, ${name}.z`}else if(rank===4){return`${name}.x, ${name}.y, ${name}.z, ${name}.w`}else{throw Error(`Cumulative sum for rank ${rank} is not yet supported`)}}function getFinalCoord(rank,name){if(rank===1){return`${name}`}else if(rank===2){return`${name}.y`}else if(rank===3){return`${name}.z`}else if(rank===4){return`${name}.w`}else{throw Error(`Cumulative sum for rank ${rank} is not yet supported`)}}class DecodeMatrixProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=false;this.packedOutput=true;this.outPackingScheme=PackingScheme.DENSE;const texShape=getDenseTexShape(outputShape);const glsl=getGlslDifferences();this.outputShape=outputShape;this.userCode=`
2020-11-16 21:51:46 +01:00
ivec3 outCoordsFromFlatIndex(int index) {
${getLogicalCoordinatesFromFlatIndex(["r","c","d"],outputShape)}
return ivec3(r, c, d);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
vec4 result = vec4(0.);
for (int i=0; i<4; i++) {
int flatIndex = index + i;
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
result[i] = getA(rc.x, rc.y, rc.z);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
${glsl.output} = result;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class DecodeMatrixPackedProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;this.outPackingScheme=PackingScheme.DENSE;const texShape=getDenseTexShape(outputShape);const glsl=getGlslDifferences();this.outputShape=outputShape;this.userCode=`
ivec3 outCoordsFromFlatIndex(int index) {
${getLogicalCoordinatesFromFlatIndex(["r","c","d"],outputShape)}
return ivec3(r, c, d);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
vec4 result = vec4(0.);
for (int i=0; i<4; i++) {
int flatIndex = index + i;
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
${glsl.output} = result;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class DepthToSpaceProgram{constructor(outputShape,blockSize,dataFormat){this.variableNames=["x"];this.outputShape=[];this.outputShape=outputShape;this.blockSize=blockSize;this.dataFormat=dataFormat;this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int h = ${this.getHeightCoordString()};
int w = ${this.getWidthCoordString()};
int d = ${this.getDepthCoordString()};
int in_h = h / ${blockSize};
int offset_h = imod(h, ${blockSize});
int in_w = w / ${blockSize};
int offset_w = imod(w, ${blockSize});
int offset_d = (offset_h * ${blockSize} + offset_w) *
${this.getOutputDepthSize()};
int in_d = d + offset_d;
float result = ${this.getInputSamplingString()};
setOutput(result);
}
`}getHeightCoordString(){if(this.dataFormat==="NHWC"){return`coords[1]`}else{return`coords[2]`}}getWidthCoordString(){if(this.dataFormat==="NHWC"){return`coords[2]`}else{return`coords[3]`}}getDepthCoordString(){if(this.dataFormat==="NHWC"){return`coords[3]`}else{return`coords[1]`}}getOutputDepthSize(){if(this.dataFormat==="NHWC"){return this.outputShape[3]}else{return this.outputShape[1]}}getInputSamplingString(){if(this.dataFormat==="NHWC"){return`getX(b, in_h, in_w, in_d)`}else{return`getX(b, in_d, in_h, in_w)`}}}class DiagProgram{constructor(size){this.variableNames=["X"];this.outputShape=[size,size];this.userCode=`
void main() {
ivec2 coords = getOutputCoords();
float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
setOutput(val);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class EncodeFloatProgram{constructor(outputShape){this.variableNames=["A"];this.outTexUsage=TextureUsage.DOWNLOAD;const glsl=getGlslDifferences();this.outputShape=outputShape;this.userCode=`
${ENCODE_FLOAT_SNIPPET}
void main() {
float x = getAAtOutCoords();
${glsl.output} = encode_float(x);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class EncodeFloatPackedProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=false;this.outTexUsage=TextureUsage.DOWNLOAD;const glsl=getGlslDifferences();this.outputShape=outputShape;this.userCode=`
${ENCODE_FLOAT_SNIPPET}
void main() {
ivec3 coords = getOutputCoords();
float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
${glsl.output} = encode_float(x);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class EncodeMatrixProgram{constructor(outputShape,texShape,inputIsUnsignedByte=false){this.variableNames=["A"];const glsl=getGlslDifferences();const[height,width]=texShape;this.outputShape=outputShape;let output=`result`;if(inputIsUnsignedByte){output=`floor(result * 255. + 0.5)`}this.userCode=`
${getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 coords = getOutputCoords();
int flatIndex = getFlatIndex(coords);
int offset = imod(flatIndex, 4);
flatIndex = idiv(flatIndex, 4, 1.);
int r = flatIndex / ${width};
int c = imod(flatIndex, ${width});
vec2 uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);
vec4 values = ${glsl.texture2D}(A, uv);
float result;
if(offset == 0) {
result = values[0];
} else if(offset == 1) {
result = values[1];
} else if(offset == 2) {
result = values[2];
} else {
result = values[3];
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
${glsl.output} = vec4(${output}, 0., 0., 0.);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class EncodeMatrixPackedProgram{constructor(outputShape,texShape,inputIsUnsignedByte=false){this.variableNames=["A"];this.packedInputs=false;this.packedOutput=true;const glsl=getGlslDifferences();const[height,width]=texShape;this.outputShape=outputShape;let mainLoop="";let output="result";if(inputIsUnsignedByte){output="floor(result * 255. + 0.5)"}for(let row=0;row<=1;row++){for(let col=0;col<=1;col++){const channel=row*2+col;mainLoop+=`
localCoords = coords;
if(localCoords[2] + ${col} < ${outputShape[2]}) {
localCoords[2] += ${col};
if(localCoords[1] + ${row} < ${outputShape[1]}) {
localCoords[1] += ${row};
flatIndex = getFlatIndex(localCoords);
offset = imod(flatIndex, 4);
flatIndex = idiv(flatIndex, 4, 1.);
r = flatIndex / ${width};
c = imod(flatIndex, ${width});
uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);
values = ${glsl.texture2D}(A, uv);
if(offset == 0) {
result[${channel}] = values[0];
} else if(offset == 1) {
result[${channel}] = values[1];
} else if(offset == 2) {
result[${channel}] = values[2];
} else {
result[${channel}] = values[3];
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}this.userCode=`
${getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 coords = getOutputCoords();
vec4 result = vec4(0.);
int flatIndex, r, c, offset;
ivec3 localCoords;
vec2 uv;
vec4 values;
${mainLoop}
${glsl.output} = ${output};
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class FillProgram{constructor(shape,value){this.outputShape=[];this.variableNames=["x"];this.outputShape=shape;this.userCode=`
uniform float value;
void main() {
// Input can be obtained from uniform value.
setOutput(value);
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}getCustomSetupFunc(value){return(gpgpu,webGLProgram)=>{if(this.valueLoc==null){this.valueLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"value")}gpgpu.gl.uniform1f(this.valueLoc,value)}}}class GatherProgram{constructor(aShape,indicesLength,axis){this.variableNames=["A","indices"];const outputShape=aShape.slice();outputShape[axis]=indicesLength;this.outputShape=outputShape;this.rank=outputShape.length;const dtype=getCoordsDataType(this.rank);const sourceCoords=getSourceCoords2(aShape,axis);this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getA(${sourceCoords}));
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}}function getSourceCoords2(aShape,axis){const rank=aShape.length;if(rank>4){throw Error(`Gather for rank ${rank} is not yet supported`)}if(rank===1){return`int(getIndices(resRC))`}const currentCoords=["resRC.x","resRC.y","resRC.z","resRC.w"];const sourceCoords=[];for(let i=0;i<aShape.length;i++){if(i===axis){sourceCoords.push(`int(getIndices(${currentCoords[i]}))`)}else{sourceCoords.push(`${currentCoords[i]}`)}}return sourceCoords.join()}class GatherNDProgram{constructor(sliceDim,strides,shape){this.sliceDim=sliceDim;this.strides=strides;this.variableNames=["x","indices"];this.outputShape=shape;const stridesType=getCoordsDataType(strides.length);const dtype=getCoordsDataType(shape.length);const strideString=this.sliceDim>1?"strides[j]":"strides";this.userCode=`
2020-11-16 21:51:46 +01:00
${stridesType} strides = ${stridesType}(${this.strides});
void main() {
${dtype} coords = getOutputCoords();
int flattenIndex = 0;
for (int j = 0; j < ${this.sliceDim}; j++) {
int index = round(getIndices(coords[0], j));
flattenIndex += index * ${strideString};
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(getX(flattenIndex, coords[1]));
2020-11-12 18:58:55 +01:00
}
2020-11-17 16:18:15 +01:00
`}}function createVertexShader2(gl){const glsl=getGlslDifferences();const vertexShaderSource=`${glsl.version}
2020-11-16 21:51:46 +01:00
precision highp float;
${glsl.attribute} vec3 clipSpacePos;
${glsl.attribute} vec2 uv;
${glsl.varyingVs} vec2 resultUV;
void main() {
gl_Position = vec4(clipSpacePos, 1);
resultUV = uv;
2020-11-17 16:18:15 +01:00
}`;return createVertexShader(gl,vertexShaderSource)}function createVertexBuffer(gl){const vertexArray=new Float32Array([-1,1,0,0,1,-1,-1,0,0,0,1,1,0,1,1,1,-1,0,1,0]);return createStaticVertexBuffer(gl,vertexArray)}function createIndexBuffer(gl){const triangleVertexIndices=new Uint16Array([0,1,2,2,1,3]);return createStaticIndexBuffer(gl,triangleVertexIndices)}function createAndConfigureTexture(gl,width,height,internalFormat,textureFormat,textureType){validateTextureSize(width,height);const texture=createTexture(gl);const tex2d=gl.TEXTURE_2D;callAndCheck(gl,()=>gl.bindTexture(tex2d,texture));callAndCheck(gl,()=>gl.texParameteri(tex2d,gl.TEXTURE_WRAP_S,gl.CLAMP_TO_EDGE));callAndCheck(gl,()=>gl.texParameteri(tex2d,gl.TEXTURE_WRAP_T,gl.CLAMP_TO_EDGE));callAndCheck(gl,()=>gl.texParameteri(tex2d,gl.TEXTURE_MIN_FILTER,gl.NEAREST));callAndCheck(gl,()=>gl.texParameteri(tex2d,gl.TEXTURE_MAG_FILTER,gl.NEAREST));callAndCheck(gl,()=>gl.texImage2D(tex2d,0,internalFormat,width,height,0,textureFormat,textureType,null));callAndCheck(gl,()=>gl.bindTexture(gl.TEXTURE_2D,null));return texture}function getInternalFormatForFloat32MatrixTexture(textureConfig){return textureConfig.internalFormatFloat}function createFloat32MatrixTexture(gl,rows,columns,textureConfig){const[width,height]=getUnpackedMatrixTextureShapeWidthHeight(rows,columns);return createAndConfigureTexture(gl,width,height,getInternalFormatForFloat32MatrixTexture(textureConfig),textureConfig.textureFormatFloat,gl.FLOAT)}function getInternalFormatForFloat16MatrixTexture(textureConfig){return textureConfig.internalFormatHalfFloat}function createFloat16MatrixTexture(gl,rows,columns,textureConfig){const[width,height]=getUnpackedMatrixTextureShapeWidthHeight(rows,columns);return createAndConfigureTexture(gl,width,height,getInternalFormatForFloat16MatrixTexture(textureConfig),textureConfig.textureFormatFloat,textureConfig.textureTypeHalfFloat)}function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig){return textureConfig.downloadTextureFormat}function createUnsignedBytesMatrixTexture(gl,rows,columns,textureConfig){const[width,height]=getUnpackedMatrixTextureShapeWidthHeight(rows,columns);return createAndConfigureTexture(gl,width,height,getInternalFormatForUnsignedBytesMatrixTexture(textureConfig),gl.RGBA,gl.UNSIGNED_BYTE)}function getInternalFormatForPackedMatrixTexture(textureConfig){return textureConfig.internalFormatPackedFloat}function createPackedMatrixTexture(gl,rows,columns,textureConfig){const[width,height]=getPackedMatrixTextureShapeWidthHeight(rows,columns);return createAndConfigureTexture(gl,width,height,getInternalFormatForPackedMatrixTexture(textureConfig),gl.RGBA,gl.FLOAT)}function getInternalFormatForFloat16PackedMatrixTexture(textureConfig){return textureConfig.internalFormatPackedHalfFloat}function createFloat16PackedMatrixTexture(gl,rows,columns,textureConfig){const[width,height]=getPackedMatrixTextureShapeWidthHeight(rows,columns);return createAndConfigureTexture(gl,width,height,getInternalFormatForFloat16PackedMatrixTexture(textureConfig),gl.RGBA,textureConfig.textureTypeHalfFloat)}function bindVertexProgramAttributeStreams(gl,program,vertexBuffer){const posOffset=0;const uvOffset=3*4;const stride=3*4+2*4;callAndCheck(gl,()=>gl.bindBuffer(gl.ARRAY_BUFFER,vertexBuffer));const success=bindVertexBufferToProgramAttribute(gl,program,"clipSpacePos",vertexBuffer,3,stride,posOffset);return success&&bindVertexBufferToProgramAttribute(gl,program,"uv",vertexBuffer,2,stride,uvOffset)}function uploadDenseMatrixToTexture(gl,texture,width,height,data2,textureConfig){callAndCheck(gl,()=>gl.bindTexture(gl.TEXTURE_2D,texture));let dataForUpload,texelDataType,internalFormat;if(data2 instanceof Uint8Array){dataForUpload=new Uint8Array(width*height*4);texelDataType=gl.UNSIGNED_BYTE;internalFormat=gl.RGBA}else{dataForUpload=new Float32Array(width*height*4);texelDataType=gl.FLOAT;internalFormat=textureConfig.internalFormatPackedFloat}dataForUpload.set(data2);callAndCheck(gl,()=>gl.texImage2D(gl.TEXTURE_2D,0,internalFormat,width,height,0,gl.RGBA,texelDataType,dataForUpload
2020-11-16 21:51:46 +01:00
blockIndex = rc.y + ${col};
pos = rc.x + ${row};
if(blockIndex < ${outputShape[1]} && pos < ${outputShape[0]}) {
offsetY = int(blockIndex / (${outWidth})) * ${strideHeight} - ${top};
d0 = offsetY + ${dilationHeight} * (pos / ${itemsPerBlockRow});
if(d0 < ${inputShape[rowDim]} && d0 >= 0) {
offsetX = int(mod(float(blockIndex), ${outWidth}.) * ${strideWidth}. - ${left}.);
d1 = offsetX + ${dilationWidth} * (int(mod(float(pos), ${itemsPerBlockRow}.) / ${inChannels}.));
if(d1 < ${inputShape[colDim]} && d1 >= 0) {
ch = int(mod(float(pos), ${inChannels}.));
if (${isChannelsLast}) {
innerDims = vec2(d1, ch);
result[${row*2+col}] = getChannel(
getA(d0, int(innerDims.x),
int(innerDims.y)), innerDims);
} else {
innerDims = vec2(d0, d1);
result[${row*2+col}] = getChannel(
getA(ch, int(innerDims.x),
int(innerDims.y)), innerDims);
}
2020-11-12 18:58:55 +01:00
}
}
}
2020-11-16 21:51:46 +01:00
`}}this.userCode=`
void main() {
ivec2 rc = getOutputCoords();
vec4 result = vec4(0);
int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
vec2 innerDims;
${unrolled}
${glsl.output} = result;
}
`}}class LRNProgram{constructor(xShape,radius,bias,alpha,beta){this.variableNames=["x"];this.outputShape=[];const rad=radius;const maxD=xShape[3]-1;this.outputShape=xShape;let powOperator;const basis=`float(${bias}) + float(${alpha}) * sum`;if(beta===.5){powOperator=`inversesqrt(${basis})`}else if(beta===1){powOperator=`1.0/(${basis})`}else{powOperator=`exp(log(${basis}) * float(-${beta}));`}this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int r = coords[1];
int c = coords[2];
int d = coords[3];
float x = getX(b, r, c, d);
float sum = 0.0;
for (int j = -${rad}; j <= ${rad}; j++) {
int idx = d + j;
if (idx >= 0 && idx <= ${maxD}) {
float z = getX(b, r, c, idx);
sum += z * z;
2020-11-12 18:17:57 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
float val = x * ${powOperator};
setOutput(val);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class LRNGradProgram{constructor(inputShape,depthRadius,bias,alpha,beta){this.variableNames=["inputImage","outputImage","dy"];this.outputShape=[];this.outputShape=inputShape;this.depth=inputShape[3];this.depthRadius=depthRadius;this.bias=bias;this.alpha=alpha;this.beta=beta;this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int r = coords[1];
int c = coords[2];
float result = 0.0;
for (int d = 0; d < ${this.depth}; ++d) {
int depthBegin = int(max(0.0, float(d - ${depthRadius})));
int depthEnd = int(min(float(${this.depth}),
float(d + ${depthRadius} + 1)));
const int MIN_DEPTH_BEGIN = 0;
const int MAX_DEPTH_END = ${this.depth};
float norm = 0.0;
for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
if (k < depthBegin){
continue;
}
else if (k >= depthBegin && k < depthEnd) {
norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
}
else {
break;
}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
norm = float(${alpha}) * norm + float(${bias});
for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
if (k < depthBegin){
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
else if (k >= depthBegin && k < depthEnd){
float dyi = -2.0 * float(${alpha})
* float(${beta})
* getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)
/ norm;
if (k == d) {
dyi += pow(norm, -1.0 * ${beta});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
if (k == coords[3]) {
dyi *= getDy(b, r, c, d);
result += dyi;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
}
else {
break;
}
2020-11-08 15:56:02 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class LRNPackedProgram{constructor(xShape,radius,bias,alpha,beta){this.variableNames=["x"];this.outputShape=[];this.packedInputs=true;this.packedOutput=true;const rad=radius;const maxD=xShape[3]-1;this.outputShape=xShape;let powOperator;const basis=`float(${bias}) + float(${alpha}) * sum`;if(beta===.5){powOperator=`inversesqrt(${basis})`}else if(beta===1){powOperator=`1.0/(${basis})`}else{powOperator=`exp(log(${basis}) * float(-${beta}));`}this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords.x;
int r = coords.y;
int c = coords.z;
int d = coords.w;
bool hasNextCol = d < ${this.outputShape[3]};
bool hasNextRow = c < ${this.outputShape[2]};
vec4 sum = vec4(0.);
vec4 xFragAtOutputCoords = getX(b, r, c, d);
vec4 xAtOutputCoords = vec4(
getChannel(xFragAtOutputCoords, vec2(c, d)),
hasNextCol ?
getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
hasNextRow ?
getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
(hasNextRow && hasNextCol) ?
getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
);
int firstChannel = d - ${rad};
vec2 cache = vec2(0.);
if(firstChannel >= 0){
vec4 firstChannelFrag = getX(b, r, c, firstChannel);
cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
if(hasNextRow){
cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
ivec2 depth = ivec2(d, d + 1);
for (int j = - ${rad}; j <= ${rad}; j++) {
ivec2 idx = depth + j;
bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
if(depthInRange || depthPlusOneInRange){
vec4 z = vec4(0.);
vec4 xFragAtCurrentDepth;
z.xz = cache.xy;
if(depthPlusOneInRange && hasNextCol){
xFragAtCurrentDepth = idx.y != d ?
getX(b, r, c, idx.y) : xFragAtOutputCoords;
z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
if(hasNextRow){
z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
}
}
cache.xy = z.yw;
sum += z * z;
2020-11-12 18:17:57 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
vec4 result = xAtOutputCoords * ${powOperator};
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class MaxPool2DBackpropProgram{constructor(convInfo){this.variableNames=["dy","maxPos"];this.outputShape=convInfo.inShape;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padTop=effectiveFilterHeight-1-convInfo.padInfo.top;const padLeft=effectiveFilterWidth-1-convInfo.padInfo.left;const lastIndex=effectiveFilterHeight*effectiveFilterWidth-1;this.userCode=`
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 dyRCCorner = coords.yz - pads;
int dyRCorner = dyRCCorner.x;
int dyCCorner = dyRCCorner.y;
// Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(b, idyR, idyC, d);
int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
// Get the current value, check it against the value from the
// position matrix.
int curPosValue = wR * ${effectiveFilterWidth} + wC;
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
dotProd += dyValue * mask;
2020-11-12 18:58:55 +01:00
}
2020-10-15 15:43:16 +02:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class MaxPool3DBackpropProgram{constructor(convInfo){this.variableNames=["dy","maxPos"];this.outputShape=convInfo.inShape;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationDepth=convInfo.dilationDepth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const effectiveFilterDepth=convInfo.effectiveFilterDepth;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padFront=effectiveFilterDepth-1-convInfo.padInfo.front;const padTop=effectiveFilterHeight-1-convInfo.padInfo.top;const padLeft=effectiveFilterWidth-1-convInfo.padInfo.left;const lastIndex=effectiveFilterDepth*effectiveFilterHeight*effectiveFilterWidth-1;this.userCode=`
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyDCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
// Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get
// dx(xD, xR, xC, ch).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyD = int(dyD);
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
int maxPosValue = ${lastIndex} -
int(getMaxPos(batch, idyD, idyR, idyC, ch));
// Get the current value, check it against the value from the
// position matrix.
int curPosValue =
wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
wR * ${effectiveFilterWidth} + wC;
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
dotProd += dyValue * mask;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(dotProd);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class MatMulPackedProgram{constructor(aShape,bShape,outputShape,transposeA=false,transposeB=false,addBias=false,activation2=null,hasPreluActivation=false){this.variableNames=["matrixA","matrixB"];this.packedInputs=true;this.packedOutput=true;this.outputShape=outputShape;const sharedDim=transposeA?aShape[1]:aShape[2];const sharedDimensionPacked=Math.ceil(sharedDim/2);const aSample=transposeA?"i * 2, rc.y":"rc.y, i * 2";const bSample=transposeB?"rc.z, i * 2":"i * 2, rc.z";const aSwizzle=transposeA?["a.xxyy","a.zzww"]:["a.xxzz","a.yyww"];const bSwizzle=transposeB?["b.xzxz","b.ywyw"]:["b.xyxy","b.zwzw"];let activationSnippet="",applyActivationSnippet="";if(activation2){if(hasPreluActivation){activationSnippet=`vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation2}
}`}else{activationSnippet=`vec4 activation(vec4 x) {
${activation2}
}`}applyActivationSnippet=`result = activation(result);`}const addBiasSnippet=addBias?"result += getBiasAtOutCoords();":"";if(addBias){this.variableNames.push("bias")}if(hasPreluActivation){this.variableNames.push("preluActivationWeights")}let batchASnippet="rc.x";let batchBSnippet="rc.x";if(aShape[0]<bShape[0]){batchASnippet=`int(min(float(rc.x), ${aShape[0]-1}.))`}else if(bShape[0]<aShape[0]){batchBSnippet=`int(min(float(rc.x), ${bShape[0]-1}.))`}this.userCode=`
${activationSnippet}
const float sharedDimension = ${sharedDimensionPacked}.0;
vec4 dot2x2ARowBCol(ivec3 rc) {
vec4 result = vec4(0);
for (int i = 0; i < ${sharedDimensionPacked}; i++) {
int batchA = ${batchASnippet};
int batchB = ${batchBSnippet};
vec4 a = getMatrixA(batchA, ${aSample});
vec4 b = getMatrixB(batchB, ${bSample});
// These swizzled products need to be separately added.
// See: https://github.com/tensorflow/tfjs/issues/1735
result += (${aSwizzle[0]} * ${bSwizzle[0]});
result += (${aSwizzle[1]} * ${bSwizzle[1]});
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
return result;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec3 rc = getOutputCoords();
vec4 result = dot2x2ARowBCol(rc);
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class MultinomialProgram{constructor(batchSize,numOutcomes,numSamples){this.variableNames=["probs"];this.outputShape=[batchSize,numSamples];this.userCode=`
uniform float seed;
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
float r = random(seed);
float cdf = 0.0;
for (int i = 0; i < ${numOutcomes-1}; i++) {
cdf += getProbs(batch, i);
if (r < cdf) {
setOutput(float(i));
return;
2020-11-08 18:32:31 +01:00
}
}
2020-11-16 21:51:46 +01:00
// If no other event happened, last event happened.
setOutput(float(${numOutcomes-1}));
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}getCustomSetupFunc(seed){return(gpgpu,webGLProgram)=>{if(this.seedLoc==null){this.seedLoc=gpgpu.getUniformLocation(webGLProgram,"seed")}gpgpu.gl.uniform1f(this.seedLoc,seed)}}}class OneHotProgram{constructor(numIndices,depth,onValue,offValue){this.variableNames=["indices"];this.outputShape=[numIndices,depth];this.userCode=`
void main() {
ivec2 coords = getOutputCoords();
int index = round(getIndices(coords.x));
setOutput(mix(float(${offValue}), float(${onValue}),
float(index == coords.y)));
2020-10-23 00:50:09 +02:00
}
2020-11-16 21:51:46 +01:00
`}}class PackProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=false;this.packedOutput=true;this.outputShape=outputShape;const rank=outputShape.length;if(rank===0){this.userCode=`
void main() {
setOutput(vec4(getA(), 0., 0., 0.));
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}else{const channels=getChannels("rc",rank);const dtype=getCoordsDataType(rank);const outOfBoundsCondition=getOutOfBoundsCondition(rank,outputShape,channels);const setup38=getSetup(rank,outputShape[outputShape.length-1],outputShape[outputShape.length-2],channels);const output=getOutput(outputShape,channels);this.userCode=`
void main() {
${dtype} rc = getOutputCoords();
if(${outOfBoundsCondition}) {
setOutput(vec4(0));
2020-11-12 18:58:55 +01:00
} else {
2020-11-16 21:51:46 +01:00
${setup38}
setOutput(vec4(${output}));
2020-11-10 02:13:38 +01:00
}
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}}}function getSourceCoordsArr(rank,dims){const coords2=[];for(let row=0;row<=1;row++){for(let col=0;col<=1;col++){let coord=`${row===0?"r":"rp1"}, ${col===0?"c":"cp1"}`;for(let d=2;d<rank;d++){coord=`${dims[dims.length-1-d]},`+coord}coords2.push(coord)}}return coords2}function getOutOfBoundsCondition(rank,shape,dims){if(rank===1){return`rc > ${shape[0]}`}let cond="";for(let i=rank-2;i<rank;i++){cond+=`${dims[i]} >= ${shape[i]}`;if(i<rank-1){cond+="||"}}return cond}function getSetup(rank,cols,rows,dims){if(rank===1){return""}const innerDims=dims.slice(-2);return`
int r = ${innerDims[0]};
int c = ${innerDims[1]};
int rp1 = r + 1;
int cp1 = c + 1;
bool cEdge = cp1 >= ${cols};
bool rEdge = rp1 >= ${rows};
`}function getOutput(shape,dims){const rank=shape.length;const sourceCoords=getSourceCoordsArr(rank,dims);if(rank===1){return`getA(rc),
rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
0, 0`}return`getA(${sourceCoords[0]}),
cEdge ? 0. : getA(${sourceCoords[1]}),
rEdge ? 0. : getA(${sourceCoords[2]}),
rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`}class PadProgram{constructor(xShape,paddings,constantValue){this.variableNames=["x"];this.outputShape=paddings.map((p2,i)=>p2[0]+xShape[i]+p2[1]);const rank=xShape.length;const type=getCoordsDataType(rank);const start=paddings.map(p2=>p2[0]).join(",");const end=paddings.map((p2,i)=>p2[0]+xShape[i]).join(",");const unpackedCoords=["coords[0]","coords[1]","coords[2]","coords[3]"].slice(0,rank);if(rank===1){this.userCode=`
int start = ${start};
int end = ${end};
void main() {
int outC = getOutputCoords();
if (outC < start || outC >= end) {
setOutput(float(${constantValue}));
} else {
setOutput(getX(outC - start));
2020-11-12 18:58:55 +01:00
}
}
2020-11-16 21:51:46 +01:00
`;return}this.userCode=`
${type} start = ${type}(${start});
${type} end = ${type}(${end});
void main() {
${type} outC = getOutputCoords();
if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
setOutput(float(${constantValue}));
} else {
${type} coords = outC - start;
setOutput(getX(${unpackedCoords}));
2020-11-12 18:58:55 +01:00
}
}
2020-11-16 21:51:46 +01:00
`}}class PadPackedProgram{constructor(xShape,paddings,constantValue){this.variableNames=["x"];this.packedInputs=true;this.packedOutput=true;this.outputShape=paddings.map((p2,i)=>p2[0]+xShape[i]+p2[1]);const rank=xShape.length;const dtype=getCoordsDataType(rank);const start=paddings.map(p2=>p2[0]).join(",");const end=paddings.map((p2,i)=>p2[0]+xShape[i]).join(",");const coords2=getChannels("rc",rank);const source=getChannels("source",rank);const cLimit=`${coords2[rank-1]} < ${this.outputShape[rank-1]}`;const innerDims=rank===1?"source":`vec2(${source.slice(-2).join()})`;const componentSetup=[`${dtype} rc = outputLoc;`,`${coords2[rank-1]} += 1;
if(${cLimit}) {
`,rank===1?"":`}
rc = outputLoc;
${coords2[rank-2]} += 1;
if(${coords2[rank-2]} < ${this.outputShape[rank-2]}) {`,rank===1?"":` ${coords2[rank-1]} += 1;
if(${cLimit}) {`];const paddingArea=rank===1?"rc < start || rc >= end":"any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))";let mainLoop="";for(let i=0,j=rank===1?2:4;i<j;i++){mainLoop+=`
${componentSetup[i]}
if (${paddingArea}) {
result[${i}] = float(${constantValue});
2020-11-10 02:13:38 +01:00
} else {
2020-11-16 21:51:46 +01:00
${dtype} source = rc - start;
result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}mainLoop+=rank===1?`} `:`}}`;this.userCode=`
const ${dtype} start = ${dtype}(${start});
const ${dtype} end = ${dtype}(${end});
void main() {
${dtype} outputLoc = getOutputCoords();
vec4 result = vec4(0.);
${mainLoop}
setOutput(result);
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Pool2DProgram{constructor(convInfo,poolType,computePositions,flattenPositions=false,includeBatchInIndex=false){this.variableNames=["x"];if(poolType==="avg"&&computePositions){throw new Error("Cannot compute positions for average pool.")}const filterWidth=convInfo.filterWidth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;this.outputShape=convInfo.outShape;const isAvgPool=poolType==="avg";const batchFlattenPositionStr=`((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;const flattenPositionStr=`(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;let initializationValue="0.0";if(!isAvgPool){initializationValue="-1.0 / 1e-20"}if(computePositions){const compareOp2=">=";this.userCode=`
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
ivec2 xRCCorner = coords.yz * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// max/min x(?, ?, d) to get y(yR, yC, d).
// ? = to be determined
float minMaxValue = 0.0;
float minMaxValueFound = 0.0;
int minMaxPosition = 0;
float avgValue = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
int xC = xCCorner + wC;
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float value = getX(batch, xR, xC, d);
// If a min / max value has already been found, use it. If not,
// use the current value.
float currMinMaxValue = mix(
value, minMaxValue, minMaxValueFound);
if (value ${compareOp2} currMinMaxValue) {
minMaxValue = value;
minMaxValueFound = 1.0;
minMaxPosition = ${flattenPositions?includeBatchInIndex?batchFlattenPositionStr:flattenPositionStr:`wR * ${effectiveFilterWidth} + wC`};
}
}
}
setOutput(float(minMaxPosition));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`;return}const compareOp="max";let returnValue=`${poolType}(${poolType}(${poolType}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;if(poolType==="avg"){returnValue=`avgValue / count`}const filterWidthNearestVec4=Math.floor(filterWidth/4)*4;const filterWidthVec4Remainder=filterWidth%4;const updateSnippet=`
if (${isAvgPool}) {
avgValue += dot(values, ones);
2020-11-12 18:58:55 +01:00
} else {
2020-11-16 21:51:46 +01:00
minMaxValue = ${compareOp}(values, minMaxValue);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`;this.userCode=`
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float count = 0.0;
float getValue(int batch, int xR, int xC, int d) {
if (xC < 0 || xC >= ${convInfo.inWidth}) {
return initializationValue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
count += 1.0;
return getX(batch, xR, xC, d);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
ivec2 xRCCorner = coords.yz * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// max/min x(?, ?, d) to get y(yR, yC, d).
// ? = to be determined
vec4 minMaxValue = vec4(${initializationValue});
float avgValue = 0.0;
count = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
int xC = xCCorner + wC * ${dilationWidth};
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
);
${updateSnippet}
}
int xC = xCCorner + ${filterWidthNearestVec4};
if (${filterWidthVec4Remainder===1}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder===2}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder===3}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
initializationValue
);
${updateSnippet}
2020-11-10 02:13:38 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(${returnValue});
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class Pool3DProgram{constructor(convInfo,poolType,computePositions,flattenPositions=false,includeBatchInIndex=false){this.variableNames=["x"];if(poolType==="avg"&&computePositions){throw new Error("Cannot compute positions for average pool.")}const filterWidth=convInfo.filterWidth;const strideDepth=convInfo.strideDepth;const strideHeight=convInfo.strideHeight;const strideWidth=convInfo.strideWidth;const dilationDepth=convInfo.dilationDepth;const dilationHeight=convInfo.dilationHeight;const dilationWidth=convInfo.dilationWidth;const effectiveFilterDepth=convInfo.effectiveFilterDepth;const effectiveFilterHeight=convInfo.effectiveFilterHeight;const effectiveFilterWidth=convInfo.effectiveFilterWidth;const padFront=convInfo.padInfo.front;const padTop=convInfo.padInfo.top;const padLeft=convInfo.padInfo.left;this.outputShape=convInfo.outShape;const isAvgPool=poolType==="avg";let initializationValue="0.0";if(!isAvgPool){initializationValue="-1.0 / 1e-20"}if(computePositions){const compareOp2=">=";this.userCode=`
const ivec3 strides =
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xDCorner = xCorner.x;
int xRCorner = xCorner.y;
int xCCorner = xCorner.z;
// max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).
// ? = to be determined
float minMaxValue = 0.0;
float minMaxValueFound = 0.0;
int minMaxPosition = 0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
int xD = xDCorner + wD;
if (xD < 0 || xD >= ${convInfo.inDepth}) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
int xC = xCCorner + wC;
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float value = getX(batch, xD, xR, xC, ch);
// If a min / max value has already been found, use it. If not,
// use the current value.
float currMinMaxValue = mix(
value, minMaxValue, minMaxValueFound);
if (value ${compareOp2} currMinMaxValue) {
minMaxValue = value;
minMaxValueFound = 1.0;
minMaxPosition = ${flattenPositions?includeBatchInIndex?`(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`:`((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`:`wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
wR * ${effectiveFilterWidth} + wC`};
}
}
2020-11-12 18:58:55 +01:00
}
}
2020-11-16 21:51:46 +01:00
setOutput(float(minMaxPosition));
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`;return}const compareOp="max";let returnValue=`${poolType}(${poolType}(${poolType}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;if(poolType==="avg"){returnValue=`avgValue / count`}const filterWidthNearestVec4=Math.floor(filterWidth/4)*4;const filterWidthVec4Remainder=filterWidth%4;const updateSnippet=`
if (${isAvgPool}) {
avgValue += dot(values, ones);
} else {
minMaxValue = ${compareOp}(values, minMaxValue);
}
2020-11-16 21:51:46 +01:00
`;this.userCode=`
const ivec3 strides =
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float count = 0.0;
float getValue(int batch, int xD, int xR, int xC, int ch) {
if (xC < 0 || xC >= ${convInfo.inWidth}) {
return initializationValue;
2020-11-08 18:32:31 +01:00
}
2020-11-16 21:51:46 +01:00
count += 1.0;
return getX(batch, xD, xR, xC, ch);
2020-11-06 19:50:16 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xDCorner = xCorner.x;
int xRCorner = xCorner.y;
int xCCorner = xCorner.z;
// max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).
// ? = to be determined
vec4 minMaxValue = vec4(${initializationValue});
float avgValue = 0.0;
count = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
int xD = xDCorner + wD;
if (xD < 0 || xD >= ${convInfo.inDepth}) {
continue;
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
int xC = xCCorner + wC * ${dilationWidth};
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
);
${updateSnippet}
}
int xC = xCCorner + ${filterWidthNearestVec4};
if (${filterWidthVec4Remainder===1}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder===2}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder===3}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
initializationValue
);
${updateSnippet}
}
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(${returnValue});
2020-11-12 18:58:55 +01:00
}
}
2020-11-16 21:51:46 +01:00
`}}class ReduceProgram{constructor(reduceInfo,reduceType){this.variableNames=["x"];const{windowSize,batchSize,inSize,outSize}=reduceInfo;this.outputShape=[batchSize,outSize];let initializationValue="0.0";let compareOp=``;if(reduceType==="prod"){initializationValue="1.0"}else if(reduceType==="min"){initializationValue="1.0 / 1e-20";compareOp=`min`}else if(reduceType==="max"){initializationValue="-1.0 / 1e-20";compareOp=`max`}let returnValue=`${reduceType}(${reduceType}(${reduceType}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;if(reduceType==="sum"){returnValue=`sumValue`}else if(reduceType==="prod"){returnValue=`prodValue`}else if(reduceType==="all"){returnValue=`allValue`}else if(reduceType==="any"){returnValue=`anyValue`}const windowSizeNearestVec4=Math.floor(windowSize/4)*4;const windowSizeVec4Remainder=windowSize%4;let updateSnippet=`
if (${reduceType==="sum"}) {
sumValue += dot(values, ones);
} else if (${reduceType==="prod"}) {
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
prodValue *= tmp[0] * tmp[1];
2020-11-12 18:58:55 +01:00
} else {
2020-11-16 21:51:46 +01:00
minMaxValue = ${compareOp}(values, minMaxValue);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`;let vecType=`vec4`;if(reduceType==="all"){initializationValue="1.0";updateSnippet=`
bool reducedAllValue = all(values);
float floatedReducedAllValue = float(reducedAllValue);
allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
`;vecType=`bvec4`}else if(reduceType==="any"){initializationValue="0.0";updateSnippet=`
bool reducedAnyValue = any(values);
float floatedReducedAnyValue = float(reducedAnyValue);
anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
`;vecType=`bvec4`}let checkOutOfBounds="";if(inSize%windowSize>0){checkOutOfBounds=`
if (inIdx < 0 || inIdx >= ${inSize}) {
return initializationValue;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}this.userCode=`
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float getValue(int batch, int inIdx) {
${checkOutOfBounds}
return getX(batch, inIdx);
2020-11-12 18:17:57 +01:00
}
2020-11-16 21:51:46 +01:00
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
vec4 minMaxValue = vec4(${initializationValue});
float prodValue = 1.0;
float sumValue = 0.0;
float allValue = 1.0;
float anyValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
${updateSnippet}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder===1}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${windowSizeVec4Remainder===2}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${windowSizeVec4Remainder===3}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
initializationValue
);
${updateSnippet}
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(${returnValue});
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class ReshapePackedProgram{constructor(outputShape,inputShape){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;this.outputShape=outputShape;let mainLoop=``;for(let i=0;i<4;i++){let thisRC=`thisRC = rc;`;if(i%2===1){thisRC+=`thisRC.z += 1;`}if(i>1){thisRC+=`thisRC.y += 1;`}mainLoop+=`
${thisRC}
${i>0?`if(thisRC.y < rows && thisRC.z < cols){`:""}
int flatIndex = getFlatIndex(thisRC);
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
result[${i}] =
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
${i>0?"}":""}
`}this.userCode=`
${getReshapedInputCoords(inputShape)}
${getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 rc = getOutputCoords();
vec4 result = vec4(0.);
ivec3 thisRC;
int rows = ${outputShape[1]};
int cols = ${outputShape[2]};
${mainLoop}
setOutput(result);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}function getReshapedInputCoords(shape){const coordsFromIndexSnippet=getLogicalCoordinatesFromFlatIndex(["r","c","d"],shape);return`
ivec3 inputCoordsFromReshapedOutCoords(int index) {
${coordsFromIndexSnippet}
return ivec3(r, c, d);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}class ResizeBilinearBackpropProgram{constructor(dy,x,alignCorners){this.variableNames=["dy"];this.outputShape=[];this.outputShape=x.shape;const[,xHeight,xWidth]=x.shape;const[,yHeight,yWidth]=dy.shape;const effectiveXSize=[alignCorners&&yHeight>1?xHeight-1:xHeight,alignCorners&&yWidth>1?xWidth-1:xWidth];const effectiveYSize=[alignCorners&&yHeight>1?yHeight-1:yHeight,alignCorners&&yWidth>1?yWidth-1:yWidth];const heightScale=effectiveXSize[0]/effectiveYSize[0];const widthScale=effectiveXSize[1]/effectiveYSize[1];const invHeightScale=1/heightScale;const invWidthScale=1/widthScale;const winHeight=Math.ceil(invHeightScale)*2+2;const winWidth=Math.ceil(invWidthScale)*2+2;this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
int r = coords[1];
int c = coords[2];
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float accumulator = 0.0;
2020-11-12 18:17:57 +01:00
2020-11-16 21:51:46 +01:00
const float heightScale = float(${heightScale});
const float widthScale = float(${widthScale});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
const float invHeightScale = float(${invHeightScale});
const float invWidthScale = float(${invWidthScale});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
const int winHeight = int(${winHeight});
const int winWidth = int(${winWidth});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Compute bounds for where in dy we will look
float startRLerp = floor(float(r) * invHeightScale);
int startDyR = int(startRLerp - float(winHeight / 2));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float startCLerp = floor(float(c) * invWidthScale);
int startDyC = int(startCLerp - float(winWidth / 2));
2020-11-12 18:58:55 +01:00
2020-11-16 21:51:46 +01:00
// Loop over dy
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
int dyR = dyROffset + startDyR;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Guard against the window exceeding the bounds of dy
if (dyR < 0 || dyR >= ${yHeight}) {
continue;
}
2020-11-12 18:58:55 +01:00
2020-11-16 21:51:46 +01:00
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
int dyC = dyCOffset + startDyC;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Guard against the window exceeding the bounds of dy
if (dyC < 0 || dyC >= ${yWidth}) {
continue;
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float dxR = float(dyR) * heightScale;
int topDxRIndex = int(floor(dxR));
int bottomDxRIndex = int(min(ceil(dxR), ${xHeight-1}.0));
float dxRLerp = dxR - float(topDxRIndex);
float inverseDxRLerp = 1.0 - dxRLerp;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float dxC = float(dyC) * widthScale;
int leftDxCIndex = int(floor(dxC));
int rightDxCIndex = int(min(ceil(dxC), ${xWidth-1}.0));
float dxCLerp = dxC - float(leftDxCIndex);
float inverseDxCLerp = 1.0 - dxCLerp;
2020-11-12 18:58:55 +01:00
2020-11-16 21:51:46 +01:00
if (r == topDxRIndex && c == leftDxCIndex) {
// topLeft
accumulator +=
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
if (r == topDxRIndex && c == rightDxCIndex) {
// topRight
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
}
if (r == bottomDxRIndex && c == leftDxCIndex) {
// bottomLeft
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
}
if (r == bottomDxRIndex && c == rightDxCIndex) {
// bottomRight
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
}
}
}
// End loop over dy
setOutput(accumulator);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class ResizeBilinearProgram{constructor(inputShape,newHeight,newWidth,alignCorners){this.variableNames=["A"];this.outputShape=[];const[batch,oldHeight,oldWidth,depth]=inputShape;this.outputShape=[batch,newHeight,newWidth,depth];const effectiveInSize=[alignCorners&&newHeight>1?oldHeight-1:oldHeight,alignCorners&&newWidth>1?oldWidth-1:oldWidth];const effectiveOutSize=[alignCorners&&newHeight>1?newHeight-1:newHeight,alignCorners&&newWidth>1?newWidth-1:newWidth];this.userCode=`
const vec2 effectiveInputOverOutputRatioRC = vec2(
${effectiveInSize[0]/effectiveOutSize[0]},
${effectiveInSize[1]/effectiveOutSize[1]});
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 yRC = coords.yz;
// Fractional source index.
vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;
// Compute the four integer indices.
ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);
ivec2 sourceCeilRC = ivec2(
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
float top = topLeft + (topRight - topLeft) * fracRC.y;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
float newValue = top + (bottom - top) * fracRC.x;
setOutput(newValue);
2020-11-12 18:58:55 +01:00
}
2020-11-16 21:51:46 +01:00
`}}class ResizeBilinearPackedProgram{constructor(inputShape,newHeight,newWidth,alignCorners){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;this.outputShape=[];const[batch,oldHeight,oldWidth,depth]=inputShape;this.outputShape=[batch,newHeight,newWidth,depth];const effectiveInSize=[alignCorners&&newHeight>1?oldHeight-1:oldHeight,alignCorners&&newWidth>1?oldWidth-1:oldWidth];const effectiveOutSize=[alignCorners&&newHeight>1?newHeight-1:newHeight,alignCorners&&newWidth>1?newWidth-1:newWidth];this.userCode=`
const vec3 effectiveInputOverOutputRatioRC = vec3(
${effectiveInSize[0]/effectiveOutSize[0]},
${effectiveInSize[1]/effectiveOutSize[1]},
${effectiveInSize[1]/effectiveOutSize[1]});
const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
${oldWidth}.0);
float getAValue(int b, int r, int c, int d) {
return getChannel(getA(b, r, c, d), vec2(c, d));
2020-11-12 18:58:55 +01:00
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
// Calculate values for next column in yRC.z.
ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Fractional source index.
vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;
2020-11-12 18:58:55 +01:00
2020-11-16 21:51:46 +01:00
// Compute the four integer indices.
ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);
ivec3 sourceCeilRC = ivec3(
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Should we calculate next column and row elements in 2x2 packed cell.
bool hasNextCol = d < ${depth-1};
bool hasNextRow = coords.z < ${newWidth-1};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// In parallel, construct four corners for all four components in
// packed 2x2 cell.
vec4 topLeft = vec4(
getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 bottomLeft = vec4(
getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 topRight = vec4(
getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
vec4 bottomRight = vec4(
getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 top = mix(topLeft, topRight, fracRC.yyzz);
vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
vec4 newValue = mix(top, bottom, fracRC.x);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
setOutput(newValue);
}
`}}class ResizeNearestNeigborBackpropProgram{constructor(dy,x,alignCorners){this.variableNames=["dy"];this.outputShape=[];this.outputShape=x.shape;const[,xHeight,xWidth]=x.shape;const[,yHeight,yWidth]=dy.shape;const effectiveXSize=[alignCorners&&yHeight>1?xHeight-1:xHeight,alignCorners&&yWidth>1?xWidth-1:xWidth];const effectiveYSize=[alignCorners&&yHeight>1?yHeight-1:yHeight,alignCorners&&yWidth>1?yWidth-1:yWidth];const heightScale=effectiveXSize[0]/effectiveYSize[0];const widthScale=effectiveXSize[1]/effectiveYSize[1];const invHeightScale=1/heightScale;const invWidthScale=1/widthScale;const winHeight=Math.ceil(invHeightScale)*2+2;const winWidth=Math.ceil(invWidthScale)*2+2;this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
int r = coords[1];
int c = coords[2];
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float accumulator = 0.0;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
const float heightScale = float(${heightScale});
const float widthScale = float(${widthScale});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
const float invHeightScale = float(${invHeightScale});
const float invWidthScale = float(${invWidthScale});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
const int winHeight = int(${winHeight});
const int winWidth = int(${winWidth});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Compute bounds for where in dy we will look
float startRLerp = floor(float(r) * invHeightScale);
int startDyR = int(floor(startRLerp - float(winHeight / 2)));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float startCLerp = floor(float(c) * invWidthScale);
int startDyC = int(floor(startCLerp - float(winWidth / 2)));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Loop over dy
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
int dyR = dyROffset + startDyR;
// Guard against the window exceeding the bounds of dy
if (dyR < 0 || dyR >= ${yHeight}) {
continue;
}
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
int dyC = dyCOffset + startDyC;
// Guard against the window exceeding the bounds of dy
if (dyC < 0 || dyC >= ${yWidth}) {
continue;
}
float sourceFracRow =
float(${effectiveXSize[0]}) *
(float(dyR) / float(${effectiveYSize[0]}));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float sourceFracCol =
float(${effectiveXSize[1]}) *
(float(dyC) / float(${effectiveYSize[1]}));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
int sourceNearestRow = int(min(
float(int(${xHeight}) - 1),
${alignCorners} ? float(round(sourceFracRow)) :
float(floor(sourceFracRow))));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
int sourceNearestCol = int(min(
float(int(${xWidth}) - 1),
${alignCorners} ? float(round(sourceFracCol)) :
float(floor(sourceFracCol))));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
if (r == sourceNearestRow && c == sourceNearestCol) {
accumulator += getDy(b, dyR, dyC, d);
}
}
}
// End loop over dy
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
setOutput(accumulator);
}
`}}class ResizeNearestNeighborProgram{constructor(inputShape,newHeight,newWidth,alignCorners){this.variableNames=["A"];this.outputShape=[];const[batch,oldHeight,oldWidth,depth]=inputShape;this.outputShape=[batch,newHeight,newWidth,depth];const effectiveInSize=[alignCorners&&newHeight>1?oldHeight-1:oldHeight,alignCorners&&newWidth>1?oldWidth-1:oldWidth];const effectiveOutSize=[alignCorners&&newHeight>1?newHeight-1:newHeight,alignCorners&&newWidth>1?newWidth-1:newWidth];const roundBase=alignCorners?"0.5":"0.0";this.userCode=`
const vec2 effectiveInputOverOutputRatioRC = vec2(
${effectiveInSize[0]/effectiveOutSize[0]},
${effectiveInSize[1]/effectiveOutSize[1]});
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 yRC = coords.yz;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Fractional source index.
vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
// Compute the coordinators of nearest neighbor point.
ivec2 sourceNearestRC = ivec2(
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
setOutput(newValue);
}
`}}class ReverseProgram{constructor(xShape,axis){this.variableNames=["x"];const rank=xShape.length;if(rank>4){throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`)}this.outputShape=xShape;if(rank===1){this.userCode=`
void main() {
int coord = getOutputCoords();
setOutput(getX(${xShape[0]} - coord - 1));
}
`;return}const getInCoord=i=>{if(axis.indexOf(i)!==-1&&xShape[i]!==1){return`${xShape[i]} - coords[${i}] - 1`}return`coords[${i}]`};const inCoords=xShape.map((_,i)=>getInCoord(i)).join(",");const type=getCoordsDataType(rank);this.userCode=`
void main() {
${type} coords = getOutputCoords();
setOutput(getX(${inCoords}));
}
`}}class ReversePackedProgram{constructor(xShape,axis){this.variableNames=["x"];this.packedInputs=true;this.packedOutput=true;const rank=xShape.length;if(rank>4){throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`)}this.outputShape=xShape;const channels=getChannels("rc",rank);const nextColumn=`${channels[rank-1]} + 1 < ${this.outputShape[rank-1]}`;const nextRow=`${channels[rank-2]} + 1 < ${this.outputShape[rank-2]}`;const type=getCoordsDataType(rank);if(rank===1){this.userCode=`
void main(){
int rc = getOutputCoords();
vec4 result = vec4(0.);
result.r = getChannel(getX(${xShape[0]} - rc - 1),
${xShape[0]} - rc - 1);
if(${nextColumn}){
result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
${xShape[0]} - (rc + 1) - 1);
}
setOutput(result);
}
`}else{this.userCode=`
void main() {
${type} rc = getOutputCoords();
vec4 result = vec4(0.);
result.r = ${getR(channels.slice())};
if(${nextColumn}){
result.g = ${getG(channels.slice())};
}
if(${nextRow}) {
result.b = ${getB(channels.slice())};
if(${nextColumn}) {
result.a = ${getA(channels.slice())};
}
}
setOutput(result);
}
`}function getR(channels2){return getChannel(channels2)}function getG(channels2){channels2[rank-1]="("+channels2[rank-1]+` + 1)`;return getChannel(channels2)}function getB(channels2){channels2[rank-2]="("+channels2[rank-2]+` + 1)`;return getChannel(channels2)}function getA(channels2){channels2[rank-1]="("+channels2[rank-1]+` + 1)`;channels2[rank-2]="("+channels2[rank-2]+` + 1)`;return getChannel(channels2)}function getChannel(channels2){const inCoordsArray=xShape.map((_,i)=>getInCoord(i,channels2));const inCoords=inCoordsArray.join(",");const innerDims=inCoordsArray.slice(-2).join(",");return`getChannel(getX(${inCoords}), vec2(${innerDims}))`}function getInCoord(i,channels1){if(axis.indexOf(i)!==-1&&xShape[i]!==1){return`${xShape[i]} - ${channels1[i]} - 1`}else{return`${channels1[i]}`}}}}class ScatterProgram{constructor(updateSize,sliceDim,indicesRank,updatesRank,strides,shape,summingDupeIndex=true){this.variableNames=["updates","indices","defaultValue"];this.outputShape=shape;const stridesType=getCoordsDataType(strides.length);const dtype=getCoordsDataType(shape.length);let indicesString="";if(indicesRank===1){indicesString="i"}else if(indicesRank===2){indicesString="i, j"}const indicesSnippet=`getIndices(${indicesString})`;let updatesString="";if(updatesRank===1){updatesString="i"}else if(updatesRank===2){updatesString="i, coords[1]"}const updatesSnippet=`getUpdates(${updatesString})`;const strideString=sliceDim>1?"strides[j]":"strides";this.userCode=`
${stridesType} strides = ${stridesType}(${strides});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
${dtype} coords = getOutputCoords();
float sum = 0.0;
bool found = false;
for (int i = 0; i < ${updateSize}; i++) {
int flattenedIndex = 0;
for (int j = 0; j < ${sliceDim}; j++) {
int index = round(${indicesSnippet});
flattenedIndex += index * ${strideString};
}
if (flattenedIndex == coords[0]) {
sum += ${updatesSnippet};
found = true;
}
}
setOutput(mix(getDefaultValue(), sum, float(found)));
}
`}}class SegmentOpProgram{constructor(segOpInfo,segOpType){this.variableNames=["x","segmentIds"];const windowSize=segOpInfo.windowSize;const batchSize=segOpInfo.batchSize;const inSize=segOpInfo.inSize;const numSegments=segOpInfo.numSegments;const outSize=numSegments*Math.ceil(inSize/windowSize);this.outputShape=[batchSize,outSize];const initializationValue="0.0";const returnValue=`sumValue`;const windowSizeNearestVec4=Math.floor(windowSize/4)*4;const windowSizeVec4Remainder=windowSize%4;const updateSnippet=`
sumValue += dot(values, segFilter);
`;let checkValueOutOfBounds="";if(inSize%windowSize>0){checkValueOutOfBounds=`
if (inIdx < 0 || inIdx >= ${inSize}) {
return initializationValue;
}
`}let checkSegmentIdOutOfBounds="";if(inSize%windowSize>0){checkSegmentIdOutOfBounds=`
if (inIdx < 0 || inIdx >= ${inSize}) {
return -1.0;
}
`}this.userCode=`
const float initializationValue = ${initializationValue};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float getValue(int batch, int inIdx) {
${checkValueOutOfBounds}
return getX(batch, inIdx);
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float getSegmentIdAtIndex(int inIdx) {
${checkSegmentIdOutOfBounds}
return getSegmentIds(inIdx);
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = int(floor(float(outIdx) / float(
${numSegments})) * float(${windowSize}));
int currentSeg = int(mod(float(outIdx), float(${numSegments})));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float sumValue = 0.0;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${updateSnippet}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder===1}) {
vec4 values = vec4(
getValue(batch, inIdx),
initializationValue,
initializationValue,
initializationValue
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
0,
0,
0
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${updateSnippet}
} else if (${windowSizeVec4Remainder===2}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
initializationValue,
initializationValue
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
0,
0
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${updateSnippet}
} else if (${windowSizeVec4Remainder===3}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
initializationValue
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
0
);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${updateSnippet}
}
setOutput(${returnValue});
}
`}}class SelectProgram{constructor(cRank,shape,rank){this.variableNames=["c","a","b"];this.outputShape=shape;let cCoords;let abCoords;if(rank>4){throw Error(`Where for rank ${rank} is not yet supported`)}if(rank===1){abCoords=`resRC`;cCoords=`resRC`}else{const currentCoords=["resRC.x","resRC.y","resRC.z","resRC.w"];const cCoordVars=[];const abCoordVars=[];for(let i=0;i<shape.length;i++){abCoordVars.push(`${currentCoords[i]}`);if(i<cRank){cCoordVars.push(`${currentCoords[i]}`)}}cCoords=cCoordVars.join();abCoords=abCoordVars.join()}const dtype=getCoordsDataType(rank);this.userCode=`
void main() {
${dtype} resRC = getOutputCoords();
float cVal = getC(${cCoords});
if (cVal >= 1.0) {
setOutput(getA(${abCoords}));
} else {
setOutput(getB(${abCoords}));
}
}
2020-11-17 16:18:15 +01:00
`}}class SliceProgram{constructor(destSize){this.variableNames=["source"];this.outputShape=destSize;this.rank=destSize.length;const dtype=getCoordsDataType(this.rank);const uniformPart=`uniform int start[${this.rank}];`;const sourceCoords=getCoords2(this.rank);let body2;const coordSum=destSize.map((_,i)=>{return`sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`});body2=`
2020-11-16 21:51:46 +01:00
${dtype} sourceLoc;
${dtype} coords = getOutputCoords();
${coordSum.join("\n")}
`;this.userCode=`
${uniformPart}
void main() {
${body2}
setOutput(getSource(${sourceCoords}));
}
2020-11-17 16:18:15 +01:00
`}getCustomSetupFunc(start){if(start.length!==this.rank){throw Error(`The rank (${this.rank}) of the program must match the length of start (${start.length})`)}return(gpgpu,webGLProgram)=>{if(this.startLoc==null){this.startLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"start");if(this.startLoc==null){return}}gpgpu.gl.uniform1iv(this.startLoc,start)}}}const coords=["x","y","z","w","u","v"];function getCoords2(rank){if(rank===1){return"sourceLoc"}else if(rank<=6){return coords.slice(0,rank).map(x=>"sourceLoc."+x).join(",")}else{throw Error(`Slicing for rank ${rank} is not yet supported`)}}class SlicePackedProgram{constructor(destSize){this.variableNames=["source"];this.packedInputs=true;this.packedOutput=true;this.outputShape=destSize;this.rank=destSize.length;const dtype=getCoordsDataType(this.rank);const coords2=getChannels("coords",this.rank);const sourceLoc=getChannels("sourceLoc",this.rank);const innerDims=this.rank===1?"sourceLoc":`vec2(${sourceLoc.slice(-2).join()})`;const getChannel=`getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;const upperRow=`
2020-11-16 21:51:46 +01:00
result.x = ${getChannel};
if (++${coords2[this.rank-1]} < ${destSize[this.rank-1]}) {
++${sourceLoc[this.rank-1]};
result.y = ${getChannel};
--${sourceLoc[this.rank-1]};
}
`;const lowerRow=this.rank===1?"":`
--${coords2[this.rank-1]};
if (++${coords2[this.rank-2]} < ${destSize[this.rank-2]}) {
++${sourceLoc[this.rank-2]};
result.z = ${getChannel};
if (++${coords2[this.rank-1]} < ${destSize[this.rank-1]}) {
++${sourceLoc[this.rank-1]};
result.w = ${getChannel};
}
}
`;const sourceLocSetup=this.rank<=4?`sourceLoc = coords +
${dtype}(${destSize.map((_,i)=>`start[${i}]`).join()});`:destSize.map((_,i)=>`${sourceLoc[i]} = ${coords2[i]} + start[${i}];`).join("\n");this.userCode=`
uniform int start[${this.rank}];
void main() {
${dtype} coords = getOutputCoords();
${dtype} sourceLoc;
${sourceLocSetup}
vec4 result = vec4(0.);
${upperRow}
${lowerRow}
setOutput(result);
}
`}getCustomSetupFunc(start){if(start.length!==this.rank){throw Error(`The rank (${this.rank}) of the program must match the length of start (${start.length})`)}return(gpgpu,webGLProgram)=>{if(this.startLoc==null){this.startLoc=gpgpu.getUniformLocationNoThrow(webGLProgram,"start");if(this.startLoc==null){return}}gpgpu.gl.uniform1iv(this.startLoc,start)}}}class StridedSliceProgram{constructor(begin,strides,size){this.variableNames=["x"];this.outputShape=size;const rank=size.length;const inputDtype=getCoordsDataType(size.length);const dtype=getCoordsDataType(size.length);let newCoords="";if(rank===1){newCoords="coords * strides + begin"}else{let outputAxis=0;newCoords=size.map((_,i)=>{outputAxis++;return size.length===1?`coords * strides[${i}] + begin[${i}]`:`coords[${outputAxis-1}] * strides[${i}] + begin[${i}]`}).join(",")}this.userCode=`
${inputDtype} begin = ${inputDtype}(${begin});
${inputDtype} strides = ${inputDtype}(${strides});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
${dtype} coords = getOutputCoords();
setOutput(getX(${newCoords}));
}
2020-11-17 16:18:15 +01:00
`}}class TextureManager{constructor(gpgpu){this.gpgpu=gpgpu;this.numUsedTextures=0;this.numFreeTextures=0;this._numBytesAllocated=0;this._numBytesFree=0;this.freeTextures={};this.logEnabled=false;this.usedTextures={}}acquireTexture(shapeRC,usage,isPacked){const physicalTexType=getPhysicalFromLogicalTextureType(usage,isPacked);const shapeKey=getKeyFromTextureShape(shapeRC,physicalTexType,isPacked);if(!(shapeKey in this.freeTextures)){this.freeTextures[shapeKey]=[]}if(!(shapeKey in this.usedTextures)){this.usedTextures[shapeKey]=[]}const texBytes=computeBytes(shapeRC,physicalTexType,this.gpgpu.gl,this.gpgpu.textureConfig,isPacked);if(this.freeTextures[shapeKey].length>0){this.numFreeTextures--;this.numUsedTextures++;this._numBytesFree-=texBytes;this.log();const newTexture2=this.freeTextures[shapeKey].shift();this.usedTextures[shapeKey].push(newTexture2);return newTexture2}let newTexture;if(physicalTexType===PhysicalTextureType.PACKED_2X2_FLOAT32){newTexture=this.gpgpu.createPackedMatrixTexture(shapeRC[0],shapeRC[1])}else if(physicalTexType===PhysicalTextureType.PACKED_2X2_FLOAT16){newTexture=this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0],shapeRC[1])}else if(physicalTexType===PhysicalTextureType.UNPACKED_FLOAT32){newTexture=this.gpgpu.createFloat32MatrixTexture(shapeRC[0],shapeRC[1])}else if(physicalTexType===PhysicalTextureType.UNPACKED_FLOAT16){newTexture=this.gpgpu.createFloat16MatrixTexture(shapeRC[0],shapeRC[1])}else if(physicalTexType===PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE){newTexture=this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0],shapeRC[1])}this.usedTextures[shapeKey].push(newTexture);this.numUsedTextures++;this._numBytesAllocated+=texBytes;this.log();return newTexture}releaseTexture(texture,shape,logicalTexType,isPacked){if(this.freeTextures==null){return}const physicalTexType=getPhysicalFromLogicalTextureType(logicalTexType,isPacked);const shapeKey=getKeyFromTextureShape(shape,physicalTexType,isPacked);if(!(shapeKey in this.freeTextures)){this.freeTextures[shapeKey]=[]}const texBytes=computeBytes(shape,physicalTexType,this.gpgpu.gl,this.gpgpu.textureConfig,isPacked);const deleteTexThreshold=env().get("WEBGL_DELETE_TEXTURE_THRESHOLD");if(deleteTexThreshold!==-1&&this._numBytesAllocated>deleteTexThreshold){this.gpgpu.deleteMatrixTexture(texture);this._numBytesAllocated-=texBytes}else{this.freeTextures[shapeKey].push(texture);this.numFreeTextures++;this._numBytesFree+=texBytes}this.numUsedTextures--;const texList=this.usedTextures[shapeKey];const texIndex=texList.indexOf(texture);if(texIndex<0){throw new Error("Cannot release a texture that was never provided by this texture manager")}texList.splice(texIndex,1);this.log()}log(){if(!this.logEnabled){return}const total=this.numFreeTextures+this.numUsedTextures;console.log("Free/Used",`${this.numFreeTextures} / ${this.numUsedTextures}`,`(${total})`);const freeRatio=this._numBytesFree/this._numBytesAllocated;console.log(`Bytes allocated: ${this._numBytesAllocated}`);console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100*freeRatio)}%)`)}get numBytesAllocated(){return this._numBytesAllocated}get numBytesFree(){return this._numBytesFree}getNumUsedTextures(){return this.numUsedTextures}getNumFreeTextures(){return this.numFreeTextures}dispose(){if(this.freeTextures==null){return}for(const texShape in this.freeTextures){this.freeTextures[texShape].forEach(tex=>{this.gpgpu.deleteMatrixTexture(tex)})}for(const texShape in this.usedTextures){this.usedTextures[texShape].forEach(tex=>{this.gpgpu.deleteMatrixTexture(tex)})}this.freeTextures=null;this.usedTextures=null;this.numUsedTextures=0;this.numFreeTextures=0;this._numBytesAllocated=0;this._numBytesFree=0}}function numBytesForInternalFormat(gl,internalFormat){const glany=gl;if(internalFormat===glany.R32F){return 4}else if(internalFormat===glany.R16F){return 2}else if(internalFormat===glany.RGBA32F){return 16}else if(internalFormat===gl.RGBA){return 16}else if(internalFormat===glany.RGBA16F){return 8}throw new Error(`Unknown internal format ${internalFormat}`)}function computeBytes(shape,p
2020-11-16 21:51:46 +01:00
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getA(${sourceCoords}));
}
2020-11-17 16:18:15 +01:00
`}}function getSourceCoords3(aShape){const rank=aShape.length;if(rank>5){throw Error(`Tile for rank ${rank} is not yet supported`)}if(rank===1){return`imod(resRC, ${aShape[0]})`}const currentCoords=["resRC.x","resRC.y","resRC.z","resRC.w","resRC.u"];const sourceCoords=[];for(let i=0;i<aShape.length;i++){sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`)}return sourceCoords.join()}class UnaryOpProgram{constructor(aShape,opSnippet){this.variableNames=["A"];this.outputShape=aShape;this.userCode=`
2020-11-16 21:51:46 +01:00
float unaryOperation(float x) {
${opSnippet}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
float x = getAAtOutCoords();
float y = unaryOperation(x);
setOutput(y);
}
2020-11-17 16:18:15 +01:00
`}}const CHECK_NAN_SNIPPET3=`if (isnan(x)) return x;`;const LINEAR=`return x;`;const ABS=`return abs(x);`;const RELU=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
return (x < 0.0) ? 0.0 : x;
2020-11-17 16:18:15 +01:00
`;const RELU6=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
return (x < 0.0) ? 0.0 : min(6.0, x);
2020-11-17 16:18:15 +01:00
`;const ELU2=`return (x >= 0.0) ? x : (exp(x) - 1.0);`;const SELU=`
2020-11-16 21:51:46 +01:00
// Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
// see: https://arxiv.org/abs/1706.02515
2020-11-17 16:18:15 +01:00
float scaleAlpha = ${backend_util_exports.SELU_SCALEALPHA};
float scale = ${backend_util_exports.SELU_SCALE};
2020-11-16 21:51:46 +01:00
return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
2020-11-17 16:18:15 +01:00
`;function STEP(alpha=0){return CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
return x > 0.0 ? 1.0 : float(${alpha});
`}const NEG=`return -x;`;const CEIL=`return ceil(x);`;const FLOOR=`return floor(x);`;const SIGN=`
if (isnan(x)) { return 0.0; }
return sign(x);
`;const IS_NAN=`return float(isnan(x));`;const IS_INF=`return float(isinf(x));`;const IS_FINITE=`return float(!isnan(x) && !isinf(x));`;const ROUND=`
// OpenGL ES does not support round function.
// The algorithm is based on banker's rounding.
float base = floor(x);
if ((x - base) < 0.5) {
return floor(x);
} else if ((x - base) > 0.5) {
return ceil(x);
} else {
if (mod(base, 2.0) == 0.0) {
return base;
} else {
return base + 1.0;
}
}
`;const EXP=`return exp(x);`;const EXPM1=`return exp(x) - 1.0;`;const LOG=`if (x < 0.0) return NAN;
return log(x);`;const LOG1P=`return log(1.0 + x);`;const SQRT=`return sqrt(x);`;const RSQRT=`return inversesqrt(x);`;const SIGMOID=`return 1.0 / (1.0 + exp(-1.0 * x));`;const SOFTPLUS=`
float epsilon = 1.1920928955078125e-7;
float threshold = log(epsilon) + 2.0;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
bool too_large = x > -threshold;
bool too_small = x < threshold;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float result;
float exp_x = exp(x);
if (too_large){
result = x;
}
else if (too_small){
result = exp_x;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
else{
result = log(exp_x + 1.0);
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const ASIN=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
if (abs(x) > 1.) {
return NAN;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
return asin(x);
2020-11-17 16:18:15 +01:00
`;const ACOS=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
if (abs(x) > 1.) {
return NAN;
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
return acos(x);
2020-11-17 16:18:15 +01:00
`;const ATAN=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
return atan(x);
`;const SINH=`
float e2x = exp(x);
return (e2x - 1.0 / e2x) / 2.0;
`;const COSH=`
float e2x = exp(-x);
return (e2x + 1.0 / e2x) / 2.0;
`;const TANH=`
float e2x = exp(-2.0 * abs(x));
return sign(x) * (1.0 - e2x) / (1.0 + e2x);
2020-11-17 16:18:15 +01:00
`;const ASINH=CHECK_NAN_SNIPPET3+`return log(x + sqrt(x * x + 1.0));`;const ACOSH=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
if (x < 1.0) return NAN;
2020-11-17 16:18:15 +01:00
return log(x + sqrt(x * x - 1.0));`;const ATANH=CHECK_NAN_SNIPPET3+`
2020-11-16 21:51:46 +01:00
if ((x < -1.0) || (x > 1.0)) return NAN;
return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;const ERF=`
// Error function is calculated approximately with elementary function.
// See "Handbook of Mathematical Functions with Formulas,
// Graphs, and Mathematical Tables", Abramowitz and Stegun.
2020-11-17 16:18:15 +01:00
float p = ${backend_util_exports.ERF_P};
float a1 = ${backend_util_exports.ERF_A1};
float a2 = ${backend_util_exports.ERF_A2};
float a3 = ${backend_util_exports.ERF_A3};
float a4 = ${backend_util_exports.ERF_A4};
float a5 = ${backend_util_exports.ERF_A5};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float sign = sign(x);
x = abs(x);
float t = 1.0 / (1.0 + p * x);
return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
2020-11-17 16:18:15 +01:00
`;const RECIPROCAL=`return 1.0 / x;`;const LOGICAL_NOT=`return float(!(x >= 1.0));`;const CLONE="return x;";const LINEAR2=`return x;`;const LOG2=`
2020-11-16 21:51:46 +01:00
vec4 result = log(x);
vec4 isNaN = vec4(lessThan(x, vec4(0.0)));
result.r = isNaN.r == 1.0 ? NAN : result.r;
result.g = isNaN.g == 1.0 ? NAN : result.g;
result.b = isNaN.b == 1.0 ? NAN : result.b;
result.a = isNaN.a == 1.0 ? NAN : result.a;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const RELU2=`
2020-11-16 21:51:46 +01:00
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const RELU62=`
2020-11-16 21:51:46 +01:00
vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
2020-11-17 16:18:15 +01:00
`;const ELU3=`
2020-11-16 21:51:46 +01:00
vec4 result;
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
return result;
`;class UnaryOpPackedProgram{constructor(aShape,opSnippet){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;this.outputShape=aShape;this.userCode=`
vec4 unaryOperation(vec4 x) {
${opSnippet}
}
void main() {
vec4 x = getAAtOutCoords();
vec4 y = unaryOperation(x);
setOutput(y);
}
`}}class UnpackProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=false;this.outputShape=outputShape;const rank=outputShape.length;const channels=getChannels("rc",rank);const dtype=getCoordsDataType(rank);const sourceCoords=getSourceCoords(rank,channels);const innerDims=channels.slice(-2);const coords2=rank<=1?"rc":`vec2(${innerDims.join(",")})`;this.userCode=`
void main() {
${dtype} rc = getOutputCoords();
vec4 packedInput = getA(${sourceCoords});
setOutput(getChannel(packedInput, ${coords2}));
}
2020-11-17 16:18:15 +01:00
`}}const{segment_util:segment_util2}=backend_util_exports;const split11=kernel_impls_exports.split;const tile10=kernel_impls_exports.tile;const topkImpl3=kernel_impls_exports.topkImpl;const whereImpl3=kernel_impls_exports.whereImpl;const EPSILON_FLOAT322=1e-7;const EPSILON_FLOAT162=1e-4;const binaryCaches={};function getBinaryCache(webGLVersion){if(webGLVersion in binaryCaches){return binaryCaches[webGLVersion]}binaryCaches[webGLVersion]={};return binaryCaches[webGLVersion]}function mapActivationToShaderProgram(activation2,packed=false){if(activation2==="linear"){if(packed){return LINEAR2}return LINEAR}else if(activation2==="relu"){if(packed){return RELU2}return RELU}else if(activation2==="elu"){if(packed){return ELU3}return ELU2}else if(activation2==="relu6"){if(packed){return RELU62}return RELU6}else if(activation2==="prelu"){if(packed){return PRELU2}return PRELU}throw new Error(`Activation ${activation2} has not been implemented for the WebGL backend.`)}const CPU_HANDOFF_SIZE_THRESHOLD=128;const BEFORE_PAGING_CONSTANT=600;function numMBBeforeWarning(){if(env().global.screen==null){return 1024}return env().global.screen.height*env().global.screen.width*window.devicePixelRatio*BEFORE_PAGING_CONSTANT/1024/1024}const MATMUL_SHARED_DIM_THRESHOLD=1e3;class MathBackendWebGL extends KernelBackend{constructor(gpgpu){super();this.pendingRead=new WeakMap;this.pendingDisposal=new WeakSet;this.dataRefCount=new WeakMap;this.numBytesInGPU=0;this.uploadWaitMs=0;this.downloadWaitMs=0;this.warnedAboutMemory=false;this.warnedAboutCPUBackend=false;this.pendingDeletes=0;this.disposed=false;if(!env().getBool("HAS_WEBGL")){throw new Error("WebGL is not supported on this device")}if(gpgpu==null){const gl=getWebGLContext(env().getNumber("WEBGL_VERSION"));this.binaryCache=getBinaryCache(env().getNumber("WEBGL_VERSION"));this.gpgpu=new GPGPUContext(gl);this.canvas=gl.canvas;this.gpgpuCreatedLocally=true}else{this.gpgpu=gpgpu;this.binaryCache={};this.gpgpuCreatedLocally=false;this.canvas=gpgpu.gl.canvas}this.textureManager=new TextureManager(this.gpgpu);this.numMBBeforeWarning=numMBBeforeWarning();this.texData=new DataStorage(this,engine15())}numDataIds(){return this.texData.numDataIds()+(this.cpuBackend?this.cpuBackend.numDataIds():0)-this.pendingDeletes}write(values,shape,dtype){if(env().getBool("WEBGL_CHECK_NUMERICAL_PROBLEMS")||env().getBool("DEBUG")){this.checkNumericalProblems(values)}if(dtype==="complex64"&&values!=null){throw new Error(`Cannot write to a complex64 dtype. Please use tf.complex(real, imag).`)}const dataId={};this.texData.set(dataId,{shape,dtype,values,usage:TextureUsage.UPLOAD,refCount:1,complexParentRefCount:0});return dataId}incRef(dataId){const texData=this.texData.get(dataId);texData.refCount++}decRef(dataId){if(this.texData.has(dataId)){const texData=this.texData.get(dataId);texData.refCount--}}move(dataId,values,shape,dtype){if(env().getBool("DEBUG")){this.checkNumericalProblems(values)}if(dtype==="complex64"){throw new Error(`Cannot write to a complex64 dtype. Please use tf.complex(real, imag).`)}this.texData.set(dataId,{shape,dtype,values,usage:TextureUsage.UPLOAD,refCount:1,complexParentRefCount:0})}disposeIntermediateTensorInfo(tensorInfo){const dataId=tensorInfo.dataId;if(this.texData.has(dataId)){const textureData=this.texData.get(dataId);textureData.refCount--;if(textureData.refCount<1){this.disposeData(dataId)}}}readSync(dataId){const texData=this.texData.get(dataId);const{values,dtype,complexTensorInfos,slice:slice21,shape,isPacked}=texData;if(slice21!=null){let program;if(isPacked){program=new UnaryOpPackedProgram(shape,CLONE)}else{program=new UnaryOpProgram(shape,CLONE)}const res=this.runWebGLProgram(program,[{dataId,shape,dtype}],dtype);const data2=this.readSync(res.dataId);this.disposeIntermediateTensorInfo(res);return data2}if(values!=null){return this.convertAndCacheOnCPU(dataId)}if(dtype==="string"){return values}const shouldTimeProgram=this.activeTimers!=null;let start;if(shouldTimeProgram){start=util_exports.now()}let result;if(dtype==="complex64"){const realValues=this.readSync(complexTensorIn
2020-11-16 21:51:46 +01:00
if (isnan(a)) return a;
if (isnan(b)) return b;
`;const CHECK_NAN_SNIPPET_BINARY_PACKED=`
result.r = isNaN.r > 0. ? NAN : result.r;
result.g = isNaN.g > 0. ? NAN : result.g;
result.b = isNaN.b > 0. ? NAN : result.b;
result.a = isNaN.a > 0. ? NAN : result.a;
2020-11-17 16:18:15 +01:00
`;function unaryKernelFunc2(opSnippet){return({inputs,backend:backend3})=>{const{x}=inputs;const webglBackend=backend3;const program=new UnaryOpProgram(x.shape,opSnippet);return webglBackend.runWebGLProgram(program,[x],x.dtype)}}function binaryKernelFunc2({opSnippet,packedOpSnippet,checkOutOfBounds=false,supportsComplex=false,cpuKernelImpl,dtype}){return({inputs,backend:backend3})=>{const{a,b}=inputs;const webglBackend=backend3;if(supportsComplex&&a.dtype==="complex64"){const aData=webglBackend.texData.get(a.dataId);const bData=webglBackend.texData.get(b.dataId);const[real8,imag8]=[[aData.complexTensorInfos.real,bData.complexTensorInfos.real],[aData.complexTensorInfos.imag,bData.complexTensorInfos.imag]].map(complexParts=>{const[aPart,bPart]=complexParts;const aHandle={dataId:aPart.dataId,dtype:aPart.dtype,shape:a.shape};const bHandle={dataId:bPart.dataId,dtype:bPart.dtype,shape:b.shape};const program2=new BinaryOpProgram(opSnippet,a.shape,b.shape);return webglBackend.runWebGLProgram(program2,[aHandle,bHandle],upcastType(aPart.dtype,bPart.dtype))});const complexOutput=complex10({inputs:{real:real8,imag:imag8},backend:webglBackend});webglBackend.disposeIntermediateTensorInfo(real8);webglBackend.disposeIntermediateTensorInfo(imag8);return complexOutput}const $dtype=dtype||upcastType(a.dtype,b.dtype);if(webglBackend.shouldExecuteOnCPU([a,b])&&cpuKernelImpl!=null){const aData=webglBackend.texData.get(a.dataId);const bData=webglBackend.texData.get(b.dataId);const[outValues,outShape]=cpuKernelImpl(a.shape,b.shape,aData.values,bData.values,$dtype);const out=webglBackend.makeTensorInfo(outShape,$dtype);const outData=webglBackend.texData.get(out.dataId);outData.values=outValues;return out}const shouldUsePackedProgram=env().getBool("WEBGL_PACK_BINARY_OPERATIONS")&&packedOpSnippet!=null;let program;if(shouldUsePackedProgram){program=new BinaryOpPackedProgram(packedOpSnippet,a.shape,b.shape,checkOutOfBounds)}else{program=new BinaryOpProgram(opSnippet,a.shape,b.shape)}return webglBackend.runWebGLProgram(program,[a,b],$dtype)}}const ADD="return a + b;";const addKernelFunc=binaryKernelFunc2({opSnippet:ADD,packedOpSnippet:ADD,supportsComplex:true,cpuKernelImpl:addImplCPU});const addConfig2={kernelName:Add,backendName:"webgl",kernelFunc:addKernelFunc};const ATAN2=CHECK_NAN_SNIPPET_BINARY+`
2020-11-16 21:51:46 +01:00
return atan(a, b);
`;const ATAN2_PACKED=`
vec4 result = atan(a, b);
vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
`+CHECK_NAN_SNIPPET_BINARY_PACKED+`
return result;
2020-11-17 16:18:15 +01:00
`;const atan25=binaryKernelFunc2({opSnippet:ATAN2,packedOpSnippet:ATAN2_PACKED});const atan2Config={kernelName:Atan2,backendName:"webgl",kernelFunc:atan25};function avgPool3(args){const{inputs,backend:backend3,attrs}=args;const{x}=inputs;assertNotComplex2(x,"avgPool");const{filterSize,strides,pad:pad11,dimRoundingMode}=attrs;const dilations=1;util_exports.assert(backend_util_exports.eitherStridesOrDilationsAreOne(strides,dilations),()=>`Error in avgPool: Either strides or dilations must be 1. Got strides ${strides} and dilations '${dilations}'`);const convInfo=backend_util_exports.computePool2DInfo(x.shape,filterSize,strides,dilations,pad11,dimRoundingMode);if(convInfo.filterWidth===1&&convInfo.filterHeight===1&&util_exports.arraysEqual(convInfo.inShape,convInfo.outShape)){return identity3({inputs:{x},backend:backend3})}const avgPoolProgram=new Pool2DProgram(convInfo,"avg",false);return backend3.runWebGLProgram(avgPoolProgram,[x],"float32")}const avgPoolConfig2={kernelName:AvgPool,backendName:"webgl",kernelFunc:avgPool3};function avgPoolBackprop3(args){const{inputs,backend:backend3,attrs}=args;const{dy,input:input2}=inputs;const x=input2;assertNotComplex2([dy,input2],"avgPoolBackprop");const{filterSize,strides,pad:pad11}=attrs;const convInfo=backend_util_exports.computePool2DInfo(x.shape,filterSize,strides,1,pad11);const avgPoolBackpropProgram=new AvgPool2DBackpropProgram(convInfo);return backend3.runWebGLProgram(avgPoolBackpropProgram,[dy],x.dtype)}const avgPoolBackpropConfig2={kernelName:AvgPoolBackprop,backendName:"webgl",kernelFunc:avgPoolBackprop3};class BatchNormProgram{constructor(xShape,meanShape,varianceShape,offsetShape,scaleShape,varianceEpsilon){this.outputShape=[];this.variableNames=["x","mean","variance"];backend_util_exports.assertAndGetBroadcastShape(xShape,meanShape);backend_util_exports.assertAndGetBroadcastShape(xShape,varianceShape);let offsetSnippet="0.0";if(offsetShape!=null){backend_util_exports.assertAndGetBroadcastShape(xShape,offsetShape);this.variableNames.push("offset");offsetSnippet="getOffsetAtOutCoords()"}let scaleSnippet="1.0";if(scaleShape!=null){backend_util_exports.assertAndGetBroadcastShape(xShape,scaleShape);this.variableNames.push("scale");scaleSnippet="getScaleAtOutCoords()"}this.outputShape=xShape;this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
float x = getXAtOutCoords();
float mean = getMeanAtOutCoords();
float variance = getVarianceAtOutCoords();
float offset = ${offsetSnippet};
float scale = ${scaleSnippet};
float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
}
2020-11-17 16:18:15 +01:00
`}}class BatchNormPackedProgram{constructor(xShape,meanShape,varianceShape,offsetShape,scaleShape,varianceEpsilon){this.packedInputs=true;this.packedOutput=true;this.variableNames=["x","mean","variance"];backend_util_exports.assertAndGetBroadcastShape(xShape,meanShape);backend_util_exports.assertAndGetBroadcastShape(xShape,varianceShape);let offsetSnippet="vec4(0.0)";if(offsetShape!=null){backend_util_exports.assertAndGetBroadcastShape(xShape,offsetShape);this.variableNames.push("offset");offsetSnippet="getOffsetAtOutCoords()"}let scaleSnippet="vec4(1.0)";if(scaleShape!=null){backend_util_exports.assertAndGetBroadcastShape(xShape,scaleShape);this.variableNames.push("scale");scaleSnippet="getScaleAtOutCoords()"}this.outputShape=xShape;this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
vec4 offset = ${offsetSnippet};
vec4 scale = ${scaleSnippet};
vec4 x = getXAtOutCoords();
vec4 mean = getMeanAtOutCoords();
vec4 variance = getVarianceAtOutCoords();
vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
setOutput((x - mean) * inv + offset);
}
2020-11-17 16:18:15 +01:00
`}}const batchNorm3=({inputs,backend:backend3,attrs})=>{const{x,mean:mean7,variance,offset,scale:scale2}=inputs;util_exports.assert(mean7.shape.length===variance.shape.length,()=>"Batch normalization gradient requires mean and variance to have equal ranks.");util_exports.assert(offset==null||mean7.shape.length===offset.shape.length,()=>"Batch normalization gradient requires mean and offset to have equal ranks.");util_exports.assert(scale2==null||mean7.shape.length===scale2.shape.length,()=>"Batch normalization gradient requires mean and scale to have equal ranks.");let{varianceEpsilon}=attrs;if(varianceEpsilon==null){varianceEpsilon=.001}const finalInputs=[x,mean7,variance];let offsetShape=null;if(offset!=null){offsetShape=offset.shape;finalInputs.push(offset)}let scaleShape=null;if(scale2!=null){scaleShape=scale2.shape;finalInputs.push(scale2)}const program=env().getBool("WEBGL_PACK_NORMALIZATION")?new BatchNormPackedProgram(x.shape,mean7.shape,variance.shape,offsetShape,scaleShape,varianceEpsilon):new BatchNormProgram(x.shape,mean7.shape,variance.shape,offsetShape,scaleShape,varianceEpsilon);const output=backend3.runWebGLProgram(program,finalInputs,finalInputs[0].dtype);return output};const batchNormConfig2={kernelName:FusedBatchNorm,backendName:"webgl",kernelFunc:batchNorm3};const NOT_EQUAL=`return float(a != b);`;const notEqual3=binaryKernelFunc2({opSnippet:NOT_EQUAL,dtype:"bool"});const notEqualConfig2={kernelName:NotEqual,backendName:"webgl",kernelFunc:notEqual3};function real7(args){const{inputs,backend:backend3}=args;const{input:input2}=inputs;const inputData=backend3.texData.get(input2.dataId);return identity3({inputs:{x:inputData.complexTensorInfos.real},backend:backend3})}const realConfig2={kernelName:Real,backendName:"webgl",kernelFunc:real7};const TO_INT=`return float(int(x));`;function int(input2,backend3){const program=new UnaryOpProgram(input2.shape,TO_INT);const output=backend3.runWebGLProgram(program,[input2],"int32");return{dataId:output.dataId,shape:output.shape,dtype:output.dtype}}function cast50(args){const{inputs,backend:backend3,attrs}=args;const{x}=inputs;const{dtype}=attrs;if(dtype==="complex64"){if(x.dtype==="complex64"){return identity3({inputs:{x},backend:backend3})}const zerosTensor=zeros(x.shape);const floatX=cast50({inputs:{x},backend:backend3,attrs:{dtype:"float32"}});const result=complex10({inputs:{real:floatX,imag:zerosTensor},backend:backend3});zerosTensor.dispose();backend3.disposeIntermediateTensorInfo(floatX);return result}if(x.dtype==="complex64"){const realPart=real7({inputs:{input:x},backend:backend3});const result=cast50({inputs:{x:realPart},backend:backend3,attrs:{dtype}});backend3.disposeIntermediateTensorInfo(realPart);return result}if(!util_exports.hasEncodingLoss(x.dtype,dtype)){const result=identity3({inputs:{x},backend:backend3});return{dataId:result.dataId,shape:result.shape,dtype}}if(dtype==="int32"){return int(x,backend3)}if(dtype==="bool"){const zerosTensorInfo=backend3.makeTensorInfo([],"bool",util_exports.getTypedArrayFromDType("bool",1));const binaryInputs={a:x,b:zerosTensorInfo};const result=notEqual3({inputs:binaryInputs,backend:backend3});backend3.disposeIntermediateTensorInfo(zerosTensorInfo);return result}throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`)}const castConfig2={kernelName:Cast,backendName:"webgl",kernelFunc:cast50};class ConcatProgram{constructor(shapes){this.outputShape=[];this.outputShape=backend_util_exports.computeOutShape(shapes,1);this.variableNames=shapes.map((_,i)=>`T${i}`);const offsets=new Array(shapes.length-1);offsets[0]=shapes[0][1];for(let i=1;i<offsets.length;i++){offsets[i]=offsets[i-1]+shapes[i][1]}const snippets=[`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];for(let i=1;i<offsets.length;i++){const shift=offsets[i-1];snippets.push(`else if (yC < ${offsets[i]}) setOutput(getT${i}(yR, yC-${shift}));`)}const lastIndex=offsets.length;const lastShift=offsets[offsets.length-1];snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
ivec2 coords = getOutputCoords();
int yR = coords.x;
int yC = coords.y;
${snippets.join("\n ")}
}
2020-11-17 16:18:15 +01:00
`}}class ConcatPackedProgram{constructor(shapes,axis){this.packedInputs=true;this.packedOutput=true;this.outputShape=[];this.outputShape=backend_util_exports.computeOutShape(shapes,axis);const shape=this.outputShape;const rank=shape.length;const dtype=getCoordsDataType(rank);const coords2=getChannels("coords",rank);const channels=["x","y","z","w","u","v"].slice(0,rank);this.variableNames=shapes.map((_,i)=>`T${i}`);const offsets=new Array(shapes.length-1);offsets[0]=shapes[0][axis];for(let i=1;i<offsets.length;i++){offsets[i]=offsets[i-1]+shapes[i][axis]}const channel=channels[axis];const lastChannels=channels.slice(-2);const allChannels=channels.join();let getValueSnippet=`if (${channel} < ${offsets[0]}) {
2020-11-16 21:51:46 +01:00
return getChannel(
getT0(${allChannels}), vec2(${lastChannels.join()}));
}`;for(let i=1;i<offsets.length;i++){const shift2=offsets[i-1];getValueSnippet+=`
if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i-1]}) {
return getChannel(
getT${i}(${shiftedChannels(channels,channel,shift2)}),
vec2(${shiftedChannels(lastChannels,channel,shift2)}));
}`}const lastIndex=offsets.length;const shift=offsets[offsets.length-1];getValueSnippet+=`
return getChannel(
getT${lastIndex}(${shiftedChannels(channels,channel,shift)}),
vec2(${shiftedChannels(lastChannels,channel,shift)}));`;this.userCode=`
float getValue(${channels.map(x=>"int "+x)}) {
${getValueSnippet}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
${dtype} coords = getOutputCoords();
vec4 result = vec4(getValue(${coords2}), 0., 0., 0.);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${coords2[rank-1]} = ${coords2[rank-1]} + 1;
if (${coords2[rank-1]} < ${shape[rank-1]}) {
result.g = getValue(${coords2});
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${coords2[rank-2]} = ${coords2[rank-2]} + 1;
if (${coords2[rank-2]} < ${shape[rank-2]}) {
result.a = getValue(${coords2});
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${coords2[rank-1]} = ${coords2[rank-1]} - 1;
if (${coords2[rank-2]} < ${shape[rank-2]} &&
${coords2[rank-1]} < ${shape[rank-1]}) {
result.b = getValue(${coords2});
}
setOutput(result);
}
2020-11-17 16:18:15 +01:00
`}}function shiftedChannels(channels,channel,shift){const channelIdx=channels.indexOf(channel);const res=channels.map((c,idx)=>{if(idx===channelIdx){return`${c} - ${shift}`}else{return c}});return res.join()}function imag7(args){const{inputs,backend:backend3}=args;const{input:input2}=inputs;const inputData=backend3.texData.get(input2.dataId);return identity3({inputs:{x:inputData.complexTensorInfos.imag},backend:backend3})}const imagConfig2={kernelName:Imag,backendName:"webgl",kernelFunc:imag7};function packedReshape(input2,afterShape,backend3){const input3DShape=[getBatchDim(input2.shape),...getRowsCols(input2.shape)];const input3D={dtype:input2.dtype,shape:input3DShape,dataId:input2.dataId};const afterShapeAs3D=[getBatchDim(afterShape),...getRowsCols(afterShape)];const program=new ReshapePackedProgram(afterShapeAs3D,input3DShape);const preventEagerUnpackingOfOutput=true;const output=backend3.runWebGLProgram(program,[input3D],input2.dtype,null,preventEagerUnpackingOfOutput);return{dataId:output.dataId,shape:afterShape,dtype:output.dtype}}function reshape90(args){const{inputs,backend:backend3,attrs}=args;const{x}=inputs;const{shape}=attrs;const webglBackend=backend3;const xSize=util_exports.sizeFromShape(x.shape);const $shape=util_exports.inferFromImplicitShape(shape,xSize);const $xSize=util_exports.sizeFromShape($shape);util_exports.assert(xSize===$xSize,()=>`The new shape (${$shape}) has ${$xSize} elements and the old shape (${x.shape}) has ${xSize} elements. The new shape and old shape must have the same number of elements.`);const xTexData=webglBackend.texData.get(x.dataId);if(xTexData.isPacked&&!isReshapeFree(x.shape,$shape)&&!(xTexData.texture!==null&&isReshapeFree(xTexData.shape,$shape))){return packedReshape(x,$shape,webglBackend)}webglBackend.incRef(x.dataId);return{dataId:x.dataId,shape:$shape,dtype:x.dtype}}const reshapeConfig2={kernelName:Reshape,backendName:"webgl",kernelFunc:reshape90};function concatImpl(inputs,axis,backend3){const dtype=inputs[0].dtype;if(dtype==="complex64"){const reals=inputs.map(t=>real7({inputs:{input:t},backend:backend3}));const imags=inputs.map(t=>imag7({inputs:{input:t},backend:backend3}));const realConcated=concatImpl(reals,axis,backend3);const imagConcated=concatImpl(imags,axis,backend3);const result2=complex10({inputs:{real:realConcated,imag:imagConcated},backend:backend3});reals.forEach(r=>backend3.disposeIntermediateTensorInfo(r));imags.forEach(i=>backend3.disposeIntermediateTensorInfo(i));backend3.disposeIntermediateTensorInfo(realConcated);backend3.disposeIntermediateTensorInfo(imagConcated);return result2}if(inputs.length>env().getNumber("WEBGL_MAX_TEXTURES_IN_SHADER")){const midIndex=Math.floor(inputs.length/2);const leftSide=concatImpl(inputs.slice(0,midIndex),axis,backend3);const rightSide=concatImpl(inputs.slice(midIndex),axis,backend3);const result2=concatImpl([leftSide,rightSide],axis,backend3);backend3.disposeIntermediateTensorInfo(leftSide);backend3.disposeIntermediateTensorInfo(rightSide);return result2}if(env().getBool("WEBGL_PACK_ARRAY_OPERATIONS")&&inputs[0].shape.length>1){const program2=new ConcatPackedProgram(inputs.map(t=>t.shape),axis);return backend3.runWebGLProgram(program2,inputs,dtype)}const outShape=backend_util_exports.computeOutShape(inputs.map(t=>t.shape),axis);const tensors2D=inputs.map(x=>reshape90({inputs:{x},attrs:{shape:[-1,util_exports.sizeFromShape(x.shape.slice(axis))]},backend:backend3}));const program=new ConcatProgram(tensors2D.map(t=>t.shape));const result=backend3.runWebGLProgram(program,tensors2D,dtype);tensors2D.forEach(r=>backend3.disposeIntermediateTensorInfo(r));const reshapedResult=reshape90({inputs:{x:result},attrs:{shape:outShape},backend:backend3});backend3.disposeIntermediateTensorInfo(result);return reshapedResult}function concat18(args){const{inputs,backend:backend3,attrs}=args;const{axis}=attrs;const $axis=util_exports.parseAxisParam(axis,inputs[0].shape)[0];const outShape=backend_util_exports.computeOutShape(inputs.map(t=>t.shape),$axis);if(util_exports.sizeFromShape(outShape)===0){return backend3.makeTensorInfo(outSha
2020-11-16 21:51:46 +01:00
return cos(x);
2020-11-17 16:18:15 +01:00
`;const cos7=unaryKernelFunc2(COS);const cosConfig2={kernelName:Cos,backendName:"webgl",kernelFunc:cos7};const DIV=`
2020-11-16 21:51:46 +01:00
if (a == b) {
return 1.0;
2020-11-10 02:13:38 +01:00
};
2020-11-16 21:51:46 +01:00
return a / b;`;const DIV_PACKED=`
// vec4 one = vec4(equal(a, b));
// return one + (vec4(1.0) - one) * a / b;
vec4 result = a / b;
if(a.x == b.x) {
result.x = 1.;
}
if(a.y == b.y) {
result.y = 1.;
}
if(a.z == b.z) {
result.z = 1.;
}
if(a.w == b.w) {
result.w = 1.;
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
return result;
2020-11-17 16:18:15 +01:00
`;const div36=binaryKernelFunc2({opSnippet:DIV,packedOpSnippet:DIV_PACKED,checkOutOfBounds:true});const divConfig2={kernelName:Div,backendName:"webgl",kernelFunc:div36};class FFTProgram{constructor(component,inputShape,inverse){this.variableNames=["real","imag"];const innerDim=inputShape[1];this.outputShape=inputShape;const exponentMultiplierSnippet=inverse?`2.0 * ${Math.PI}`:`-2.0 * ${Math.PI}`;const resultDenominator=inverse?`${innerDim}.0`:"1.0";let opString;if(component==="real"){opString="return real * expR - imag * expI;"}else if(component==="imag"){opString="return real * expI + imag * expR;"}else{throw new Error(`FFT component must be either "real" or "imag", got ${component}.`)}this.userCode=`
2020-11-16 21:51:46 +01:00
const float exponentMultiplier = ${exponentMultiplierSnippet};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float unaryOpComplex(float real, float expR, float imag, float expI) {
${opString}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float mulMatDFT(int batch, int index) {
float indexRatio = float(index) / float(${innerDim});
float exponentMultiplierTimesIndexRatio =
exponentMultiplier * indexRatio;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float result = 0.0;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
for (int i = 0; i < ${innerDim}; i++) {
// x = (-2|2 * PI / N) * index * i;
float x = exponentMultiplierTimesIndexRatio * float(i);
float expR = cos(x);
float expI = sin(x);
float real = getReal(batch, i);
float imag = getImag(batch, i);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
result +=
unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
}
return result;
}
void main() {
ivec2 coords = getOutputCoords();
setOutput(mulMatDFT(coords[0], coords[1]));
}
2020-11-17 16:18:15 +01:00
`}}function fftImpl2(x,inverse,backend3){const xData=backend3.texData.get(x.dataId);const inputSize=util_exports.sizeFromShape(x.shape);const innerDimensionSize=x.shape[x.shape.length-1];const batch=inputSize/innerDimensionSize;const input2D=reshape90({inputs:{x},backend:backend3,attrs:{shape:[batch,innerDimensionSize]}});const xShape=input2D.shape;const realProgram=new FFTProgram("real",xShape,inverse);const imagProgram=new FFTProgram("imag",xShape,inverse);const inputs=[{dataId:xData.complexTensorInfos.real.dataId,dtype:xData.complexTensorInfos.real.dtype,shape:xShape},{dataId:xData.complexTensorInfos.imag.dataId,dtype:xData.complexTensorInfos.imag.dtype,shape:xShape}];const realPart=backend3.runWebGLProgram(realProgram,inputs,"float32");const imagPart=backend3.runWebGLProgram(imagProgram,inputs,"float32");const complexOutput=complex10({inputs:{real:realPart,imag:imagPart},backend:backend3});backend3.disposeIntermediateTensorInfo(realPart);backend3.disposeIntermediateTensorInfo(imagPart);const complexOutputReshaped=reshape90({inputs:{x:complexOutput},backend:backend3,attrs:{shape:x.shape}});backend3.disposeIntermediateTensorInfo(complexOutputReshaped);return complexOutputReshaped}function fft7(args){const{inputs,backend:backend3}=args;const{input:input2}=inputs;return fftImpl2(input2,false,backend3)}const fftConfig2={kernelName:FFT,backendName:"webgl",kernelFunc:fft7};class FlipLeftRightProgram{constructor(imageShape){this.variableNames=["Image"];this.outputShape=[];const imageWidth=imageShape[2];this.outputShape=imageShape;this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];
int coordX = ${imageWidth} - x;
float outputValue;
if(coordX >= 0 && coordX < ${imageWidth}) {
outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
} else {
outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
}
setOutput(outputValue);
}
2020-11-17 16:18:15 +01:00
`}}const flipLeftRightConfig2={kernelName:FlipLeftRight,backendName:"webgl",kernelFunc:({inputs,backend:backend3})=>{const{image:image4}=inputs;const webglBackend=backend3;const program=new FlipLeftRightProgram(image4.shape);const output=webglBackend.runWebGLProgram(program,[image4],image4.dtype);return output}};class FromPixelsProgram{constructor(outputShape){this.variableNames=["A"];const glsl=getGlslDifferences();const[height,width]=outputShape;this.outputShape=outputShape;this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
ivec3 coords = getOutputCoords();
int texR = coords[0];
int texC = coords[1];
int depth = coords[2];
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 values = ${glsl.texture2D}(A, uv);
float value;
if (depth == 0) {
value = values.r;
} else if (depth == 1) {
value = values.g;
} else if (depth == 2) {
value = values.b;
} else if (depth == 3) {
value = values.a;
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
setOutput(floor(value * 255.0 + 0.5));
}
`}}class FromPixelsPackedProgram{constructor(outputShape){this.variableNames=["A"];this.packedInputs=false;this.packedOutput=true;const glsl=getGlslDifferences();const[height,width]=outputShape;this.outputShape=outputShape;this.userCode=`
void main() {
ivec3 coords = getOutputCoords();
int texR = coords[0];
int texC = coords[1];
int depth = coords[2];
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec4 result = vec4(0.);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
for(int row=0; row<=1; row++) {
for(int col=0; col<=1; col++) {
texC = coords[1] + row;
depth = coords[2] + col;
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${width}.0, ${height}.0);
vec4 values = ${glsl.texture2D}(A, uv);
float value;
if (depth == 0) {
value = values.r;
} else if (depth == 1) {
value = values.g;
} else if (depth == 2) {
value = values.b;
} else if (depth == 3) {
value = values.a;
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
result[row * 2 + col] = floor(value * 255.0 + 0.5);
}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${glsl.output} = result;
}
2020-11-17 16:18:15 +01:00
`}}const fromPixelsConfig={kernelName:FromPixels,backendName:"webgl",kernelFunc:fromPixels2};let fromPixels2DContext2;function fromPixels2(args){const{inputs,backend:backend3,attrs}=args;let{pixels}=inputs;const{numChannels}=attrs;const isVideo=typeof HTMLVideoElement!=="undefined"&&pixels instanceof HTMLVideoElement;const isImage=typeof HTMLImageElement!=="undefined"&&pixels instanceof HTMLImageElement;const[width,height]=isVideo?[pixels.videoWidth,pixels.videoHeight]:[pixels.width,pixels.height];const texShape=[height,width];const outShape=[height,width,numChannels];if(isImage||isVideo){if(fromPixels2DContext2==null){fromPixels2DContext2=document.createElement("canvas").getContext("2d")}fromPixels2DContext2.canvas.width=width;fromPixels2DContext2.canvas.height=height;fromPixels2DContext2.drawImage(pixels,0,0,width,height);pixels=fromPixels2DContext2.canvas}const tempPixelHandle=backend3.makeTensorInfo(texShape,"int32");backend3.texData.get(tempPixelHandle.dataId).usage=TextureUsage.PIXELS;backend3.gpgpu.uploadPixelDataToTexture(backend3.getTexture(tempPixelHandle.dataId),pixels);const program=env().getBool("WEBGL_PACK")?new FromPixelsPackedProgram(outShape):new FromPixelsProgram(outShape);const res=backend3.runWebGLProgram(program,[tempPixelHandle],"int32");backend3.disposeData(tempPixelHandle.dataId);return res}function ifft7(args){const{inputs,backend:backend3}=args;const{input:input2}=inputs;return fftImpl2(input2,true,backend3)}const ifftConfig2={kernelName:IFFT,backendName:"webgl",kernelFunc:ifft7};class MeanProgram{constructor(reduceInfo,divisor){this.variableNames=["x"];const{windowSize,batchSize,inSize,outSize}=reduceInfo;this.outputShape=[batchSize,outSize];const windowSizeNearestVec4=Math.floor(windowSize/4)*4;const windowSizeVec4Remainder=windowSize%4;let updateSnippet=`sumValue += dot(values, ones);`;if(divisor!=null){const denominator=1/divisor;updateSnippet=`sumValue += dot(values * ${util_exports.isInt(denominator)?denominator.toPrecision(2):denominator}, ones);`}let checkOutOfBounds="";if(inSize%windowSize>0){checkOutOfBounds=`
2020-11-16 21:51:46 +01:00
if (inIdx < 0 || inIdx >= ${inSize}) {
return 0.0;
}
`}this.userCode=`
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float getValue(int batch, int inIdx) {
${checkOutOfBounds}
return getX(batch, inIdx);
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
float sumValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
${updateSnippet}
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder===1}) {
vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
${updateSnippet}
} else if (${windowSizeVec4Remainder===2}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1), 0.0, 0.0);
${updateSnippet}
} else if (${windowSizeVec4Remainder===3}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2), 0.0);
${updateSnippet}
}
setOutput(sumValue);
}
2020-11-17 16:18:15 +01:00
`}}function getReductionStages(inShape){const stages=[];while(stages.length===0||stages[stages.length-1].outSize!==1){const outSize=stages.length?stages[stages.length-1].outSize:inShape[1];const windowSize=backend_util_exports.computeOptimalWindowSize(outSize);stages.push({inSize:outSize,windowSize,outSize:Math.ceil(outSize/windowSize)})}return stages}function reduce(x,dtype,reductionType,backend3){const reductionStages=getReductionStages(x.shape);let result=x;for(let i=0;i<reductionStages.length;i++){const{inSize,windowSize,outSize}=reductionStages[i];let program;let previousResult;if(reductionType==="mean"){program=i===0?new MeanProgram({windowSize,inSize,batchSize:x.shape[0],outSize},inSize):new MeanProgram({windowSize,inSize,batchSize:x.shape[0],outSize})}else{program=new ReduceProgram({windowSize,inSize,batchSize:x.shape[0],outSize},reductionType)}previousResult=result;result=backend3.runWebGLProgram(program,[result],dtype);if(previousResult.dataId!==x.dataId){backend3.disposeIntermediateTensorInfo(previousResult)}}return result}function maxImpl2(x,reduceShape,outShape,backend3){const inSize=util_exports.sizeFromShape(reduceShape);const xSize=util_exports.sizeFromShape(x.shape);const batchSize=xSize/inSize;const reshapedInput=reshape90({inputs:{x},attrs:{shape:[batchSize,inSize]},backend:backend3});const reduced=reduce(reshapedInput,x.dtype,"max",backend3);const reshapedOutput=reshape90({inputs:{x:reduced},attrs:{shape:outShape},backend:backend3});backend3.disposeIntermediateTensorInfo(reshapedInput);backend3.disposeIntermediateTensorInfo(reduced);return reshapedOutput}class TransposeProgram{constructor(aShape,newDim){this.variableNames=["A"];const outputShape=new Array(aShape.length);for(let i=0;i<outputShape.length;i++){outputShape[i]=aShape[newDim[i]]}this.outputShape=outputShape;this.rank=outputShape.length;const dtype=getCoordsDataType(this.rank);const switched=getSwitchedCoords(newDim);this.userCode=`
2020-11-16 21:51:46 +01:00
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getA(${switched}));
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
`}}function getSwitchedCoords(newDim){const rank=newDim.length;if(rank>6){throw Error(`Transpose for rank ${rank} is not yet supported`)}const originalOrder=["resRC.x","resRC.y","resRC.z","resRC.w","resRC.u","resRC.v"];const switchedCoords=new Array(rank);for(let i=0;i<newDim.length;i++){switchedCoords[newDim[i]]=originalOrder[i]}return switchedCoords.join()}class TransposePackedProgram{constructor(aShape,newDim){this.variableNames=["A"];this.packedInputs=true;this.packedOutput=true;const outputShape=new Array(aShape.length);for(let i=0;i<outputShape.length;i++){outputShape[i]=aShape[newDim[i]]}this.outputShape=outputShape;this.rank=outputShape.length;if(this.rank>6){throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`)}const dtype=getCoordsDataType(this.rank);const outputOrder=getVecChannels("rc",this.rank);const switchedOrder=new Array(this.rank);for(let i=0;i<newDim.length;i++){switchedOrder[newDim[i]]=outputOrder[i]}const innerDims=`vec2(${switchedOrder.slice(-2).join()})`;const nextColumn=`++${outputOrder[this.rank-1]} < ${outputShape[this.rank-1]}`;const getc=`getChannel(getA(${switchedOrder.join()}), ${innerDims})`;this.userCode=`
void main() {
${dtype} rc = getOutputCoords();
vec4 result = vec4(0.);
result[0] = ${getc};
if(${nextColumn}) {
result[1] = ${getc};
}
--${outputOrder[this.rank-1]};
if(++${outputOrder[this.rank-2]} < ${outputShape[this.rank-2]}) {
result[2] = ${getc};
if(${nextColumn}) {
result[3] = ${getc};
}
2020-11-10 02:13:38 +01:00
}
2020-11-16 21:51:46 +01:00
setOutput(result);
2020-11-10 02:13:38 +01:00
}
2020-11-17 16:18:15 +01:00
`}}function transposeImpl2(x,perm,backend3){const program=env().getBool("WEBGL_PACK_ARRAY_OPERATIONS")?new TransposePackedProgram(x.shape,perm):new TransposeProgram(x.shape,perm);return backend3.runWebGLProgram(program,[x],x.dtype)}const maxConfig2={kernelName:Max,backendName:"webgl",kernelFunc:({inputs,attrs,backend:backend3})=>{const{x}=inputs;const{reductionIndices,keepDims}=attrs;const webglBackend=backend3;const xRank=x.shape.length;const origAxes=util_exports.parseAxisParam(reductionIndices,x.shape);let axes=origAxes;const permutedAxes=backend_util_exports.getAxesPermutation(axes,xRank);const maxInputIsTransposed=permutedAxes!=null;const shouldExecuteOnCPU=webglBackend.shouldExecuteOnCPU([x]);let maxInput=x;if(maxInputIsTransposed){if(shouldExecuteOnCPU){const xTexData=webglBackend.texData.get(maxInput.dataId);const values=xTexData.values;const newShape=new Array(xRank);for(let i=0;i<newShape.length;i++){newShape[i]=x.shape[permutedAxes[i]]}const maxInputValues=transposeImplCPU(values,x.shape,x.dtype,permutedAxes,newShape);maxInput=webglBackend.makeTensorInfo(newShape,x.dtype);const maxInputData=webglBackend.texData.get(maxInput.dataId);maxInputData.values=maxInputValues}else{maxInput=transposeImpl2(x,permutedAxes,webglBackend)}axes=backend_util_exports.getInnerMostAxes(axes.length,xRank)}backend_util_exports.assertAxesAreInnerMostDims("max",axes,xRank);const[maxOutShape,reduceShape]=backend_util_exports.computeOutAndReduceShapes(maxInput.shape,axes);let outShape=maxOutShape;if(keepDims){outShape=backend_util_exports.expandShapeToKeepDim(maxOutShape,origAxes)}let out;if(shouldExecuteOnCPU){const xTexData=webglBackend.texData.get(maxInput.dataId);const values=xTexData.values;const outValues=maxImplCPU(values,util_exports.sizeFromShape(reduceShape),outShape,x.dtype);out=webglBackend.makeTensorInfo(outShape,x.dtype);const outData=webglBackend.texData.get(out.dataId);outData.values=outValues}else{out=maxImpl2(maxInput,reduceShape,outShape,webglBackend)}if(maxInputIsTransposed){webglBackend.disposeIntermediateTensorInfo(maxInput)}return out}};function maxPool3(args){const{inputs,backend:backend3,attrs}=args;const{x}=inputs;assertNotComplex2(x,"maxPool");const{filterSize,strides,pad:pad11,dimRoundingMode}=attrs;const dilations=1;util_exports.assert(backend_util_exports.eitherStridesOrDilationsAreOne(strides,dilations),()=>`Error in maxPool: Either strides or dilations must be 1. Got strides ${strides} and dilations '${dilations}'`);const convInfo=backend_util_exports.computePool2DInfo(x.shape,filterSize,strides,dilations,pad11,dimRoundingMode);if(convInfo.filterWidth===1&&convInfo.filterHeight===1&&util_exports.arraysEqual(convInfo.inShape,convInfo.outShape)){return identity3({inputs:{x},backend:backend3})}const maxPoolProgram=new Pool2DProgram(convInfo,"max",false);return backend3.runWebGLProgram(maxPoolProgram,[x],x.dtype)}const maxPoolConfig2={kernelName:MaxPool,backendName:"webgl",kernelFunc:maxPool3};function maxPoolBackprop3(args){const{inputs,backend:backend3,attrs}=args;const{dy,input:input2,output}=inputs;const x=input2;assertNotComplex2([input2,output],"maxPoolBackprop");const{filterSize,strides,pad:pad11,dimRoundingMode}=attrs;const convInfo=backend_util_exports.computePool2DInfo(x.shape,filterSize,strides,1,pad11,dimRoundingMode);const getPositions=true;const maxPoolPositionsProgram=new Pool2DProgram(convInfo,"max",getPositions);const maxPoolPositions2=backend3.runWebGLProgram(maxPoolPositionsProgram,[x],x.dtype);const maxPoolBackPropProgram=new MaxPool2DBackpropProgram(convInfo);const result=backend3.runWebGLProgram(maxPoolBackPropProgram,[dy,maxPoolPositions2],x.dtype);backend3.disposeIntermediateTensorInfo(maxPoolPositions2);return result}const maxPoolBackpropConfig2={kernelName:MaxPoolBackprop,backendName:"webgl",kernelFunc:maxPoolBackprop3};function maxPoolWithArgmaxImpl2(x,includeBatchInIndex,convInfo,backend3){let program=new Pool2DProgram(convInfo,"max",false);const poolOutput=backend3.runWebGLProgram(program,[x],"float32");program=new Pool2DProgram(convInfo,"max",true,true,includeBatchInIndex);c
2020-11-16 21:51:46 +01:00
int start = ${start};
int end = ${end};
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
int outC = getOutputCoords();
if (outC < start) {
outC = start * 2 - outC - ${offset};
} else if(outC >= end) {
outC = (end - 1) * 2 - outC + ${offset};
}
setOutput(getX(outC - start));
}
`;return}this.userCode=`
${dtype} start = ${dtype}(${start});
${dtype} end = ${dtype}(${end});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
${dtype} outC = getOutputCoords();
for (int i = 0; i < ${rank}; i++) {
if (outC[i] < start[i]) {
outC[i] = start[i] * 2 - outC[i] - ${offset};
} else if(outC[i] >= end[i]) {
outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
}
}
${dtype} coords = outC - start;
setOutput(getX(${unpackedCoords}));
}
`}}class MirrorPadPackedProgram{constructor(xShape,paddings,mode){this.variableNames=["x"];this.packedInputs=true;this.packedOutput=true;this.outputShape=paddings.map((p2,i)=>p2[0]+xShape[i]+p2[1]);const rank=xShape.length;const dtype=getCoordsDataType(rank);const start=paddings.map(p2=>p2[0]).join(",");const end=paddings.map((p2,i)=>p2[0]+xShape[i]).join(",");const coords2=getChannels("rc",rank);const source=getChannels("source",rank);const cLimit=`${coords2[rank-1]} < ${this.outputShape[rank-1]}`;const innerDims=rank===1?"source":`vec2(${source.slice(-2).join()})`;const offset=mode==="reflect"?0:1;let mainLoop="";if(rank===1){const padSetup=`
${dtype} source = rc;
if (source < start) {
source = start * 2 - source - ${offset};
} else if (source >= end) {
source = (end - 1) * 2 - source + ${offset};
}
source -= start;
`;mainLoop=`
${dtype} rc = outputLoc;
${padSetup}
result[0] = getChannel(getX(${source.join()}), ${innerDims});
${coords2[rank-1]} += 1;
if(${cLimit}) {
${padSetup}
result[1] = getChannel(getX(${source.join()}), ${innerDims});
}
`}else{const padSetup=`
${dtype} source = rc;
${dtype} lt = ${dtype}(lessThan(source, start));
${dtype} gte = ${dtype}(greaterThanEqual(source, end));
${dtype} orig = 1 - (lt + gte);
source = orig * source +
lt * (start * 2 - source - ${offset}) +
gte * ((end - 1) * 2 - source + ${offset});
source -= start;
`;mainLoop=`
${dtype} rc = outputLoc;
${padSetup}
result[0] = getChannel(getX(${source.join()}), ${innerDims});
${coords2[rank-1]} += 1;
if(${cLimit}) {
${padSetup}
result[1] = getChannel(getX(${source.join()}), ${innerDims});
}
rc = outputLoc;
${coords2[rank-2]} += 1;
if(${coords2[rank-2]} < ${this.outputShape[rank-2]}) {
${padSetup}
result[2] = getChannel(getX(${source.join()}), ${innerDims});
${coords2[rank-1]} += 1;
if(${cLimit}) {
${padSetup}
result[3] = getChannel(getX(${source.join()}), ${innerDims});
}
}
`}this.userCode=`
const ${dtype} start = ${dtype}(${start});
const ${dtype} end = ${dtype}(${end});
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
${dtype} outputLoc = getOutputCoords();
vec4 result = vec4(0.);
${mainLoop}
setOutput(result);
}
2020-11-17 16:18:15 +01:00
`}}const mirrorPadKernelFunc=({inputs,backend:backend3,attrs})=>{const{x}=inputs;const{paddings,mode}=attrs;const program=env().getBool("WEBGL_PACK_ARRAY_OPERATIONS")?new MirrorPadPackedProgram(x.shape,paddings,mode):new MirrorPadProgram(x.shape,paddings,mode);const output=backend3.runWebGLProgram(program,[x],x.dtype);return output};const mirrorPadConfig2={kernelName:MirrorPad,backendName:"webgl",kernelFunc:mirrorPadKernelFunc};const COMPLEX_MULTIPLY={REAL:"return areal * breal - aimag * bimag;",IMAG:"return areal * bimag + aimag * breal;"};class BinaryOpComplexProgram{constructor(op2,aShape,bShape){this.variableNames=["AReal","AImag","BReal","BImag"];this.outputShape=backend_util_exports.assertAndGetBroadcastShape(aShape,bShape);this.userCode=`
2020-11-16 21:51:46 +01:00
float binaryOpComplex(
float areal, float aimag, float breal, float bimag) {
2020-11-17 16:18:15 +01:00
${op2}
2020-11-16 21:51:46 +01:00
}
2020-11-10 02:13:38 +01:00
2020-11-16 21:51:46 +01:00
void main() {
float areal = getARealAtOutCoords();
float aimag = getAImagAtOutCoords();
float breal = getBRealAtOutCoords();
float bimag = getBImagAtOutCoords();
setOutput(binaryOpComplex(areal, aimag, breal, bimag));
}
2020-11-17 16:18:15 +01:00
`}}const MUL="return a * b;";function multiply3(args){const{inputs,backend:backend3}=args;const{a,b}=inputs;const dtype=backend_util_exports.upcastType(a.dtype,b.dtype);if(a.dtype==="complex64"){const aData=backend3.texData.get(a.dataId);const bData=backend3.texData.get(b.dataId);const realProgram=new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL,a.shape,b.shape);const imagProgram=new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG,a.shape,b.shape);const inputs2=[{dataId:aData.complexTensorInfos.real.dataId,dtype:aData.complexTensorInfos.real.dtype,shape:a.shape},{dataId:aData.complexTensorInfos.imag.dataId,dtype:aData.complexTensorInfos.imag.dtype,shape:a.shape},{dataId:bData.complexTensorInfos.real.dataId,dtype:bData.complexTensorInfos.real.dtype,shape:b.shape},{dataId:bData.complexTensorInfos.imag.dataId,dtype:bData.complexTensorInfos.imag.dtype,shape:b.shape}];const realPart=backend3.runWebGLProgram(realProgram,inputs2,"float32");const imagPart=backend3.runWebGLProgram(imagProgram,inputs2,"float32");const complexOutput=complex10({inputs:{real:realPart,imag:imagPart},backend:backend3});backend3.disposeIntermediateTensorInfo(realPart);backend3.disposeIntermediateTensorInfo(imagPart);return complexOutput}if(backend3.shouldExecuteOnCPU([a,b])){const aData=backend3.texData.get(a.dataId);const bData=backend3.texData.get(b.dataId);const[outValues,outShape]=multiplyImplCPU(a.shape,b.shape,aData.values,bData.values,dtype);const out=backend3.makeTensorInfo(outShape,dtype);const outData=backend3.texData.get(out.dataId);outData.values=outValues;return out}let program;if(env().getBool("WEBGL_PACK_BINARY_OPERATIONS")){program=new BinaryOpPackedProgram(MUL,a.shape,b.shape)}else{program=new BinaryOpProgram(MUL,a.shape,b.shape)}return backend3.runWebGLProgram(program,[a,b],dtype)}const multiplyConfig2={kernelName:Multiply,backendName:"webgl",kernelFunc:multiply3};const nonMaxSuppressionV3Config={kernelName:NonMaxSuppressionV3,backendName:"webgl",kernelFunc:({inputs,backend:backend3,attrs})=>{backend_util_exports.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");const{boxes,scores}=inputs;const{maxOutputSize,iouThreshold,scoreThreshold}=attrs;const gpuBackend=backend3;const boxesVals=gpuBackend.readSync(boxes.dataId);const scoresVals=gpuBackend.readSync(scores.dataId);const maxOutputSizeVal=maxOutputSize;const iouThresholdVal=iouThreshold;const scoreThresholdVal=scoreThreshold;return kernel_impls_exports.nonMaxSuppressionV3Impl(boxesVals,scoresVals,maxOutputSizeVal,iouThresholdVal,scoreThresholdVal)}};const nonMaxSuppressionV4Impl3=kernel_impls_exports.nonMaxSuppressionV4Impl;const nonMaxSuppressionV4Config2={kernelName:NonMaxSuppressionV4,backendName:"webgl",kernelFunc:({inputs,backend:backend3,attrs})=>{backend_util_exports.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");const{boxes,scores}=inputs;const{maxOutputSize,iouThreshold,scoreThreshold,padToMaxOutputSize}=attrs;const gpuBackend=backend3;const boxesVals=gpuBackend.readSync(boxes.dataId);const scoresVals=gpuBackend.readSync(scores.dataId);const{selectedIndices,validOutputs}=nonMaxSuppressionV4Impl3(boxesVals,scoresVals,maxOutputSize,iouThreshold,scoreThreshold,padToMaxOutputSize);return[selectedIndices,validOutputs]}};const nonMaxSuppressionV5Impl3=kernel_impls_exports.nonMaxSuppressionV5Impl;const nonMaxSuppressionV5Config2={kernelName:NonMaxSuppressionV5,backendName:"webgl",kernelFunc:({inputs,backend:backend3,attrs})=>{backend_util_exports.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");const{boxes,scores}=inputs;const{maxOutputSize,iouThreshold,scoreThreshold,softNmsSigma}=attrs;const gpuBackend=backend3;const boxesVals=gpuBackend.readSync(boxes.dataId);const scoresVals=gpuBackend.readSync(scores.dataId);const maxOutputSizeVal=maxOutputSize;const iouThresholdVal=iouThreshold;const scoreThresholdVal=scoreThreshold;const softNmsSigmaVal=softNmsSigma;const{selectedIndices,selectedScores}=nonMaxSuppressionV5Impl3(boxe
2020-11-16 21:51:46 +01:00
vec3 fill = vec3(${fillValue.join(",")});
float outputValue = fill[coords[3]];`}this.userCode=`
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];
int y = coords[1];
float coordXFloat = (float(x) - ${centerXString}) * ${cosFactor} - (float(y) - ${centerYString}) * ${sinFactor};
float coordYFloat = (float(x) - ${centerXString}) * ${sinFactor} + (float(y) - ${centerYString}) * ${cosFactor};
int coordX = int(round(coordXFloat + ${centerXString}));
int coordY = int(round(coordYFloat + ${centerYString}));
${fillSnippet}
if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
outputValue = getImage(coords[0], coordY, coordX, coords[3]);
}
setOutput(outputValue);
}
2020-11-17 16:18:15 +01:00
`}}const rotateWithOffsetConfig2={kernelName:RotateWithOffset,backendName:"webgl",kernelFunc:({inputs,attrs,backend:backend3})=>{const{image:image4}=inputs;const{radians,fillValue,center}=attrs;const webglBackend=backend3;const program=new RotateProgram(image4.shape,radians,fillValue,center);const output=webglBackend.runWebGLProgram(program,[image4],image4.dtype);return output}};const SIN=CHECK_NAN_SNIPPET_UNARY+`
2020-11-16 21:51:46 +01:00
return sin(x);
2020-11-17 16:18:15 +01:00
`;const sin6=unaryKernelFunc2(SIN);const sinConfig2={kernelName:Sin,backendName:"webgl",kernelFunc:sin6};const SQUARE=`return x * x;`;const square25=unaryKernelFunc2(SQUARE);const squareConfig2={kernelName:Square,backendName:"webgl",kernelFunc:square25};const SQUARED_DIFFERENCE="return (a - b) * (a - b);";const squaredDifference3=binaryKernelFunc2({opSnippet:SQUARED_DIFFERENCE,packedOpSnippet:SQUARED_DIFFERENCE});const squaredDifferenceConfig2={kernelName:SquaredDifference,backendName:"webgl",kernelFunc:squaredDifference3};const SUB="return a - b;";const subKernelFunc=binaryKernelFunc2({opSnippet:SUB,packedOpSnippet:SUB,supportsComplex:true,cpuKernelImpl:subImplCPU});const subConfig2={kernelName:Sub,backendName:"webgl",kernelFunc:subKernelFunc};const TAN=`return tan(x);`;const tan5=unaryKernelFunc2(TAN);const tanConfig2={kernelName:Tan,backendName:"webgl",kernelFunc:tan5};const transposeConfig2={kernelName:Transpose,backendName:"webgl",kernelFunc:({inputs,attrs,backend:backend3})=>{const{x}=inputs;const{perm}=attrs;const webglBackend=backend3;const xRank=x.shape.length;const newShape=new Array(xRank);for(let i=0;i<newShape.length;i++){newShape[i]=x.shape[perm[i]]}let out;if(webglBackend.shouldExecuteOnCPU([x])){const xTexData=webglBackend.texData.get(x.dataId);const values=xTexData.values;const outValues=transposeImplCPU(values,x.shape,x.dtype,perm,newShape);out=webglBackend.makeTensorInfo(newShape,x.dtype);const outData=webglBackend.texData.get(out.dataId);outData.values=outValues}else{out=transposeImpl2(x,perm,webglBackend)}return out}};function unique7(args){const{inputs,attrs,backend:backend3}=args;const{axis}=attrs;const{x}=inputs;assertNotComplex2(x,"unique");console.warn("WARNING: ","UI might be locked temporarily as data is being downloaded");const values=backend3.readSync(x.dataId);const{outputValues,outputShape,indices}=uniqueImplCPU(values,axis,x.shape,x.dtype);return[backend3.makeTensorInfo(outputShape,x.dtype,outputValues),backend3.makeTensorInfo([indices.length],"int32",indices)]}const uniqueConfig2={kernelName:Unique,backendName:"webgl",kernelFunc:unique7};const kernelConfigs2=[addConfig2,atan2Config,avgPoolConfig2,avgPoolBackpropConfig2,batchNormConfig2,castConfig2,complexConfig2,concatConfig2,cosConfig2,divConfig2,fftConfig2,flipLeftRightConfig2,fromPixelsConfig,identityConfig2,ifftConfig2,imagConfig2,maxConfig2,maxPoolConfig2,maxPoolBackpropConfig2,maxPoolWithArgmaxConfig2,meanConfig,mirrorPadConfig2,multiplyConfig2,nonMaxSuppressionV3Config,nonMaxSuppressionV4Config2,nonMaxSuppressionV5Config2,notEqualConfig2,realConfig2,reshapeConfig2,rotateWithOffsetConfig2,sinConfig2,squareConfig2,subConfig2,squaredDifferenceConfig2,tanConfig2,transposeConfig2,uniqueConfig2];for(const kernelConfig of kernelConfigs2){registerKernel(kernelConfig)}const version14="2.7.0";const version16={"tfjs-core":version,"tfjs-backend-cpu":version10,"tfjs-backend-webgl":version12,"tfjs-data":version8,"tfjs-layers":version2,"tfjs-converter":version6,tfjs:version14};const dist_exports3={};__export(dist_exports3,{BackendWasm:()=>BackendWasm,setWasmPath:()=>setWasmPath,setWasmPaths:()=>setWasmPaths,version_wasm:()=>version17});var CppDType;(function(CppDType2){CppDType2[CppDType2["float32"]=0]="float32";CppDType2[CppDType2["int32"]=1]="int32";CppDType2[CppDType2["bool"]=2]="bool";CppDType2[CppDType2["string"]=3]="string";CppDType2[CppDType2["complex64"]=4]="complex64"})(CppDType||(CppDType={}));var FusableActivation;(function(FusableActivation2){FusableActivation2[FusableActivation2["linear"]=0]="linear";FusableActivation2[FusableActivation2["relu"]=1]="relu";FusableActivation2[FusableActivation2["relu6"]=2]="relu6";FusableActivation2[FusableActivation2["prelu"]=3]="prelu"})(FusableActivation||(FusableActivation={}));let wasmFusedMatMul;function setup(backend3){wasmFusedMatMul=backend3.wasm.cwrap(_FusedMatMul,null,["number","array","number","number","array","number","number","number","number","number","number","number"])}function fusedBatchMatMul(args){const{inputs,backend:backend3,attrs}=args;const{a,b,bias,preluActivationWeights}=inp
2020-11-08 18:32:31 +01:00
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2017 Google LLC. All Rights Reserved.
2020-11-08 18:32:31 +01:00
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2018 Google LLC
2020-11-10 02:13:38 +01:00
*
2020-11-16 21:51:46 +01:00
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
2020-11-08 18:32:31 +01:00
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2018 Google LLC. All Rights Reserved.
2020-11-08 18:32:31 +01:00
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
2020-11-16 21:51:46 +01:00
*
2020-11-08 18:32:31 +01:00
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2018 Google LLC. All Rights Reserved.
2020-11-08 18:32:31 +01:00
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2019 Google LLC
2020-11-08 18:32:31 +01:00
*
2020-11-16 21:51:46 +01:00
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
2020-11-08 18:32:31 +01:00
* =============================================================================
*/
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2019 Google LLC. All Rights Reserved.
2020-11-08 18:32:31 +01:00
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2020 Google Inc. All Rights Reserved.
2020-11-08 18:32:31 +01:00
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
2020-11-16 21:51:46 +01:00
* Copyright 2020 Google LLC
2020-11-10 02:13:38 +01:00
*
2020-11-16 21:51:46 +01:00
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
2020-11-08 18:32:31 +01:00
* =============================================================================
*/
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
2020-11-16 21:51:46 +01:00
* Licensed under the Apache License, Version 2.0 (the License);
2020-11-08 18:32:31 +01:00
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
2020-11-16 21:51:46 +01:00
* http://www.apache.org/licenses/LICENSE-2.0
2020-11-08 18:32:31 +01:00
*
* Unless required by applicable law or agreed to in writing, software
2020-11-16 21:51:46 +01:00
* distributed under the License is distributed on an AS IS BASIS,
2020-11-08 18:32:31 +01:00
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
2020-11-16 21:51:46 +01:00
/** @license See the LICENSE file. */
2020-10-12 01:22:43 +02:00
//# sourceMappingURL=human.esm.js.map