mirror of https://github.com/vladmandic/human
2230 lines
5.2 MiB
JavaScript
2230 lines
5.2 MiB
JavaScript
![]() |
(function (global, factory) {
|
||
|
typeof exports === 'object' && typeof module !== 'undefined' ? factory(require('fs'), require('path'), require('@tensorflow/tfjs-core')) :
|
||
|
typeof define === 'function' && define.amd ? define(['fs', 'path', '@tensorflow/tfjs-core'], factory) :
|
||
|
(factory(null,null,global.tf));
|
||
|
}(this, (function (fs,path,tf) { 'use strict';
|
||
|
|
||
|
fs = fs && fs.hasOwnProperty('default') ? fs['default'] : fs;
|
||
|
path = path && path.hasOwnProperty('default') ? path['default'] : path;
|
||
|
|
||
|
function createCommonjsModule(fn, module) {
|
||
|
return module = { exports: {} }, fn(module, module.exports), module.exports;
|
||
|
}
|
||
|
|
||
|
var dist = createCommonjsModule(function (module) {
|
||
|
module.exports = {};
|
||
|
module.exports.instantiate = () => {
|
||
|
const wasmData = 'data:application/wasm;base64,\
|
||
|
AGFzbQEAAAABpQRBYAJ/fwBgAX8AYAABf2ABfwF/YAd/f39/f39/AGAFf39/f38Bf2AFf39+f38AYAJ/fwF/YAd/f39/f39/AX9gA39/fwF/YAV/f39/fwBgBn9/f39/fwBgBH9/f38Bf2AEf39/fwBgA39/fwBgAABgBn9/f39/fwF/YAV/f39/fgF/YAV/f39/fAF/YAh/f39/f39/fwF/YAZ/f39/f3wBf2ANf39/f39/f39/f39/fwBgCH9/f39/f39/AGABfAF8YA9/f39/f39/f39/f39/f38Bf2AMf39/f39/f39/f39/AX9gDH9/f39/f39/f39/fwBgCX9/f39/f39/fwBgA39+fwF/YAV/fH9/fwF/YAR/f35/AX9gA399fwF/YAN/fH8Bf2AJf39/f39/f39/AX9gAX8BfWABfwF8YAF/AX5gBX9+fn9/AX5gBH9+fn4BfmAEf39/fgF+YAN/f38BfGAFf39/f38BfGAGf39/f39/AXxgAn9/AX5gAnx/AXxgAnx8AXxgA35/fwF/YAJ+fwF/YAZ/fH9/f38Bf2ADfHx/AXxgAXwBf2ACfH8Bf2ABfQF/YAR/f39/AX5gA39/fgBgAn9+AGACf34Bf2ACf30AYAJ/fABgCn9/f39/f39/f38Bf2ADf39/AX1gC39/f39/f39/f39/AX9gCn9/f39/f39/f38AYA9/f39/f39/f39/f39/f38AYAd/f39/f398AX8C4QkzA2VudgVhYm9ydAABA2VudhlfX19jeGFfYWxsb2NhdGVfZXhjZXB0aW9uAAMDZW52E19fX2N4YV9wdXJlX3ZpcnR1YWwADwNlbnYMX19fY3hhX3Rocm93AA4DZW52GV9fX2N4YV91bmNhdWdodF9leGNlcHRpb24AAgNlbnYHX19fbG9jawABA2VudgtfX19tYXBfZmlsZQAHA2VudgtfX19zZXRFcnJObwABA2Vudg1fX19zeXNjYWxsMTQwAAcDZW52DV9fX3N5c2NhbGwxNDUABwNlbnYNX19fc3lzY2FsbDE0NgAHA2VudgxfX19zeXNjYWxsNTQABwNlbnYLX19fc3lzY2FsbDYABwNlbnYMX19fc3lzY2FsbDkxAAcDZW52CV9fX3VubG9jawABA2VudhZfX2VtYmluZF9yZWdpc3Rlcl9ib29sAAoDZW52F19fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzABUDZW52I19fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzX2NvbnN0cnVjdG9yAAsDZW52IF9fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzX2Z1bmN0aW9uABYDZW52F19fZW1iaW5kX3JlZ2lzdGVyX2VtdmFsAAADZW52Fl9fZW1iaW5kX3JlZ2lzdGVyX2VudW0ADQNlbnYcX19lbWJpbmRfcmVnaXN0ZXJfZW51bV92YWx1ZQAOA2VudhdfX2VtYmluZF9yZWdpc3Rlcl9mbG9hdAAOA2VudhlfX2VtYmluZF9yZWdpc3Rlcl9pbnRlZ2VyAAoDZW52HV9fZW1iaW5kX3JlZ2lzdGVyX21lbW9yeV92aWV3AA4DZW52HF9fZW1iaW5kX3JlZ2lzdGVyX3N0ZF9zdHJpbmcAAANlbnYdX19lbWJpbmRfcmVnaXN0ZXJfc3RkX3dzdHJpbmcADgNlbnYWX19lbWJpbmRfcmVnaXN0ZXJfdm9pZAAAA2Vudg5fX2VtdmFsX2RlY3JlZgABA2Vudg5fX2VtdmFsX2luY3JlZgABA2VudhJfX2VtdmFsX3Rha2VfdmFsdWUABwNlbnYGX2Fib3J0AA8DZW52GV9lbXNjcmlwdGVuX2dldF9oZWFwX3NpemUAAgNlbnYWX2Vtc2NyaXB0ZW5fbWVtY3B5X2JpZwAJA2VudhdfZW1zY3JpcHRlbl9yZXNpemVfaGVhcAADA2VudgdfZ2V0ZW52AAMDZW52Dl9sbHZtX2xvZzJfZjY0ABcDZW52El9sbHZtX3N0YWNrcmVzdG9yZQABA2Vudg9fbGx2bV9zdGFja3NhdmUAAgNlbnYKX2xsdm1fdHJhcAAPA2VudhJfcHRocmVhZF9jb25kX3dhaXQABwNlbnYXX3B0aHJlYWRfbXV0ZXhhdHRyX2luaXQAAwNlbnYaX3B0aHJlYWRfbXV0ZXhhdHRyX3NldHR5cGUABwNlbnYLX3N0cmZ0aW1lX2wABQNlbnYXYWJvcnRPbkNhbm5vdEdyb3dNZW1vcnkAAwNlbnYMX190YWJsZV9iYXNlA38AA2Vudg5EWU5BTUlDVE9QX1BUUgN/AAZnbG9iYWwDTmFOA3wABmdsb2JhbAhJbmZpbml0eQN8AANlbnYGbWVtb3J5AgCAAgNlbnYFdGFibGUBcAGIJ4gnA/8h/SEDDwMCAQAPAAADAwEHAAcDAQIDDgMBAgcECAIBAAIIAAEBAwAJBwMMDAABAQUAAQEBBg0DBwcBAQEBAQEBAQMBAwEJCAABDQkADAcJAA0AAAAQAQEBAQABAQABAAMACgMMAw4MAwMDAwMDAwMDAwMDAwMDAwUDAAMDBwMDAAADAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAMABwMABwcHAwMDAwMDAAMBAAEBAQAAAAABAQ8NGAcTAA4ADQAHFhkLAAEBDAAAAAEBAwMDAwwDDQAAAAAAAAAAAAAAAAAOCBMDAwAOBwMDAA4HAwMADQcDAwALBwMDAA0HAwMAAwcDAwEDBwMDAQAAAwMADgMDAAMBAA4AAwADAQMBAwEDAQMDAwMHDgEBAwADAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMBAAMDAwMDAwMDAwMHAwMDAwMDAwMDAwMDAwMDAwcDBwAJAAAOAAkAAAMOAA4BAwMBAwADAwAAAwADDQMHBwAAAA0DDgMDAAMDAwEADg0BAwkJBw0JAAEBDQQLDQ0KCg4KCgcHCQ0ODg0NAw0NAw0NAw0NDQ0NDQ0NDg4ODgUODg0AAA4AAAADAQAACQkJCQkJCQ4AAAMKAwAJAA4ODg4NDQ0ABQ4NBwcHBw0AAA0ADg0NAA0AAAAAAAMAAQwNBwANDAAHAAADBwUABwUJDAUJCQAHAAAAAA4BDgEBAQEAAQEBDQ4HAwEBAwwHAAcJBxoKDgsACw0NDgANAAAQCAMDAAoHAwUDAAoHAwkAAw0ADAkDDAABAA4MBQwCAQAAAgAHAAABBQcIBw0ADQ4OBwAJDAUABwAHBwADAAoAAAAAAQwDBw0ODQoNAAcHBwMACQcDBQwQBQUFDAUMBQUMBQUFBQUFBQUFDgUFBRAFBQAFBQUFAAQOAAEBAwABAQcBAwAHAwAHAwAHAwAHCAEBAwABAQcAAQEDAAEBAwADAAcDAAcDAAcDAAcDAAcBAQMAAQEDAAEBAwADAAcDAAcDAAcDAAMAAQEDAAMABwMAAwABAQMAAwAHAwAHAQEDAAEBAwADAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAMABwMDBwcHAAAKEAUFAAkAAAcHDgAHBwMKDgMADg4DAAAHDgAMBRAOCQMADgMAAwAHBwcJBwcHBwcDBwcJDQAABwwIAAAAAAsHBwUMDgAKFg4NBwMACQcHBwAHBwcHAwcHCQkFAwAFBwcAAA4FAAADAAAHBwcHCQcJCQcJBwcHCQMACQkHBwAAAAAHAwAHBwMABwoDAAMOCgcJCQcOBwAAAAcDBw4HAQQNDQADAAMAAwAAAAMABwMABwMDAwMJDg4AAwAHAwAHAAMABwcOAQEDAAEBCQkADgoHAAADAA4AAwAHAAcHBw4ABwcDAQcNAwMHAwMAAwAAAAcHBwADAwUKBwAAAA4DBw0HBwcHBwcHBwcHBwcHBwcHEAkHDAcHBwcMDgkJAQEAAAABAAAAAAAAAAAHDgEAAQAHAAEBDgEBAQEBAQEBAQEBAQEBAQUBAQEAAAAAAQEAAAEBAAEBAQEBAQEAAAEBAQEBAQcHAQcNAQMDAwEHDgAABwAAAwADAAABBwMABwMABwMAAAEBAwMDAwAABQAAAAwAAAMABQMABQEBAwMDAwAABwc
|
||
|
return fetch(wasmData).then(res => res.arrayBuffer()).then(fromBinary);
|
||
|
};
|
||
|
|
||
|
const fromBinary = (wasmBinary) => {
|
||
|
const Module = { wasmBinary };
|
||
|
const promise = new Promise(res => {
|
||
|
Module.onRuntimeInitialized = () => {
|
||
|
res(Module);
|
||
|
};
|
||
|
});
|
||
|
{
|
||
|
var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}Module["arguments"]=[];Module["thisProgram"]="./this.program";Module["quit"]=function(status,toThrow){throw toThrow};Module["preRun"]=[];Module["postRun"]=[];var ENVIRONMENT_IS_WEB=false;var ENVIRONMENT_IS_WORKER=false;var ENVIRONMENT_IS_NODE=false;var ENVIRONMENT_IS_SHELL=false;ENVIRONMENT_IS_WEB=typeof window==="object";ENVIRONMENT_IS_WORKER=typeof importScripts==="function";ENVIRONMENT_IS_NODE=typeof process==="object"&&typeof require==="function"&&!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_WORKER;ENVIRONMENT_IS_SHELL=!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_NODE&&!ENVIRONMENT_IS_WORKER;var scriptDirectory="";function locateFile(path$$1){if(Module["locateFile"]){return Module["locateFile"](path$$1,scriptDirectory)}else{return scriptDirectory+path$$1}}if(ENVIRONMENT_IS_NODE){scriptDirectory=__dirname+"/";var nodeFS;var nodePath;Module["read"]=function shell_read(filename,binary){var ret;if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);ret=nodeFS["readFileSync"](filename);return binary?ret:ret.toString()};Module["readBinary"]=function readBinary(filename){var ret=Module["read"](filename,true);if(!ret.buffer){ret=new Uint8Array(ret);}assert(ret.buffer);return ret};if(process["argv"].length>1){Module["thisProgram"]=process["argv"][1].replace(/\\/g,"/");}Module["arguments"]=process["argv"].slice(2);{module["exports"]=Module;}process["on"]("uncaughtException",function(ex){if(!(ex instanceof ExitStatus)){throw ex}});process["on"]("unhandledRejection",abort);Module["quit"]=function(status){process["exit"](status);};Module["inspect"]=function(){return"[Emscripten Module object]"};}else if(ENVIRONMENT_IS_SHELL){if(typeof read!="undefined"){Module["read"]=function shell_read(f){return read(f)};}Module["readBinary"]=function readBinary(f){var data;if(typeof readbuffer==="function"){return new Uint8Array(readbuffer(f))}data=read(f,"binary");assert(typeof data==="object");return data};if(typeof scriptArgs!="undefined"){Module["arguments"]=scriptArgs;}else if(typeof arguments!="undefined"){Module["arguments"]=arguments;}if(typeof quit==="function"){Module["quit"]=function(status){quit(status);};}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href;}else if(document.currentScript){scriptDirectory=document.currentScript.src;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}Module["read"]=function shell_read(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.send(null);return xhr.responseText};if(ENVIRONMENT_IS_WORKER){Module["readBinary"]=function readBinary(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.responseType="arraybuffer";xhr.send(null);return new Uint8Array(xhr.response)};}Module["readAsync"]=function readAsync(url,onload,onerror){var xhr=new XMLHttpRequest;xhr.open("GET",url,true);xhr.responseType="arraybuffer";xhr.onload=function xhr_onload(){if(xhr.status==200||xhr.status==0&&xhr.response){onload(xhr.response);return}onerror();};xhr.onerror=onerror;xhr.send(null);};Module["setWindowTitle"]=function(title){document.title=title;};}var out=Module["print"]||(typeof console!=="undefined"?console.log.bind(console):typeof print!=="undefined"?print:null);var err=Module["printErr"]||(typeof printErr!=="undefined"?printErr:typeof console!=="undefined"&&console.warn.bind(console)||out);for(key in moduleOverrides){if(moduleOverrides.hasOwnProperty(key)){Module[key]=moduleOverrides[key];}}moduleOverrides=undefined;var asm2wasmImports={"f64-rem":function(x,y){return x%y},"debugger":function(){debugger}};var tempRet0=0;var setTempRet0=function(value){tempRet0=value;};var getTempRet0=function(){return tempRet0};if(typeof WebAssembly!=="object"){err("no native wasm support detected");}var wasmMemory;var wasmTable;var ABORT=false;function assert(condition,text){if(!condition){abort("Assertion failed: "+text);}}var UT
|
||
|
}
|
||
|
return promise;
|
||
|
};
|
||
|
});
|
||
|
var dist_1 = dist.instantiate;
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google Inc. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
/** Whether we submit commands to the device queue immediately. */
|
||
|
tf.ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
|
||
|
/**
|
||
|
* Thread register block size for matmul kernel. If 0, we use the version of
|
||
|
* matMul without register blocking.
|
||
|
*/
|
||
|
tf.ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);
|
||
|
/**
|
||
|
* -1: conv2d_naive
|
||
|
* 0: conv2d_mm with matmul without register blocking
|
||
|
* >0: conv2d_mm with matmul_packed with WPT=this
|
||
|
*/
|
||
|
tf.ENV.registerFlag('WEBGPU_CONV2D_WORK_PER_THREAD', () => 2);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class BufferManager {
|
||
|
constructor(device) {
|
||
|
this.device = device;
|
||
|
this.numUsedBuffers = 0;
|
||
|
this.numFreeBuffers = 0;
|
||
|
this.freeBuffers = new Map();
|
||
|
this.usedBuffers = new Map();
|
||
|
this.numBytesUsed = 0;
|
||
|
}
|
||
|
acquireBuffer(byteSize, usage) {
|
||
|
const key = getBufferKey(byteSize, usage);
|
||
|
if (!this.freeBuffers.has(key)) {
|
||
|
this.freeBuffers.set(key, []);
|
||
|
}
|
||
|
if (!this.usedBuffers.has(key)) {
|
||
|
this.usedBuffers.set(key, []);
|
||
|
}
|
||
|
this.numBytesUsed += byteSize;
|
||
|
this.numUsedBuffers++;
|
||
|
if (this.freeBuffers.get(key).length > 0) {
|
||
|
this.numFreeBuffers--;
|
||
|
const newBuffer = this.freeBuffers.get(key).shift();
|
||
|
this.usedBuffers.get(key).push(newBuffer);
|
||
|
return newBuffer;
|
||
|
}
|
||
|
const newBuffer = this.device.createBuffer({ size: byteSize, usage });
|
||
|
this.usedBuffers.get(key).push(newBuffer);
|
||
|
return newBuffer;
|
||
|
}
|
||
|
releaseBuffer(buffer, byteSize, usage) {
|
||
|
if (this.freeBuffers == null) {
|
||
|
return;
|
||
|
}
|
||
|
const key = getBufferKey(byteSize, usage);
|
||
|
if (!this.freeBuffers.has(key)) {
|
||
|
this.freeBuffers.set(key, []);
|
||
|
}
|
||
|
this.freeBuffers.get(key).push(buffer);
|
||
|
this.numFreeBuffers++;
|
||
|
this.numUsedBuffers--;
|
||
|
const bufferList = this.usedBuffers.get(key);
|
||
|
const bufferIndex = bufferList.indexOf(buffer);
|
||
|
if (bufferIndex < 0) {
|
||
|
throw new Error('Cannot release a buffer that was never provided by this ' +
|
||
|
'buffer manager');
|
||
|
}
|
||
|
bufferList.splice(bufferIndex, 1);
|
||
|
this.numBytesUsed -= byteSize;
|
||
|
}
|
||
|
getNumUsedBuffers() {
|
||
|
return this.numUsedBuffers;
|
||
|
}
|
||
|
getNumFreeBuffers() {
|
||
|
return this.numFreeBuffers;
|
||
|
}
|
||
|
reset() {
|
||
|
this.freeBuffers = new Map();
|
||
|
this.usedBuffers = new Map();
|
||
|
this.numUsedBuffers = 0;
|
||
|
this.numFreeBuffers = 0;
|
||
|
}
|
||
|
dispose() {
|
||
|
if (this.freeBuffers == null) {
|
||
|
return;
|
||
|
}
|
||
|
for (const key in this.freeBuffers) {
|
||
|
this.freeBuffers.get(key).forEach(buff => {
|
||
|
buff.destroy();
|
||
|
});
|
||
|
}
|
||
|
for (const key in this.usedBuffers) {
|
||
|
this.usedBuffers.get(key).forEach(buff => {
|
||
|
buff.destroy();
|
||
|
});
|
||
|
}
|
||
|
this.freeBuffers = null;
|
||
|
this.usedBuffers = null;
|
||
|
this.numUsedBuffers = 0;
|
||
|
this.numFreeBuffers = 0;
|
||
|
}
|
||
|
}
|
||
|
function getBufferKey(byteSize, usage) {
|
||
|
return `${byteSize}_${usage}`;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
// Generates GLSL that computes strides.
|
||
|
function symbolicallyComputeStrides(indicesArr, variableName) {
|
||
|
if (Math.max(...indicesArr) > 3) {
|
||
|
throw new Error('Cannot symbolically compute strides for rank > 4 tensor.');
|
||
|
}
|
||
|
const numCoords = indicesArr.length;
|
||
|
const shape = indicesArr.map(d => `${variableName}[${d}]`);
|
||
|
const strides = new Array(numCoords - 1);
|
||
|
strides[numCoords - 2] = shape[numCoords - 1];
|
||
|
for (let i = numCoords - 3; i >= 0; --i) {
|
||
|
strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
|
||
|
}
|
||
|
return strides;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function getCoordsDataType(rank) {
|
||
|
if (rank <= 1) {
|
||
|
return 'int';
|
||
|
}
|
||
|
else if (rank === 2) {
|
||
|
return 'ivec2';
|
||
|
}
|
||
|
else if (rank === 3) {
|
||
|
return 'ivec3';
|
||
|
}
|
||
|
else if (rank === 4) {
|
||
|
return 'ivec4';
|
||
|
}
|
||
|
else {
|
||
|
throw Error(`GPU for rank ${rank} is not yet supported`);
|
||
|
}
|
||
|
}
|
||
|
function mapToGlslTypes(type) {
|
||
|
if (type === 'float32') {
|
||
|
return 'float';
|
||
|
}
|
||
|
if (type === 'int32') {
|
||
|
return 'int';
|
||
|
}
|
||
|
return type;
|
||
|
}
|
||
|
function makeShader(inputInfo, outputData, program) {
|
||
|
const prefixSnippets = [];
|
||
|
if (program.workGroupSize != null) {
|
||
|
prefixSnippets.push(`
|
||
|
layout (local_size_x = ${program.workGroupSize[0]},
|
||
|
local_size_y = ${program.workGroupSize[1]},
|
||
|
local_size_z = ${program.workGroupSize[2]}) in;
|
||
|
`);
|
||
|
}
|
||
|
// Output buffer.
|
||
|
prefixSnippets.push(`
|
||
|
layout(std430, set = 0, binding = 0) writeonly buffer ssbOut {
|
||
|
${mapToGlslTypes(outputData.dtype)} result[];
|
||
|
};
|
||
|
`);
|
||
|
let uniformDeclaration = '';
|
||
|
program.variableNames.forEach((x, i) => {
|
||
|
uniformDeclaration += `${getCoordsDataType(inputInfo[i].shape.length)} ${x.charAt(0).toLowerCase() + x.slice(1)}Shape; `;
|
||
|
prefixSnippets.push(`
|
||
|
layout(std430, set = 0, binding = ${1 + i}) readonly buffer ssb${x} {
|
||
|
${mapToGlslTypes(inputInfo[i].dtype)} ${x}[];
|
||
|
};
|
||
|
`);
|
||
|
});
|
||
|
uniformDeclaration +=
|
||
|
`${getCoordsDataType(outputData.shape.length)} outShape; `;
|
||
|
if (program.uniforms) {
|
||
|
uniformDeclaration += program.uniforms;
|
||
|
}
|
||
|
prefixSnippets.push(`
|
||
|
layout(std140, set = 0, binding = ${1 + program.variableNames.length}) uniform Uniforms {
|
||
|
${uniformDeclaration}
|
||
|
};
|
||
|
`);
|
||
|
const [getOutputCoords, dispatchLayoutRank] = generateGetOutputCoords(program.dispatchLayout);
|
||
|
const getCoords = generateGetCoordsFromFlatIndex(outputData.shape);
|
||
|
const sources = [
|
||
|
SHADER_PREFIX, prefixSnippets.join('\n'), SAMPLING_SNIPPETS,
|
||
|
getOutputCoords, getCoords,
|
||
|
getSetOutputSnippet(outputData.shape.length, outputData.dtype)
|
||
|
];
|
||
|
if (dispatchLayoutRank === outputData.shape.length) {
|
||
|
// Input sampling snippet is only meaningful when the output isn't getting
|
||
|
// implicitly reshaped (like it does in conv2d_matmul).
|
||
|
const inputSamplingSnippet = inputInfo.map(x => getInputSamplingSnippet(x, outputData.shape))
|
||
|
.join('\n');
|
||
|
sources.push(inputSamplingSnippet);
|
||
|
}
|
||
|
sources.push(program.userCode);
|
||
|
const source = sources.join('\n');
|
||
|
return source;
|
||
|
}
|
||
|
const SHADER_PREFIX = `#version 450
|
||
|
|
||
|
int idiv(int a, int b, float sign) {
|
||
|
int res = a / b;
|
||
|
int mod = a % b;
|
||
|
if (sign < 0. && mod != 0) {
|
||
|
res -= 1;
|
||
|
}
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
bool coordIsValid(ivec4 coord, ivec4 shape) {
|
||
|
return all(greaterThanEqual(coord, ivec4(0))) &&
|
||
|
all(lessThan(coord, shape));
|
||
|
}
|
||
|
`;
|
||
|
const SAMPLING_SNIPPETS = `
|
||
|
uint getFlatIndex(uint coord, uint shape) {
|
||
|
return coord;
|
||
|
}
|
||
|
|
||
|
uint getFlatIndex(ivec2 coords, ivec2 shape) {
|
||
|
return uint(dot(coords, ivec2(shape.y, 1.)));
|
||
|
}
|
||
|
|
||
|
uint getFlatIndex(ivec3 coords, ivec3 shape) {
|
||
|
return uint(dot(coords, ivec3(shape.y * shape.z, shape.z, 1.)));
|
||
|
}
|
||
|
|
||
|
uint getFlatIndex(ivec4 coords, ivec4 shape) {
|
||
|
return uint(dot(coords, ivec4(
|
||
|
shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1.)));
|
||
|
}
|
||
|
`;
|
||
|
function getSetOutputSnippet(outRank, outBufferType) {
|
||
|
let snippet = `void setOutput(uint flatIndex, float value) {
|
||
|
result[flatIndex] = ${mapToGlslTypes(outBufferType) === 'int' ? 'int(value)' : 'value'};
|
||
|
}
|
||
|
void setOutput(uint flatIndex, int value) {
|
||
|
result[flatIndex] = ${mapToGlslTypes(outBufferType) === 'float' ? 'float(value)' : 'value'};
|
||
|
}`;
|
||
|
if (outRank >= 2) {
|
||
|
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, outRank);
|
||
|
const type = getCoordsDataType(outRank);
|
||
|
snippet += `
|
||
|
void setOutput(${dims.map(d => `int ${d}`).join(', ')}, float value) {
|
||
|
uint flatIndex = getFlatIndex(${type}(${dims.join(', ')}), outShape);
|
||
|
setOutput(flatIndex, value);
|
||
|
}
|
||
|
void setOutput(${dims.map(d => `int ${d}`).join(', ')}, int value) {
|
||
|
uint flatIndex = getFlatIndex(${type}(${dims.join(', ')}), outShape);
|
||
|
setOutput(flatIndex, value);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
return snippet;
|
||
|
}
|
||
|
function getInputSamplingSnippet(inInfo, outShape) {
|
||
|
let res = getSamplerFromInInfo(inInfo);
|
||
|
const inShape = inInfo.shape;
|
||
|
if (inShape.length <= outShape.length) {
|
||
|
res += getSamplerAtOutputCoords(inInfo, outShape);
|
||
|
}
|
||
|
return res;
|
||
|
}
|
||
|
function getSamplerFromInInfo(inInfo) {
|
||
|
const texName = inInfo.name;
|
||
|
const rank = inInfo.shape.length;
|
||
|
const type = getCoordsDataType(rank);
|
||
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
||
|
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, rank);
|
||
|
const inputs = dims.map(d => `int ${d}`).join(', ');
|
||
|
if (rank < 1) {
|
||
|
return `
|
||
|
float ${funcName}() {
|
||
|
return ${texName}[0];
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
return `
|
||
|
float ${funcName}(${inputs}) {
|
||
|
return ${texName}[getFlatIndex(${type}(${dims.join(',')}),
|
||
|
${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
function getSamplerAtOutputCoords(inInfo, outShape) {
|
||
|
const texName = inInfo.name;
|
||
|
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
|
||
|
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
|
||
|
const inRank = inInfo.shape.length;
|
||
|
const outRank = outShape.length;
|
||
|
const type = getCoordsDataType(outRank);
|
||
|
const broadcastDims = tf.backend_util.getBroadcastDims(inInfo.shape, outShape);
|
||
|
const rankDiff = outRank - inRank;
|
||
|
let coordsSnippet = '';
|
||
|
if (inRank === 0) {
|
||
|
return `
|
||
|
float ${funcName}() {
|
||
|
return get${texFuncSnippet}();
|
||
|
}
|
||
|
|
||
|
float ${funcName}(${type} coords) {
|
||
|
return get${texFuncSnippet}();
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
else {
|
||
|
if (outRank < 2 && broadcastDims.length >= 1) {
|
||
|
coordsSnippet = 'coords = 0;';
|
||
|
}
|
||
|
else {
|
||
|
coordsSnippet =
|
||
|
broadcastDims.map(d => `coords[${d + rankDiff}] = 0;`).join('\n');
|
||
|
}
|
||
|
}
|
||
|
let unpackedCoordsSnippet = '';
|
||
|
if (outRank < 2 && inRank > 0) {
|
||
|
unpackedCoordsSnippet = 'coords';
|
||
|
}
|
||
|
else {
|
||
|
if (outRank > 1) {
|
||
|
const coordsType = getCoordsDataType(inRank);
|
||
|
const coordsValues = inInfo.shape.map((s, i) => `coords[${i + rankDiff}]`).join(', ');
|
||
|
unpackedCoordsSnippet = `${coordsType}(${coordsValues})`;
|
||
|
}
|
||
|
else {
|
||
|
unpackedCoordsSnippet = 'coords';
|
||
|
}
|
||
|
}
|
||
|
return `
|
||
|
float ${funcName}() {
|
||
|
${type} coords = getOutputCoords();
|
||
|
${coordsSnippet}
|
||
|
return ${texName}[getFlatIndex(${unpackedCoordsSnippet}, ${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
|
||
|
}
|
||
|
|
||
|
float ${funcName}(${type} coords) {
|
||
|
${coordsSnippet}
|
||
|
return ${texName}[getFlatIndex(${unpackedCoordsSnippet}, ${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
/**
|
||
|
* Generates getOutputCoords() function that computes output coordinates from
|
||
|
* dispatch geometry to reduce arithmetic.
|
||
|
*/
|
||
|
function generateGetOutputCoords(dispatchLayout) {
|
||
|
const { x, y = [], z = [] } = dispatchLayout;
|
||
|
let gatherDimensionsStr = '';
|
||
|
const dims = [x, y, z];
|
||
|
let rank = 0;
|
||
|
for (let i = 0; i < dims.length; i++) {
|
||
|
const arr = dims[i];
|
||
|
if (arr.length === 0) {
|
||
|
continue;
|
||
|
}
|
||
|
rank += arr.length;
|
||
|
if (arr.length === 1) {
|
||
|
gatherDimensionsStr += `uint d${arr[0]} = gl_GlobalInvocationID[${i}];`;
|
||
|
}
|
||
|
else {
|
||
|
const strides = symbolicallyComputeStrides(arr, 'outShape');
|
||
|
gatherDimensionsStr += `uint index${i} =
|
||
|
gl_GlobalInvocationID[${i}];`;
|
||
|
for (let j = 0; j < strides.length; j++) {
|
||
|
gatherDimensionsStr += `uint d${arr[j]} = index${i} / ${strides[j]};`;
|
||
|
if (j === strides.length - 1) {
|
||
|
gatherDimensionsStr += `uint d${arr[j + 1]} = ` +
|
||
|
`index${i} - d${arr[j]} * ${strides[j]};`;
|
||
|
}
|
||
|
else {
|
||
|
gatherDimensionsStr += `index${i} -= d${arr[j]} * ${strides[j]};`;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
const dimensions = [];
|
||
|
for (let i = 0; i < rank; i++) {
|
||
|
dimensions.push(`d${i}`);
|
||
|
}
|
||
|
const dtype = getCoordsDataType(rank);
|
||
|
const snippet = `${dtype} getOutputCoords() {
|
||
|
${gatherDimensionsStr}
|
||
|
|
||
|
return ${dtype}(${dimensions.join(',')});
|
||
|
}`;
|
||
|
return [snippet, rank];
|
||
|
}
|
||
|
/**
|
||
|
* Derives logical coordinates from a flat index. Performs integer division with
|
||
|
* each stride and decrements the index until the index equals the final
|
||
|
* dimension coordinate.
|
||
|
*/
|
||
|
function generateGetCoordsFromFlatIndex(shape) {
|
||
|
const rank = shape.length;
|
||
|
if (rank <= 1) {
|
||
|
return `int getCoordsFromFlatIndex(int index) {return index; }`;
|
||
|
}
|
||
|
const strides = tf.util.computeStrides(shape);
|
||
|
const dtype = getCoordsDataType(rank);
|
||
|
const coords = [];
|
||
|
for (let i = 0; i < rank; i++) {
|
||
|
coords.push(`d${i}`);
|
||
|
}
|
||
|
const snippet = strides
|
||
|
.map((stride, i) => {
|
||
|
const line1 = `uint ${coords[i]} = index / ${stride}`;
|
||
|
const line2 = i === strides.length - 1 ?
|
||
|
`uint ${coords[i + 1]} = index - ${coords[i]} * ${stride}` :
|
||
|
`index -= ${coords[i]} * ${stride}`;
|
||
|
return `${line1}; ${line2};`;
|
||
|
})
|
||
|
.join('');
|
||
|
return `
|
||
|
${dtype} getCoordsFromFlatIndex(uint index) {
|
||
|
${snippet}
|
||
|
return ${dtype}(${coords.join(',')});
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const arrayProduct = (arr) => {
|
||
|
let product = 1;
|
||
|
for (let i = 0; i < arr.length; i++) {
|
||
|
product *= arr[i];
|
||
|
}
|
||
|
return product;
|
||
|
};
|
||
|
// Computes dispatch geometry based on layout of output dimensions and
|
||
|
// workGroupSize.
|
||
|
function computeDispatch(layout, outputShape, workGroupSize = [1, 1, 1], elementsPerThread = [1, 1, 1]) {
|
||
|
return [
|
||
|
Math.ceil(arrayProduct(layout.x.map(d => outputShape[d])) /
|
||
|
(workGroupSize[0] * elementsPerThread[0])),
|
||
|
layout.y ? Math.ceil(arrayProduct(layout.y.map(d => outputShape[d])) /
|
||
|
(workGroupSize[1] * elementsPerThread[1])) :
|
||
|
1,
|
||
|
layout.z ? Math.ceil(arrayProduct(layout.z.map(d => outputShape[d])) /
|
||
|
(workGroupSize[2] * elementsPerThread[2])) :
|
||
|
1
|
||
|
];
|
||
|
}
|
||
|
function flatDispatchLayout(shape) {
|
||
|
return { x: shape.map((d, i) => i) };
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class ArgMinMaxProgram {
|
||
|
constructor(inputShape, axis, reduceType) {
|
||
|
this.variableNames = ['x'];
|
||
|
this.uniforms = 'uint axis;';
|
||
|
const axes = [axis];
|
||
|
tf.backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, inputShape.length);
|
||
|
const op = reduceType === 'min' ? '<' : '>';
|
||
|
// |outShape| is the shape with the removed axis
|
||
|
// |reduceShape| is the shape we are reducing. i.e. [ inputShape[axis] ]
|
||
|
const [outputShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(inputShape, axes);
|
||
|
this.outputShape = outputShape.length === 0 ? [1] : outputShape;
|
||
|
// Length of the axis we're reducing on.
|
||
|
const reduceSize = tf.util.sizeFromShape(reduceShape);
|
||
|
// The number of comparisons each thread will do
|
||
|
const reductionFactor = 2;
|
||
|
const xMaxThreads = 1024; // gl_MaxComputeWorkGroupSize
|
||
|
const xThreads = Math.min(Math.ceil(reduceSize / reductionFactor), xMaxThreads);
|
||
|
this.workGroupSize = [xThreads, 1, 1];
|
||
|
this.dispatchLayout = { x: [], y: this.outputShape.map((d, i) => i) };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
|
||
|
// When xThreads > 1, each thread reduces Length / xThreads values.
|
||
|
// Thes results are stored in shared memory and iteratively reduced.
|
||
|
const reduceInSharedMemory = xThreads > 1;
|
||
|
const sharedMemorySnippet = `
|
||
|
shared uint xBestIndices[WorkGroupSize];
|
||
|
shared float xBestValues[WorkGroupSize];
|
||
|
`;
|
||
|
const sharedMemoryReduceSnippet = `
|
||
|
xBestIndices[gl_LocalInvocationID.x] = bestIndex;
|
||
|
xBestValues[gl_LocalInvocationID.x] = bestValue;
|
||
|
|
||
|
uint currentSize = WorkGroupSize;
|
||
|
while (currentSize > 1) {
|
||
|
barrier();
|
||
|
|
||
|
for (uint w = 0; w < ${reductionFactor}; ++w) {
|
||
|
uint i = gl_LocalInvocationID.x * ${reductionFactor} + w;
|
||
|
if (i < currentSize) {
|
||
|
uint candidateIndex = xBestIndices[i];
|
||
|
float candidate = xBestValues[i];
|
||
|
if (candidate ${op} bestValue && !isnan(candidate)) {
|
||
|
bestValue = candidate;
|
||
|
bestIndex = candidateIndex;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
xBestIndices[gl_LocalInvocationID.x] = bestIndex;
|
||
|
xBestValues[gl_LocalInvocationID.x] = bestValue;
|
||
|
|
||
|
currentSize = DIV_CEIL(currentSize, ${reductionFactor});
|
||
|
}
|
||
|
|
||
|
if (gl_LocalInvocationID.x == 0) {
|
||
|
setOutput(flatOutputIndex, int(bestIndex));
|
||
|
}
|
||
|
`;
|
||
|
const outputCoordsType = getCoordsDataType(this.outputShape.length);
|
||
|
const indexOutputCoords = (outputCoords, index) => {
|
||
|
if (this.outputShape.length === 1) {
|
||
|
return outputCoords;
|
||
|
}
|
||
|
else {
|
||
|
return `${outputCoords}[${index}]`;
|
||
|
}
|
||
|
};
|
||
|
const indexInputShape = (index) => {
|
||
|
if (inputShape.length === 1) {
|
||
|
return 'xShape';
|
||
|
}
|
||
|
else {
|
||
|
return `xShape[${index}]`;
|
||
|
}
|
||
|
};
|
||
|
this.userCode = `
|
||
|
#define DIV_CEIL(x, y) (((x) - 1) / (y) + 1)
|
||
|
|
||
|
const uint WorkGroupSize = gl_WorkGroupSize.x;
|
||
|
|
||
|
${reduceInSharedMemory ? sharedMemorySnippet : ''}
|
||
|
|
||
|
// In order to get a flattened index into the input tensor, we need to
|
||
|
// add back the index along the reduced dimension to |outputCoords|.
|
||
|
// This function outputs the offset to the first value along
|
||
|
// |axis| and the stride to get the next value of the input along |axis|.
|
||
|
uvec2 getInputCoordInfo() {
|
||
|
const ${outputCoordsType} outputCoords = getOutputCoords();
|
||
|
uint i = ${this.outputShape.length - 1};
|
||
|
|
||
|
uint stride = 1;
|
||
|
uint inputStride = 1;
|
||
|
uint offset = 0;
|
||
|
|
||
|
for (uint r = 1; r <= ${inputShape.length}; ++r) {
|
||
|
uint length = ${indexInputShape(`${inputShape.length} - r`)};
|
||
|
if (${inputShape.length} - r == axis) {
|
||
|
inputStride = stride;
|
||
|
} else {
|
||
|
offset += ${indexOutputCoords('outputCoords', 'i--')} * stride;
|
||
|
}
|
||
|
stride *= length;
|
||
|
}
|
||
|
|
||
|
return uvec2(offset, inputStride);
|
||
|
}
|
||
|
|
||
|
uint getInputIndex(uvec2 coordInfo, uint index) {
|
||
|
return coordInfo[0] + coordInfo[1] * index;
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
const uvec2 coordInfo = getInputCoordInfo();
|
||
|
|
||
|
uint bestIndex = 0;
|
||
|
float bestValue = x[getInputIndex(coordInfo, bestIndex)];
|
||
|
|
||
|
const uint Length = ${indexInputShape('axis')};
|
||
|
const uint WorkPerThread = DIV_CEIL(Length, WorkGroupSize);
|
||
|
|
||
|
for (uint w = 0; w < WorkPerThread; ++w) {
|
||
|
uint i = gl_GlobalInvocationID.x * WorkPerThread + w;
|
||
|
if (i < Length) {
|
||
|
float candidate = x[getInputIndex(coordInfo, i)];
|
||
|
if (candidate ${op} bestValue && !isnan(candidate)) {
|
||
|
bestValue = candidate;
|
||
|
bestIndex = i;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const uint flatOutputIndex = gl_GlobalInvocationID.y;
|
||
|
${reduceInSharedMemory ? sharedMemoryReduceSnippet :
|
||
|
'setOutput(flatOutputIndex, int(bestIndex));'}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const MUL = 'return a * b;';
|
||
|
const ADD = 'return a + b;';
|
||
|
const SUB = 'return a - b;';
|
||
|
const INT_DIV = `
|
||
|
float s = sign(a) * sign(b);
|
||
|
int ia = int(round(a));
|
||
|
int ib = int(round(b));
|
||
|
return float(idiv(ia, ib, s));
|
||
|
`;
|
||
|
class BinaryOpProgram {
|
||
|
constructor(op, aShape, bShape) {
|
||
|
this.variableNames = ['A', 'B'];
|
||
|
this.workPerThread = 4;
|
||
|
this.workGroupSize = [1, 1, 1];
|
||
|
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
||
|
const size = tf.util.sizeFromShape(this.outputShape);
|
||
|
this.dispatchLayout = flatDispatchLayout(this.outputShape);
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
|
||
|
const type = getCoordsDataType(this.outputShape.length);
|
||
|
this.userCode = `
|
||
|
float binaryOperation(float a, float b) {
|
||
|
${op}
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
int index = int(gl_GlobalInvocationID.x);
|
||
|
|
||
|
for(int i = 0; i < ${this.workPerThread}; i++) {
|
||
|
int flatIndex = index * ${this.workPerThread} + i;
|
||
|
|
||
|
if(flatIndex < ${size}) {
|
||
|
${type} coords = getCoordsFromFlatIndex(flatIndex);
|
||
|
|
||
|
float a = getAAtOutCoords(coords);
|
||
|
float b = getBAtOutCoords(coords);
|
||
|
setOutput(flatIndex, binaryOperation(a, b));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class ConcatProgram {
|
||
|
constructor(shapes) {
|
||
|
this.outputShape =
|
||
|
tf.backend_util.computeOutShape(shapes, 1 /* axis */);
|
||
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
||
|
this.dispatchLayout = { x: [0], y: [1] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
|
||
|
const offsets = new Array(shapes.length - 1);
|
||
|
offsets[0] = shapes[0][1];
|
||
|
for (let i = 1; i < offsets.length; i++) {
|
||
|
offsets[i] = offsets[i - 1] + shapes[i][1];
|
||
|
}
|
||
|
const snippets = [
|
||
|
`if (yC < ${offsets[0]}) setOutput(coords.x, coords.y, getT0(yR, yC));`
|
||
|
];
|
||
|
for (let i = 1; i < offsets.length; i++) {
|
||
|
const shift = offsets[i - 1];
|
||
|
snippets.push(`else if (yC < ${offsets[i]}) ` +
|
||
|
`setOutput(coords.x, coords.y, getT${i}(yR, yC-${shift}));`);
|
||
|
}
|
||
|
const lastIndex = offsets.length;
|
||
|
const lastShift = offsets[offsets.length - 1];
|
||
|
snippets.push(`else setOutput(coords.x, coords.y, getT${lastIndex}(yR, yC-${lastShift}));`);
|
||
|
this.userCode = `
|
||
|
void main() {
|
||
|
ivec2 coords = getOutputCoords();
|
||
|
int yR = coords.x;
|
||
|
int yC = coords.y;
|
||
|
|
||
|
${snippets.join('\n ')}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const matMulHeader = `
|
||
|
float mm_readA(uint row, uint col);
|
||
|
float mm_readB(uint row, uint col);
|
||
|
void mm_write(uint row, uint col, float value);
|
||
|
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter);`;
|
||
|
function makeMatMulSource() {
|
||
|
return `
|
||
|
${matMulHeader}
|
||
|
|
||
|
const uint MatTileSize = gl_WorkGroupSize.x; // .x == .y
|
||
|
shared float mm_Asub[MatTileSize][MatTileSize];
|
||
|
shared float mm_Bsub[MatTileSize][MatTileSize];
|
||
|
|
||
|
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
|
||
|
uint localRow = gl_LocalInvocationID.y; // 0..MatTileSize
|
||
|
uint localCol = gl_LocalInvocationID.x; // 0..MatTileSize
|
||
|
uint globalRow = gl_GlobalInvocationID.y; // AOuter
|
||
|
uint globalCol = gl_GlobalInvocationID.x; // Inner
|
||
|
|
||
|
float acc = 0.0;
|
||
|
|
||
|
uint numTiles = (dimInner - 1) / MatTileSize + 1;
|
||
|
|
||
|
for (uint t = 0; t < numTiles; t++) {
|
||
|
// Load one tile of A and B into local memory
|
||
|
uint tiledACol = MatTileSize * t + localCol;
|
||
|
uint tiledBRow = MatTileSize * t + localRow;
|
||
|
mm_Asub[localRow][localCol] = mm_readA(globalRow, tiledACol);
|
||
|
mm_Bsub[localRow][localCol] = mm_readB(tiledBRow, globalCol);
|
||
|
|
||
|
// Synchronise to make sure the tile is loaded
|
||
|
barrier();
|
||
|
|
||
|
for (uint k = 0; k < MatTileSize; k++) {
|
||
|
acc += mm_Asub[localRow][k] * mm_Bsub[k][localCol];
|
||
|
}
|
||
|
|
||
|
// Synchronise before loading the next tile
|
||
|
barrier();
|
||
|
}
|
||
|
|
||
|
if (globalCol < dimBOuter && globalRow < dimAOuter) {
|
||
|
mm_write(globalRow, globalCol, acc);
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
class MatMulProgram {
|
||
|
constructor(outputShape) {
|
||
|
this.variableNames = ['A', 'B'];
|
||
|
this.workGroupSize = [16, 16, 1]; // Must be square.
|
||
|
this.outputShape = outputShape;
|
||
|
this.dispatchLayout = { x: [1], y: [2], z: [0] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
|
||
|
this.userCode = `
|
||
|
uint dimAOuter = aShape[1];
|
||
|
uint dimInner = aShape[2];
|
||
|
uint dimBOuter = bShape[2];
|
||
|
|
||
|
${makeMatMulSource()}
|
||
|
|
||
|
float mm_readA(uint row, uint col) {
|
||
|
if (row < dimAOuter && col < dimInner) {
|
||
|
return A[row * dimInner + col];
|
||
|
} else {
|
||
|
return 0.0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
float mm_readB(uint row, uint col) {
|
||
|
if (row < dimInner && col < dimBOuter) {
|
||
|
return B[row * dimBOuter + col];
|
||
|
} else {
|
||
|
return 0.0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void mm_write(uint row, uint col, float value) {
|
||
|
setOutput(row * dimBOuter + col, value);
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
mm_matMul(dimAOuter, dimInner, dimBOuter);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function makeMatMulPackedSource(workPerThread) {
|
||
|
return `
|
||
|
${matMulHeader}
|
||
|
|
||
|
const uint WorkGroupSize = gl_WorkGroupSize.x; // .x == .y
|
||
|
const uint WorkPerThread = ${workPerThread};
|
||
|
const uint MatTileSize = WorkGroupSize * WorkPerThread;
|
||
|
|
||
|
shared float mm_Asub[MatTileSize][MatTileSize];
|
||
|
shared float mm_Bsub[MatTileSize][MatTileSize];
|
||
|
|
||
|
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
|
||
|
// These are 0..MatTileSize, in increments of WorkPerThread.
|
||
|
uint tileRow = gl_LocalInvocationID.y * WorkPerThread;
|
||
|
uint tileCol = gl_LocalInvocationID.x * WorkPerThread;
|
||
|
|
||
|
// These are 0..AOuter, in increments of WorkPerThread.
|
||
|
uint globalRow = gl_GlobalInvocationID.y * WorkPerThread;
|
||
|
uint globalCol = gl_GlobalInvocationID.x * WorkPerThread;
|
||
|
|
||
|
uint numTiles = (dimInner - 1) / MatTileSize + 1;
|
||
|
|
||
|
float acc[WorkPerThread][WorkPerThread];
|
||
|
float ACached;
|
||
|
float BCached[WorkPerThread];
|
||
|
|
||
|
// Without this initialization strange values show up in acc.
|
||
|
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
|
||
|
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
|
||
|
acc[innerRow][innerCol] = 0.0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Loop over shared dimension.
|
||
|
for (uint t = 0; t < numTiles; t++) {
|
||
|
// Load one tile of A and B into local memory.
|
||
|
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
|
||
|
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
|
||
|
uint inputRow = tileRow + innerRow;
|
||
|
uint inputCol = tileCol + innerCol;
|
||
|
|
||
|
mm_Asub[inputRow][inputCol] = mm_readA(
|
||
|
globalRow + innerRow,
|
||
|
t * MatTileSize + tileCol + innerCol);
|
||
|
mm_Bsub[inputRow][inputCol] = mm_readB(
|
||
|
t * MatTileSize + tileRow + innerRow,
|
||
|
globalCol + innerCol);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier();
|
||
|
|
||
|
// Compute acc values for a single thread.
|
||
|
for (uint k = 0; k < MatTileSize; k++) {
|
||
|
for (uint inner = 0; inner < WorkPerThread; inner++) {
|
||
|
BCached[inner] = mm_Bsub[k][tileCol + inner];
|
||
|
}
|
||
|
|
||
|
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
|
||
|
ACached = mm_Asub[tileRow + innerRow][k];
|
||
|
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
|
||
|
acc[innerRow][innerCol] += ACached * BCached[innerCol];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier();
|
||
|
}
|
||
|
|
||
|
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
|
||
|
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
|
||
|
uint globalFlatIndex =
|
||
|
(globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
|
||
|
|
||
|
if ((globalCol + innerCol) < dimBOuter &&
|
||
|
(globalRow + innerRow) < dimAOuter) {
|
||
|
mm_write(globalRow + innerRow,
|
||
|
globalCol + innerCol,
|
||
|
acc[innerRow][innerCol]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
class MatMulPackedProgram {
|
||
|
constructor(outputShape, workPerThread) {
|
||
|
this.variableNames = ['A', 'B'];
|
||
|
this.workGroupSize = [16, 16, 1];
|
||
|
this.outputShape = outputShape;
|
||
|
this.workPerThread = workPerThread;
|
||
|
this.dispatchLayout = { x: [1], y: [2], z: [0] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [workPerThread, workPerThread, 1]);
|
||
|
// Consider compiling a different version of the shader that doesn't care
|
||
|
// about boundary conditions when loading from Asub / Bsub when tiles fit
|
||
|
// neatly inside of output. May slightly improve performance.
|
||
|
this.userCode = `
|
||
|
uint dimAOuter = aShape[1];
|
||
|
uint dimInner = aShape[2];
|
||
|
uint dimBOuter = bShape[2];
|
||
|
|
||
|
${makeMatMulPackedSource(workPerThread)}
|
||
|
|
||
|
float mm_readA(uint row, uint col) {
|
||
|
if (row < dimAOuter && col < dimInner) {
|
||
|
return A[row * dimInner + col];
|
||
|
} else {
|
||
|
return 0.0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
float mm_readB(uint row, uint col) {
|
||
|
if (row < dimInner && col < dimBOuter) {
|
||
|
return B[row * dimBOuter + col];
|
||
|
} else {
|
||
|
return 0.0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void mm_write(uint row, uint col, float value) {
|
||
|
setOutput(row * dimBOuter + col, value);
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
mm_matMul(dimAOuter, dimInner, dimBOuter);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class Conv2DMMProgram {
|
||
|
constructor(convInfo, workPerThread) {
|
||
|
this.variableNames = ['x', 'W'];
|
||
|
this.uniforms = 'ivec2 filterDims, pad, stride;';
|
||
|
this.workGroupSize = [
|
||
|
16, 16,
|
||
|
1
|
||
|
];
|
||
|
this.outputShape = convInfo.outShape;
|
||
|
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
|
||
|
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
|
||
|
let elementsPerThread;
|
||
|
let matMulSource;
|
||
|
if (workPerThread === 0) {
|
||
|
elementsPerThread = [1, 1, 1];
|
||
|
matMulSource = makeMatMulSource();
|
||
|
}
|
||
|
else {
|
||
|
elementsPerThread = [workPerThread, workPerThread, 1];
|
||
|
matMulSource = makeMatMulPackedSource(workPerThread);
|
||
|
}
|
||
|
this.dispatchLayout = { x: [1], y: [2], z: [0] };
|
||
|
const matMulOutShape = [
|
||
|
convInfo.outShape[0], convInfo.outShape[1] * convInfo.outShape[2],
|
||
|
convInfo.outShape[3]
|
||
|
];
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, matMulOutShape, this.workGroupSize, elementsPerThread);
|
||
|
this.userCode = `
|
||
|
${matMulSource}
|
||
|
|
||
|
int batch;
|
||
|
|
||
|
float mm_readA(uint row, uint col) {
|
||
|
int r = int(row), c = int(col);
|
||
|
ivec4 coord = ivec4(
|
||
|
(c / filterDims[1]) % filterDims[0],
|
||
|
c % filterDims[1],
|
||
|
c / (filterDims[0] * filterDims[1]),
|
||
|
r);
|
||
|
|
||
|
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
|
||
|
return coordIsValid(coord, shape) ? W[getFlatIndex(coord, shape)] : 0;
|
||
|
}
|
||
|
|
||
|
float mm_readB(uint row, uint col) {
|
||
|
int r = int(row), c = int(col);
|
||
|
int outRow = c / outShape[2];
|
||
|
int outCol = c % outShape[2];
|
||
|
|
||
|
int WRow = (r / filterDims[1]) % filterDims[0];
|
||
|
int WCol = r % filterDims[1];
|
||
|
|
||
|
ivec4 coord = ivec4(
|
||
|
batch,
|
||
|
pad[0] + outRow * stride[0] + WRow,
|
||
|
pad[1] + outCol * stride[1] + WCol,
|
||
|
r / (filterDims[0] * filterDims[1]));
|
||
|
return coordIsValid(coord, xShape) ?
|
||
|
x[getFlatIndex(coord, xShape)] : 0;
|
||
|
}
|
||
|
|
||
|
void mm_write(uint row, uint col, float value) {
|
||
|
ivec4 outCoord = ivec4(
|
||
|
batch,
|
||
|
col / outShape[2],
|
||
|
col % outShape[2],
|
||
|
row);
|
||
|
if (coordIsValid(outCoord, outShape)) {
|
||
|
result[getFlatIndex(outCoord, outShape)] = value;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
batch = int(gl_GlobalInvocationID.z);
|
||
|
|
||
|
int dimAOuter = outShape[3];
|
||
|
int dimBOuter = outShape[1] * outShape[2];
|
||
|
int dimInner = filterDims[0] * filterDims[1] * xShape[3];
|
||
|
mm_matMul(dimAOuter, dimInner, dimBOuter);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class Conv2DNaiveProgram {
|
||
|
constructor(convInfo) {
|
||
|
this.variableNames = ['x', 'W'];
|
||
|
this.uniforms = 'ivec2 filterDims, pad, stride;';
|
||
|
this.workGroupSize = [4, 8, 1];
|
||
|
this.outputShape = convInfo.outShape;
|
||
|
this.dispatchLayout = { x: [2], y: [1], z: [0, 3] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
|
||
|
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
|
||
|
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
|
||
|
this.userCode = `
|
||
|
float readInp(uint batch, uint row, uint col, uint chan) {
|
||
|
ivec4 coord = ivec4(batch, row, col, chan);
|
||
|
return coordIsValid(coord, xShape) ? getX(coord) : 0;
|
||
|
}
|
||
|
|
||
|
float readFilt(uint row, uint col, uint xChannel, uint outChannel) {
|
||
|
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
|
||
|
return coordIsValid(coord, shape) ?
|
||
|
getW(row, col, xChannel, outChannel) : 0;
|
||
|
}
|
||
|
|
||
|
void writeResult(uint batch, uint row, uint col, uint chan, float value) {
|
||
|
ivec4 coord = ivec4(batch, row, col, chan);
|
||
|
if (coordIsValid(coord, outShape)) {
|
||
|
setOutput(batch, row, col, chan, value);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
ivec4 coords = getOutputCoords();
|
||
|
int batch = coords[0];
|
||
|
int outChannel = coords[3];
|
||
|
|
||
|
float acc = 0.0;
|
||
|
|
||
|
for (int row = 0; row < filterDims[0]; ++row) {
|
||
|
for (int col = 0; col < filterDims[1]; ++col) {
|
||
|
for (int xChannel = 0; xChannel < xShape[3]; ++xChannel) {
|
||
|
float v = readInp(batch,
|
||
|
pad[0] + coords[1] * stride[0] + row,
|
||
|
pad[1] + coords[2] * stride[1] + col, xChannel);
|
||
|
float f = readFilt(row, col, xChannel, outChannel);
|
||
|
acc += v * f;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
writeResult(batch, coords[1], coords[2], outChannel, acc);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class DepthwiseConv2DProgram {
|
||
|
constructor(convInfo) {
|
||
|
this.variableNames = ['x', 'W'];
|
||
|
this.uniforms = 'ivec2 filterDims, pad, stride;';
|
||
|
this.workGroupSize = [4, 8, 1];
|
||
|
this.outputShape = convInfo.outShape;
|
||
|
this.dispatchLayout = { x: [2], y: [1], z: [0, 3] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
|
||
|
const xNumRows = convInfo.inHeight;
|
||
|
const xNumCols = convInfo.inWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const channelMul = convInfo.outChannels / convInfo.inChannels;
|
||
|
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
|
||
|
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
|
||
|
this.userCode = `
|
||
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
||
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
||
|
|
||
|
void writeResult(int batch, int row, int col, int chan, float value) {
|
||
|
ivec4 coord = ivec4(batch, row, col, chan);
|
||
|
if (coordIsValid(coord, outShape)) {
|
||
|
setOutput(batch, row, col, chan, value);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
ivec4 coords = getOutputCoords();
|
||
|
int batch = coords[0];
|
||
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
||
|
int d2 = coords[3];
|
||
|
int d1 = d2 / ${channelMul};
|
||
|
int q = d2 - d1 * ${channelMul};
|
||
|
|
||
|
int xRCorner = xRCCorner.x;
|
||
|
int xCCorner = xRCCorner.y;
|
||
|
|
||
|
// Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
|
||
|
// ? = to be determined. : = across all values in that axis.
|
||
|
float dotProd = 0.0;
|
||
|
// TODO(xing.xu): Flatten the two for loops and vec4 the operations.
|
||
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
||
|
int xR = xRCorner + wR * ${dilationHeight};
|
||
|
|
||
|
if (xR < 0 || xR >= ${xNumRows}) {
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
||
|
int xC = xCCorner + wC * ${dilationWidth};
|
||
|
|
||
|
if (xC < 0 || xC >= ${xNumCols}) {
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
float xVal = getX(batch, xR, xC, d1);
|
||
|
float wVal = getW(wR, wC, d1, q);
|
||
|
dotProd += xVal * wVal;
|
||
|
}
|
||
|
}
|
||
|
writeResult(batch, coords[1], coords[2], d2, dotProd);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class MaxPoolProgram {
|
||
|
constructor(convInfo) {
|
||
|
this.variableNames = ['x'];
|
||
|
this.uniforms = 'ivec2 pad, stride, dilation, convDims, filterDims;';
|
||
|
this.workGroupSize = [4, 4, 1];
|
||
|
this.outputShape = convInfo.outShape;
|
||
|
this.dispatchLayout = { x: [1], y: [2], z: [0, 3] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
|
||
|
// TODO: Parallelize max computation by thread and merge result.
|
||
|
this.userCode = `
|
||
|
float getValue(int batch, int xR, int xC, int d) {
|
||
|
if (xC < 0 || xC >= convDims.x) {
|
||
|
return 0.0;
|
||
|
}
|
||
|
return getX(batch, xR, xC, d);
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
ivec4 coords = getOutputCoords();
|
||
|
int batch = coords[0];
|
||
|
int d = coords[3];
|
||
|
|
||
|
if (all(lessThan(coords, outShape))) {
|
||
|
ivec2 xRCCorner = coords.yz * stride - pad;
|
||
|
int xRCorner = xRCCorner.x;
|
||
|
int xCCorner = xRCCorner.y;
|
||
|
|
||
|
float minMaxValue = 0.0;
|
||
|
|
||
|
for (int wR = 0; wR < filterDims.y; wR += dilation.y) {
|
||
|
int xR = xRCorner + wR;
|
||
|
|
||
|
if (xR < 0 || xR >= convDims.y) {
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
for (int wC = 0; wC < filterDims.x; wC += dilation.x) {
|
||
|
int xC = xCCorner + wC * dilation.x;
|
||
|
float value = getValue(batch, xR, xC, d);
|
||
|
minMaxValue = max(value, minMaxValue);
|
||
|
}
|
||
|
}
|
||
|
setOutput(batch, coords[1], coords[2], d, minMaxValue);
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class PadProgram {
|
||
|
constructor(xShape, paddings, constantValue) {
|
||
|
this.variableNames = ['x'];
|
||
|
this.workPerThread = 8;
|
||
|
this.workGroupSize = [1, 1, 1];
|
||
|
this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
|
||
|
const rank = xShape.length;
|
||
|
const size = tf.util.sizeFromShape(this.outputShape);
|
||
|
const type = getCoordsDataType(rank);
|
||
|
this.dispatchLayout = flatDispatchLayout(this.outputShape);
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
|
||
|
const start = paddings.map(p => p[0]).join(',');
|
||
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
||
|
const startValue = rank > 1 ? `${type}(${start})` : `${start}`;
|
||
|
const endValue = rank > 1 ? `${type}(${end})` : `${end}`;
|
||
|
const leftPadCondition = rank > 1 ? `any(lessThan(outC, start))` : `outC < start`;
|
||
|
const rightPadCondition = rank > 1 ? `any(greaterThanEqual(outC, end))` : `outC >= end`;
|
||
|
const unpackedCoords = rank > 1 ?
|
||
|
['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank) :
|
||
|
'coords';
|
||
|
this.userCode = `
|
||
|
${type} start = ${startValue};
|
||
|
${type} end = ${endValue};
|
||
|
|
||
|
void main() {
|
||
|
int index = int(gl_GlobalInvocationID.x);
|
||
|
|
||
|
for (int i = 0; i < ${this.workPerThread}; i++) {
|
||
|
int flatIndex = index * ${this.workPerThread} + i;
|
||
|
|
||
|
if (flatIndex < ${size}) {
|
||
|
${type} outC = getCoordsFromFlatIndex(flatIndex);
|
||
|
|
||
|
if (${leftPadCondition} || ${rightPadCondition}) {
|
||
|
setOutput(flatIndex, ${constantValue});
|
||
|
} else {
|
||
|
${type} coords = outC - start;
|
||
|
setOutput(flatIndex, getX(${unpackedCoords}));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class ResizeBilinearProgram {
|
||
|
constructor(inputShape, newHeight, newWidth, alignCorners) {
|
||
|
this.variableNames = ['x'];
|
||
|
this.outputShape = [inputShape[0], newHeight, newWidth, inputShape[3]];
|
||
|
this.dispatchLayout = { x: [1], y: [2], z: [0, 3] };
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
|
||
|
const adjustHeight = alignCorners && newHeight > 1;
|
||
|
const adjustWidth = alignCorners && newWidth > 1;
|
||
|
this.userCode = `
|
||
|
void main() {
|
||
|
ivec4 coords = getOutputCoords();
|
||
|
|
||
|
if (all(lessThan(coords, outShape))) {
|
||
|
int b = coords[0];
|
||
|
int d = coords[3];
|
||
|
ivec2 rc = coords.yz;
|
||
|
|
||
|
vec2 effectiveInSize = vec2(
|
||
|
${adjustHeight ? 'xShape.y - 1.0' : 'xShape.y'},
|
||
|
${adjustWidth ? 'xShape.z - 1.0' : 'xShape.z'});
|
||
|
|
||
|
vec2 effectiveOutSize = vec2(
|
||
|
${adjustHeight ? 'outShape.y - 1.0' : 'outShape.y'},
|
||
|
${adjustWidth ? 'outShape.z - 1.0' : 'outShape.z'});
|
||
|
|
||
|
vec2 effectiveInputOverOutputRatioRC =
|
||
|
effectiveInSize / effectiveOutSize;
|
||
|
|
||
|
// Fractional source index
|
||
|
vec2 sourceFracIndexRC = vec2(rc) * effectiveInputOverOutputRatioRC;
|
||
|
|
||
|
// Compute the four integer indices.
|
||
|
ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);
|
||
|
ivec2 sourceCeilRC = ivec2(
|
||
|
min(xShape.yz - 1.0, ceil(sourceFracIndexRC)));
|
||
|
|
||
|
float topLeft = getX(b, sourceFloorRC.x, sourceFloorRC.y, d);
|
||
|
float bottomLeft = getX(b, sourceCeilRC.x, sourceFloorRC.y, d);
|
||
|
float topRight = getX(b, sourceFloorRC.x, sourceCeilRC.y, d);
|
||
|
float bottomRight = getX(b, sourceCeilRC.x, sourceCeilRC.y, d);
|
||
|
|
||
|
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
|
||
|
|
||
|
float top = topLeft + (topRight - topLeft) * fracRC.y;
|
||
|
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
|
||
|
float newValue = top + (bottom - top) * fracRC.x;
|
||
|
|
||
|
setOutput(b, coords[1], coords[2], d, newValue);
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
class TransposeProgram {
|
||
|
constructor(aShape, newDim) {
|
||
|
this.variableNames = ['A'];
|
||
|
const outputShape = new Array(aShape.length);
|
||
|
for (let i = 0; i < outputShape.length; i++) {
|
||
|
outputShape[i] = aShape[newDim[i]];
|
||
|
}
|
||
|
this.outputShape = outputShape;
|
||
|
this.rank = outputShape.length;
|
||
|
const dtype = getCoordsDataType(this.rank);
|
||
|
this.dispatchLayout = flatDispatchLayout(this.outputShape);
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
|
||
|
const switched = getSwitchedCoords(newDim);
|
||
|
this.userCode = `
|
||
|
void main() {
|
||
|
${dtype} resRC = getOutputCoords();
|
||
|
setOutput(getFlatIndex(resRC, outShape), A[getFlatIndex(
|
||
|
${dtype}(${switched}), aShape)]);
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
function getSwitchedCoords(newDim) {
|
||
|
const rank = newDim.length;
|
||
|
if (rank > 4) {
|
||
|
throw Error(`Transpose for rank ${rank} is not yet supported`);
|
||
|
}
|
||
|
const switchedCoords = new Array(rank);
|
||
|
for (let i = 0; i < newDim.length; i++) {
|
||
|
switchedCoords[newDim[i]] = `resRC[${i}]`;
|
||
|
}
|
||
|
return switchedCoords.join();
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const RELU = 'return max(a, 0.0);';
|
||
|
const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
|
||
|
class UnaryOpProgram {
|
||
|
constructor(outputShape, op) {
|
||
|
this.variableNames = ['A'];
|
||
|
this.workPerThread = 4;
|
||
|
this.workGroupSize = [1, 1, 1];
|
||
|
this.outputShape = outputShape;
|
||
|
const size = tf.util.sizeFromShape(this.outputShape);
|
||
|
this.dispatchLayout = flatDispatchLayout(this.outputShape);
|
||
|
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
|
||
|
const type = getCoordsDataType(this.outputShape.length);
|
||
|
this.userCode = `
|
||
|
float unaryOperation(float a) {
|
||
|
${op}
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
int index = int(gl_GlobalInvocationID.x);
|
||
|
|
||
|
for(int i = 0; i < ${this.workPerThread}; i++) {
|
||
|
int flatIndex = index * ${this.workPerThread} + i;
|
||
|
|
||
|
if(flatIndex < ${size}) {
|
||
|
${type} coords = getCoordsFromFlatIndex(flatIndex);
|
||
|
|
||
|
float a = getAAtOutCoords(coords);
|
||
|
setOutput(flatIndex, unaryOperation(a));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
`;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const makeBindGroup = (device, bindGroupLayout, inputs, output, uniforms) => {
|
||
|
const bindings = [output, ...inputs];
|
||
|
if (uniforms) {
|
||
|
bindings.push(uniforms);
|
||
|
}
|
||
|
return device.createBindGroup({
|
||
|
layout: bindGroupLayout,
|
||
|
bindings: bindings.map((b, i) => ({ binding: i, resource: b.resource })),
|
||
|
});
|
||
|
};
|
||
|
const makeBindGroupLayout = (device, inputs, output, uniforms) => {
|
||
|
const bindings = Array(1 + inputs.length).fill({
|
||
|
visibility: GPUShaderStage.COMPUTE,
|
||
|
type: 'storage-buffer'
|
||
|
});
|
||
|
if (uniforms) {
|
||
|
bindings.push({
|
||
|
visibility: GPUShaderStage.COMPUTE,
|
||
|
type: 'uniform-buffer'
|
||
|
});
|
||
|
}
|
||
|
return device.createBindGroupLayout({
|
||
|
bindings: bindings.map((b, i) => (Object.assign({ binding: i }, b))),
|
||
|
});
|
||
|
};
|
||
|
const compileProgram = (shaderCompiler, shaderKind, compileOptions, device, program, inputsData, output, uniforms) => {
|
||
|
const outputData = { dtype: output.dtype, shape: output.shape };
|
||
|
const source = makeShader(inputsData, outputData, program);
|
||
|
const result = shaderCompiler.CompileGlslToSpv(source, shaderKind, 'file', 'main', compileOptions);
|
||
|
const error = result.GetErrorMessage();
|
||
|
if (error.length) {
|
||
|
console.error(source.split('\n')
|
||
|
.map((s, l) => (l + 1).toString().padStart(5, ' ') + ' ' + s)
|
||
|
.join('\n'));
|
||
|
throw new Error(`Shader compilation failed: ${error}`);
|
||
|
}
|
||
|
const bindGroupLayout = makeBindGroupLayout(device, inputsData, output, uniforms);
|
||
|
const code = result.GetBinary();
|
||
|
const layout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] });
|
||
|
const module = device.createShaderModule({ code });
|
||
|
const pipeline = device.createComputePipeline({ layout, computeStage: { module, entryPoint: 'main' } });
|
||
|
return { bindGroupLayout, pipeline };
|
||
|
};
|
||
|
// TODO: Consider allowing each program to specify its own shader key. E.g. some
|
||
|
// kernels account for different work group sizes, but some don't.
|
||
|
// TODO: Consider uploading shape info as vec4s regardless of rank to reduce
|
||
|
// recompilation.
|
||
|
function makeShaderKey(program, ranks) {
|
||
|
const key = (program.workGroupSize ? program.workGroupSize.join(',') : '') +
|
||
|
ranks.join(',') + program.userCode;
|
||
|
return key;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const DEFAULT_GPUBUFFER_USAGE = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
|
||
|
class WebGPUBackend extends tf.KernelBackend {
|
||
|
constructor(device, shaderc) {
|
||
|
super();
|
||
|
this.commandQueueOwnedIds = new WeakSet();
|
||
|
this.tensorMap = new WeakMap();
|
||
|
this.disposalQueue = [];
|
||
|
this.disposed = false;
|
||
|
this.uploadWaitMs = 0;
|
||
|
this.downloadWaitMs = 0;
|
||
|
this.binaryCache = {};
|
||
|
this.device = device;
|
||
|
this.queue = device.getQueue();
|
||
|
this.commandQueue = [];
|
||
|
this.shaderc = shaderc;
|
||
|
this.compiler = new shaderc.Compiler();
|
||
|
const opts = new shaderc.CompileOptions();
|
||
|
opts.SetOptimizationLevel(shaderc.optimization_level.performance);
|
||
|
this.compileOpts = opts;
|
||
|
this.bufferManager = new BufferManager(this.device);
|
||
|
}
|
||
|
floatPrecision() {
|
||
|
return 32;
|
||
|
}
|
||
|
setDataMover(dataMover) {
|
||
|
// TODO: tfjs team to implement this. Call GPUBuffer.destroy()
|
||
|
}
|
||
|
flushDisposalQueue() {
|
||
|
this.disposalQueue.forEach(d => {
|
||
|
this.releaseBuffer(d.buffer, d.byteSize, d.usage);
|
||
|
});
|
||
|
this.disposalQueue = [];
|
||
|
}
|
||
|
disposeData(dataId) {
|
||
|
if (!this.tensorMap.has(dataId)) {
|
||
|
throw new Error(`Tensor ${dataId} was not registered!`);
|
||
|
}
|
||
|
const info = this.tensorMap.get(dataId);
|
||
|
if (this.commandQueueOwnedIds.has(dataId)) {
|
||
|
this.disposalQueue.push(info.bufferInfo);
|
||
|
}
|
||
|
else {
|
||
|
this.releaseBuffer(info.bufferInfo.buffer, info.bufferInfo.byteSize, info.bufferInfo.usage);
|
||
|
}
|
||
|
this.tensorMap.delete(dataId);
|
||
|
}
|
||
|
memory() {
|
||
|
return {
|
||
|
numBytesInGPU: this.bufferManager.numBytesUsed,
|
||
|
unreliable: false
|
||
|
};
|
||
|
}
|
||
|
getBufferManager() {
|
||
|
return this.bufferManager;
|
||
|
}
|
||
|
acquireBuffer(byteSize, usage = DEFAULT_GPUBUFFER_USAGE) {
|
||
|
return this.bufferManager.acquireBuffer(byteSize, usage);
|
||
|
}
|
||
|
releaseBuffer(buffer, byteSize, usage) {
|
||
|
this.bufferManager.releaseBuffer(buffer, byteSize, usage);
|
||
|
}
|
||
|
register(dataId, shape, dtype) {
|
||
|
if (!this.tensorMap.has(dataId)) {
|
||
|
const byteSize = tf.util.sizeFromShape(shape) * tf.util.bytesPerElement(dtype);
|
||
|
const buffer = this.acquireBuffer(byteSize);
|
||
|
this.tensorMap.set(dataId, {
|
||
|
values: null,
|
||
|
id: -1,
|
||
|
dtype,
|
||
|
bufferInfo: { byteSize, usage: DEFAULT_GPUBUFFER_USAGE, buffer }
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
write(dataId, values) {
|
||
|
if (!this.tensorMap.has(dataId)) {
|
||
|
throw new Error(`Tensor ${dataId} was not registered!`);
|
||
|
}
|
||
|
const info = this.tensorMap.get(dataId);
|
||
|
info.values = values;
|
||
|
info.bufferInfo.buffer.setSubData(0, values);
|
||
|
this.tensorMap.set(dataId, info);
|
||
|
}
|
||
|
submitQueue() {
|
||
|
this.queue.submit(this.commandQueue.map(enc => enc.finish()));
|
||
|
this.commandQueue = [];
|
||
|
this.commandQueueOwnedIds = new WeakSet();
|
||
|
this.flushDisposalQueue();
|
||
|
}
|
||
|
async getBufferData(info) {
|
||
|
const staging = this.acquireBuffer(info.bufferInfo.byteSize, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
|
||
|
const encoder = this.device.createCommandEncoder({});
|
||
|
encoder.copyBufferToBuffer(info.bufferInfo.buffer, 0, staging, 0, info.bufferInfo.byteSize);
|
||
|
this.commandQueue.push(encoder);
|
||
|
this.submitQueue();
|
||
|
const mapped = await staging.mapReadAsync();
|
||
|
const values = mapped.slice(0);
|
||
|
staging.unmap();
|
||
|
this.releaseBuffer(staging, info.bufferInfo.byteSize, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
|
||
|
return values;
|
||
|
}
|
||
|
convertAndCacheOnCPU(dataId, data) {
|
||
|
const info = this.tensorMap.get(dataId);
|
||
|
info.values = data;
|
||
|
return info.values;
|
||
|
}
|
||
|
// TODO: Remove once this is fixed:
|
||
|
// https://github.com/tensorflow/tfjs/issues/1595
|
||
|
readSync(dataId) {
|
||
|
const texData = this.tensorMap.get(dataId);
|
||
|
const { values } = texData;
|
||
|
if (values == null) {
|
||
|
throw new Error('WebGPU readSync is only available for CPU-resident tensors.');
|
||
|
}
|
||
|
return values;
|
||
|
}
|
||
|
async read(dataId) {
|
||
|
if (!this.tensorMap.has(dataId)) {
|
||
|
throw new Error(`Tensor ${dataId} was not registered!`);
|
||
|
}
|
||
|
const info = this.tensorMap.get(dataId);
|
||
|
const data = await this.getBufferData(info);
|
||
|
const dataAsTypedArray = info.dtype === 'int32' ? new Int32Array(data) : new Float32Array(data);
|
||
|
this.convertAndCacheOnCPU(dataId, dataAsTypedArray);
|
||
|
return dataAsTypedArray;
|
||
|
}
|
||
|
async time(f) {
|
||
|
const oldActiveTimers = this.activeTimers;
|
||
|
const newActiveTimers = [];
|
||
|
let outerMostTime = false;
|
||
|
if (this.programTimersStack == null) {
|
||
|
this.programTimersStack = newActiveTimers;
|
||
|
outerMostTime = true;
|
||
|
}
|
||
|
else {
|
||
|
this.activeTimers.push(newActiveTimers);
|
||
|
}
|
||
|
this.activeTimers = newActiveTimers;
|
||
|
f();
|
||
|
const flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map((d) => d.query))
|
||
|
.filter(d => d != null);
|
||
|
const flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map((d) => d.name))
|
||
|
.filter(d => d != null);
|
||
|
this.activeTimers = oldActiveTimers;
|
||
|
if (outerMostTime) {
|
||
|
this.programTimersStack = null;
|
||
|
}
|
||
|
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
|
||
|
const res = {
|
||
|
uploadWaitMs: this.uploadWaitMs,
|
||
|
downloadWaitMs: this.downloadWaitMs,
|
||
|
kernelMs: tf.util.sum(kernelMs),
|
||
|
getExtraProfileInfo: () => kernelMs.map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
|
||
|
.map(d => `${d.name}: ${d.ms}`)
|
||
|
.join(', '),
|
||
|
wallMs: null
|
||
|
};
|
||
|
this.uploadWaitMs = 0;
|
||
|
this.downloadWaitMs = 0;
|
||
|
return res;
|
||
|
}
|
||
|
getAndSavePipeline(key, getBinary) {
|
||
|
if (!(key in this.binaryCache)) {
|
||
|
this.binaryCache[key] = getBinary();
|
||
|
}
|
||
|
return this.binaryCache[key];
|
||
|
}
|
||
|
makeOutputArray(shape, dtype) {
|
||
|
return tf.Tensor.make(shape, {}, dtype, this);
|
||
|
}
|
||
|
tensorToBinding(tensor) {
|
||
|
if (!tensor) {
|
||
|
return null;
|
||
|
}
|
||
|
const tensorData = this.tensorMap.get(tensor.dataId);
|
||
|
return {
|
||
|
resource: {
|
||
|
offset: 0,
|
||
|
size: tensor.size * tf.util.bytesPerElement(tensor.dtype),
|
||
|
buffer: tensorData.bufferInfo.buffer
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
startTimer() {
|
||
|
return { startMs: tf.util.now(), endMs: 0 };
|
||
|
}
|
||
|
endTimer(query) {
|
||
|
query.endMs = tf.util.now();
|
||
|
return query;
|
||
|
}
|
||
|
async getQueryTime(query) {
|
||
|
const timerQuery = query;
|
||
|
return timerQuery.endMs - timerQuery.startMs;
|
||
|
}
|
||
|
compileAndRun(program, inputs, output, programUniforms) {
|
||
|
if (output == null) {
|
||
|
output = this.makeOutputArray(program.outputShape, inputs[0].dtype);
|
||
|
}
|
||
|
let dimUniforms = [];
|
||
|
const bufferShapes = inputs.concat(output).map(d => d.shape);
|
||
|
let currentOffset = 0;
|
||
|
bufferShapes.forEach((d, i) => {
|
||
|
// Uniforms.
|
||
|
if (d.length === 0) {
|
||
|
d = [1];
|
||
|
}
|
||
|
// Complete std140 layout rules are documented here:
|
||
|
// tslint:disable-next-line:max-line-length
|
||
|
// https://www.khronos.org/registry/OpenGL/specs/gl/glspec45.core.pdf#page=159
|
||
|
let baseAlignment;
|
||
|
switch (d.length) {
|
||
|
case 0:
|
||
|
baseAlignment = 1;
|
||
|
break;
|
||
|
case 1:
|
||
|
baseAlignment = 1;
|
||
|
break;
|
||
|
case 2:
|
||
|
baseAlignment = 2;
|
||
|
break;
|
||
|
case 3:
|
||
|
baseAlignment = 4;
|
||
|
break;
|
||
|
case 4:
|
||
|
baseAlignment = 4;
|
||
|
break;
|
||
|
default:
|
||
|
tf.util.assert(false, () => `Unsupported ${d.length}D shape`);
|
||
|
}
|
||
|
const padding = Math.ceil(currentOffset / baseAlignment) * baseAlignment -
|
||
|
currentOffset;
|
||
|
for (let p = 0; p < padding; ++p) {
|
||
|
dimUniforms.push(0);
|
||
|
}
|
||
|
dimUniforms.push(...d);
|
||
|
currentOffset += d.length + padding;
|
||
|
});
|
||
|
// TODO: handle padding of program-specific uniforms
|
||
|
if (programUniforms) {
|
||
|
dimUniforms = dimUniforms.concat(programUniforms);
|
||
|
}
|
||
|
const uniformData = new Int32Array(dimUniforms);
|
||
|
const uniforms = this.makeUniforms(uniformData);
|
||
|
const key = makeShaderKey(program, bufferShapes.map(d => d.length));
|
||
|
const inputsData = inputs.map((input, i) => ({
|
||
|
// Returning dtype from tensorMap because it reflects dtype
|
||
|
// of underlying buffer, rather than abstract dtype.
|
||
|
dtype: this.tensorMap.get(input.dataId).dtype,
|
||
|
shape: input.shape,
|
||
|
name: program.variableNames[i]
|
||
|
}));
|
||
|
const { bindGroupLayout, pipeline } = this.getAndSavePipeline(key, () => {
|
||
|
return compileProgram(this.compiler, this.shaderc.shader_kind.compute, this.compileOpts, this.device, program, inputsData, output, uniforms);
|
||
|
});
|
||
|
const shouldTimeProgram = this.activeTimers != null;
|
||
|
let query;
|
||
|
if (shouldTimeProgram) {
|
||
|
query = this.startTimer();
|
||
|
}
|
||
|
// Creating bind groups on the fly should never be a bottleneck.
|
||
|
const bg = makeBindGroup(this.device, bindGroupLayout, inputs.map(t => this.tensorToBinding(t)), this.tensorToBinding(output), uniforms);
|
||
|
const encoder = this.device.createCommandEncoder({});
|
||
|
const pass = encoder.beginComputePass();
|
||
|
pass.setPipeline(pipeline);
|
||
|
pass.setBindGroup(0, bg);
|
||
|
pass.dispatch(program.dispatch[0], program.dispatch[1], program.dispatch[2]);
|
||
|
pass.endPass();
|
||
|
this.commandQueue.push(encoder);
|
||
|
inputs.forEach(input => {
|
||
|
this.commandQueueOwnedIds.add(input.dataId);
|
||
|
});
|
||
|
this.commandQueueOwnedIds.add(output.dataId);
|
||
|
if (tf.ENV.get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) {
|
||
|
this.submitQueue();
|
||
|
}
|
||
|
this.releaseBuffer(uniforms.resource.buffer, uniformData.byteLength, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
|
||
|
if (shouldTimeProgram) {
|
||
|
query = this.endTimer(query);
|
||
|
this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
|
||
|
}
|
||
|
return output;
|
||
|
}
|
||
|
makeUniforms(data) {
|
||
|
const dimensionsBuffer = this.acquireBuffer(data.byteLength, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
|
||
|
dimensionsBuffer.setSubData(0, data);
|
||
|
return {
|
||
|
resource: { offset: 0, size: data.byteLength, buffer: dimensionsBuffer }
|
||
|
};
|
||
|
}
|
||
|
pad(x, paddings, constantValue) {
|
||
|
const program = new PadProgram(x.shape, paddings, constantValue);
|
||
|
const output = this.makeOutputArray(program.outputShape, x.dtype);
|
||
|
return this.compileAndRun(program, [x], output);
|
||
|
}
|
||
|
maxPool(x, convInfo) {
|
||
|
const program = new MaxPoolProgram(convInfo);
|
||
|
const output = this.makeOutputArray(program.outputShape, x.dtype);
|
||
|
const dimensions = [
|
||
|
convInfo.padInfo.left, convInfo.padInfo.top,
|
||
|
convInfo.strideWidth, convInfo.strideHeight,
|
||
|
convInfo.dilationWidth, convInfo.dilationHeight,
|
||
|
convInfo.inWidth, convInfo.inHeight,
|
||
|
convInfo.effectiveFilterWidth,
|
||
|
convInfo.effectiveFilterHeight // Filter dims.
|
||
|
];
|
||
|
return this.compileAndRun(program, [x], output, dimensions);
|
||
|
}
|
||
|
binaryOp(a, b, op) {
|
||
|
const dtype = tf.backend_util.upcastType(a.dtype, b.dtype);
|
||
|
const program = new BinaryOpProgram(op, a.shape, b.shape);
|
||
|
const output = tf.Tensor.make(program.outputShape, {}, dtype);
|
||
|
return this.compileAndRun(program, [a, b], output);
|
||
|
}
|
||
|
add(a, b) {
|
||
|
return this.binaryOp(a, b, ADD);
|
||
|
}
|
||
|
subtract(a, b) {
|
||
|
return this.binaryOp(a, b, SUB);
|
||
|
}
|
||
|
conv2d(x, filter, convInfo) {
|
||
|
const output = tf.Tensor.make(convInfo.outShape, {}, x.dtype, this);
|
||
|
let program;
|
||
|
const workPerThread = tf.ENV.get('WEBGPU_CONV2D_WORK_PER_THREAD');
|
||
|
if (workPerThread === -1) {
|
||
|
// TODO(kainino0x): This may be obsolete, but is kept for reference.
|
||
|
program = new Conv2DNaiveProgram(convInfo);
|
||
|
}
|
||
|
else {
|
||
|
program = new Conv2DMMProgram(convInfo, workPerThread);
|
||
|
}
|
||
|
const pad = convInfo.padInfo.type === 'VALID' ?
|
||
|
[0, 0] :
|
||
|
convInfo.padInfo.type === 'SAME' ?
|
||
|
[
|
||
|
-Math.floor((convInfo.filterShape[0] - 1) / 2),
|
||
|
-Math.floor((convInfo.filterShape[1] - 1) / 2)
|
||
|
] :
|
||
|
[convInfo.padInfo.top, convInfo.padInfo.left];
|
||
|
const dimensions = [
|
||
|
convInfo.filterHeight, convInfo.filterWidth, ...pad,
|
||
|
convInfo.strideHeight, convInfo.strideWidth
|
||
|
];
|
||
|
return this.compileAndRun(program, [x, filter], output, dimensions);
|
||
|
}
|
||
|
depthwiseConv2D(x, filter, convInfo) {
|
||
|
const program = new DepthwiseConv2DProgram(convInfo);
|
||
|
return this.compileAndRun(program, [x, filter]);
|
||
|
}
|
||
|
argMinMaxReduce(x, axis, reduceType) {
|
||
|
const program = new ArgMinMaxProgram(x.shape, axis, reduceType);
|
||
|
const output = this.makeOutputArray(program.outputShape, 'int32');
|
||
|
return this.compileAndRun(program, [x], output, [axis]);
|
||
|
}
|
||
|
argMin(x, axis) {
|
||
|
return this.argMinMaxReduce(x, axis, 'min');
|
||
|
}
|
||
|
argMax(x, axis) {
|
||
|
return this.argMinMaxReduce(x, axis, 'max');
|
||
|
}
|
||
|
concat(tensors, axis) {
|
||
|
if (tensors.length === 1) {
|
||
|
return tensors[0];
|
||
|
}
|
||
|
// Is there a maximum number of buffers that can be uploaded to a WebGPU
|
||
|
// program?
|
||
|
// if (tensors.length > MAX_SSBOS_FOR_WEBGPU_PROGRAM) {
|
||
|
// const midIndex = Math.floor(tensors.length / 2);
|
||
|
// const leftSide = this.concat(tensors.slice(0, midIndex), axis);
|
||
|
// const rightSide = this.concat(tensors.slice(midIndex), axis);
|
||
|
// return this.concat([leftSide, rightSide], axis);
|
||
|
// }
|
||
|
const outShape = tf.backend_util.computeOutShape(tensors.map(t => t.shape), axis);
|
||
|
const tensors2D = tensors.map(t => t.reshape([
|
||
|
tf.util.sizeFromShape(t.shape.slice(0, axis)),
|
||
|
tf.util.sizeFromShape(t.shape.slice(axis))
|
||
|
]));
|
||
|
const program = new ConcatProgram(tensors2D.map(t => t.shape));
|
||
|
const res = this.compileAndRun(program, tensors2D);
|
||
|
return res.reshape(outShape);
|
||
|
}
|
||
|
multiply(a, b) {
|
||
|
return this.binaryOp(a, b, MUL);
|
||
|
}
|
||
|
floorDiv(a, b) {
|
||
|
return this.binaryOp(a, b, INT_DIV);
|
||
|
}
|
||
|
sigmoid(x) {
|
||
|
const program = new UnaryOpProgram(x.shape, SIGMOID);
|
||
|
return this.compileAndRun(program, [x]);
|
||
|
}
|
||
|
relu(x) {
|
||
|
const program = new UnaryOpProgram(x.shape, RELU);
|
||
|
return this.compileAndRun(program, [x]);
|
||
|
}
|
||
|
resizeBilinear(x, newHeight, newWidth, alignCorners) {
|
||
|
const program = new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);
|
||
|
const output = this.makeOutputArray(program.outputShape, x.dtype);
|
||
|
return this.compileAndRun(program, [x], output);
|
||
|
}
|
||
|
reshape(x, shape) {
|
||
|
return tf.Tensor.make(shape, { dataId: x.dataId }, x.dtype);
|
||
|
}
|
||
|
cast(x, dtype) {
|
||
|
return tf.backend_util.castTensor(x, dtype, this);
|
||
|
}
|
||
|
transpose(x, perm) {
|
||
|
const program = new TransposeProgram(x.shape, perm);
|
||
|
return this.compileAndRun(program, [x]);
|
||
|
}
|
||
|
batchMatMul(a, b, transposeA, transposeB) {
|
||
|
// TODO: Support transposed inputs.
|
||
|
// const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
|
||
|
// const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
|
||
|
const outerShapeA = a.shape[1];
|
||
|
const outerShapeB = b.shape[2];
|
||
|
const [batch, ,] = a.shape;
|
||
|
const output = tf.Tensor.make([batch, outerShapeA, outerShapeB], {}, a.dtype, this);
|
||
|
let program;
|
||
|
// TODO: We should eventually use the blocked version, but keeping around
|
||
|
// the old version while we try to understand conditions under which blocked
|
||
|
// is faster.
|
||
|
if (tf.ENV.get('WEBGPU_MATMUL_WORK_PER_THREAD') === 0) {
|
||
|
program = new MatMulProgram(output.shape);
|
||
|
}
|
||
|
else {
|
||
|
program = new MatMulPackedProgram(output.shape, tf.ENV.get('WEBGPU_MATMUL_WORK_PER_THREAD'));
|
||
|
}
|
||
|
return this.compileAndRun(program, [a, b], output);
|
||
|
}
|
||
|
fromPixels(pixels, numChannels) {
|
||
|
if (pixels == null) {
|
||
|
throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
|
||
|
}
|
||
|
const outShape = [pixels.height, pixels.width, numChannels];
|
||
|
let imageData = pixels.data;
|
||
|
if (tf.ENV.getBool('IS_BROWSER')) {
|
||
|
if (!(pixels instanceof HTMLVideoElement) &&
|
||
|
!(pixels instanceof HTMLImageElement) &&
|
||
|
!(pixels instanceof HTMLCanvasElement) &&
|
||
|
!(pixels instanceof ImageData) &&
|
||
|
!(pixels.data instanceof Uint8Array)) {
|
||
|
throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
|
||
|
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData` +
|
||
|
` or {data: Uint32Array, width: number, height: number}, ` +
|
||
|
`but was ${pixels.constructor.name}`);
|
||
|
}
|
||
|
if (pixels instanceof HTMLVideoElement ||
|
||
|
pixels instanceof HTMLImageElement ||
|
||
|
pixels instanceof HTMLCanvasElement) {
|
||
|
if (this.fromPixels2DContext == null) {
|
||
|
this.fromPixels2DContext =
|
||
|
document.createElement('canvas').getContext('2d');
|
||
|
this.fromPixels2DContext.canvas.width = pixels.width;
|
||
|
this.fromPixels2DContext.canvas.height = pixels.height;
|
||
|
}
|
||
|
this.fromPixels2DContext.drawImage(pixels, 0, 0, pixels.width, pixels.height);
|
||
|
pixels = this.fromPixels2DContext.canvas;
|
||
|
}
|
||
|
// TODO: Remove this once we figure out how to upload textures directly to
|
||
|
// WebGPU.
|
||
|
const imageDataLivesOnGPU = pixels instanceof HTMLVideoElement ||
|
||
|
pixels instanceof HTMLImageElement ||
|
||
|
pixels instanceof HTMLCanvasElement;
|
||
|
if (imageDataLivesOnGPU) {
|
||
|
imageData = this.fromPixels2DContext
|
||
|
.getImageData(0, 0, pixels.width, pixels.height)
|
||
|
.data;
|
||
|
}
|
||
|
}
|
||
|
// TODO: Encoding should happen on GPU once we no longer have to download
|
||
|
// image data to the CPU.
|
||
|
let pixelArray = imageData;
|
||
|
if (numChannels != null && numChannels !== 4) {
|
||
|
pixelArray = new Uint8Array(pixels.width * pixels.height * numChannels);
|
||
|
for (let i = 0; i < imageData.length; i++) {
|
||
|
if (i % 4 < numChannels) {
|
||
|
const pixelIndex = Math.floor(i / 4);
|
||
|
pixelArray[pixelIndex * numChannels + i % 4] = imageData[i];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
const output = this.makeOutputArray(outShape, 'int32');
|
||
|
this.write(output.dataId, Int32Array.from(pixelArray));
|
||
|
return output;
|
||
|
}
|
||
|
dispose() {
|
||
|
if (this.disposed) {
|
||
|
return;
|
||
|
}
|
||
|
this.bufferManager.dispose();
|
||
|
this.disposed = true;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
tf.registerBackend('webgpu', async () => {
|
||
|
const shaderc = await dist_1();
|
||
|
// @ts-ignore navigator.gpu is required
|
||
|
const adapter = await navigator.gpu.requestAdapter({});
|
||
|
const device = await adapter.requestDevice({});
|
||
|
return new WebGPUBackend(device, shaderc);
|
||
|
}, 3 /*priority*/);
|
||
|
|
||
|
})));
|
||
|
//# sourceMappingURL=tf-webgpu.js.map
|