human/assets/tf-backend-webgpu.js

2230 lines
5.2 MiB
JavaScript
Raw Normal View History

2020-10-30 15:23:49 +01:00
(function (global, factory) {
typeof exports === 'object' && typeof module !== 'undefined' ? factory(require('fs'), require('path'), require('@tensorflow/tfjs-core')) :
typeof define === 'function' && define.amd ? define(['fs', 'path', '@tensorflow/tfjs-core'], factory) :
(factory(null,null,global.tf));
}(this, (function (fs,path,tf) { 'use strict';
fs = fs && fs.hasOwnProperty('default') ? fs['default'] : fs;
path = path && path.hasOwnProperty('default') ? path['default'] : path;
function createCommonjsModule(fn, module) {
return module = { exports: {} }, fn(module, module.exports), module.exports;
}
var dist = createCommonjsModule(function (module) {
module.exports = {};
module.exports.instantiate = () => {
const wasmData = 'data:application/wasm;base64,\
AGFzbQEAAAABpQRBYAJ/fwBgAX8AYAABf2ABfwF/YAd/f39/f39/AGAFf39/f38Bf2AFf39+f38AYAJ/fwF/YAd/f39/f39/AX9gA39/fwF/YAV/f39/fwBgBn9/f39/fwBgBH9/f38Bf2AEf39/fwBgA39/fwBgAABgBn9/f39/fwF/YAV/f39/fgF/YAV/f39/fAF/YAh/f39/f39/fwF/YAZ/f39/f3wBf2ANf39/f39/f39/f39/fwBgCH9/f39/f39/AGABfAF8YA9/f39/f39/f39/f39/f38Bf2AMf39/f39/f39/f39/AX9gDH9/f39/f39/f39/fwBgCX9/f39/f39/fwBgA39+fwF/YAV/fH9/fwF/YAR/f35/AX9gA399fwF/YAN/fH8Bf2AJf39/f39/f39/AX9gAX8BfWABfwF8YAF/AX5gBX9+fn9/AX5gBH9+fn4BfmAEf39/fgF+YAN/f38BfGAFf39/f38BfGAGf39/f39/AXxgAn9/AX5gAnx/AXxgAnx8AXxgA35/fwF/YAJ+fwF/YAZ/fH9/f38Bf2ADfHx/AXxgAXwBf2ACfH8Bf2ABfQF/YAR/f39/AX5gA39/fgBgAn9+AGACf34Bf2ACf30AYAJ/fABgCn9/f39/f39/f38Bf2ADf39/AX1gC39/f39/f39/f39/AX9gCn9/f39/f39/f38AYA9/f39/f39/f39/f39/f38AYAd/f39/f398AX8C4QkzA2VudgVhYm9ydAABA2VudhlfX19jeGFfYWxsb2NhdGVfZXhjZXB0aW9uAAMDZW52E19fX2N4YV9wdXJlX3ZpcnR1YWwADwNlbnYMX19fY3hhX3Rocm93AA4DZW52GV9fX2N4YV91bmNhdWdodF9leGNlcHRpb24AAgNlbnYHX19fbG9jawABA2VudgtfX19tYXBfZmlsZQAHA2VudgtfX19zZXRFcnJObwABA2Vudg1fX19zeXNjYWxsMTQwAAcDZW52DV9fX3N5c2NhbGwxNDUABwNlbnYNX19fc3lzY2FsbDE0NgAHA2VudgxfX19zeXNjYWxsNTQABwNlbnYLX19fc3lzY2FsbDYABwNlbnYMX19fc3lzY2FsbDkxAAcDZW52CV9fX3VubG9jawABA2VudhZfX2VtYmluZF9yZWdpc3Rlcl9ib29sAAoDZW52F19fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzABUDZW52I19fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzX2NvbnN0cnVjdG9yAAsDZW52IF9fZW1iaW5kX3JlZ2lzdGVyX2NsYXNzX2Z1bmN0aW9uABYDZW52F19fZW1iaW5kX3JlZ2lzdGVyX2VtdmFsAAADZW52Fl9fZW1iaW5kX3JlZ2lzdGVyX2VudW0ADQNlbnYcX19lbWJpbmRfcmVnaXN0ZXJfZW51bV92YWx1ZQAOA2VudhdfX2VtYmluZF9yZWdpc3Rlcl9mbG9hdAAOA2VudhlfX2VtYmluZF9yZWdpc3Rlcl9pbnRlZ2VyAAoDZW52HV9fZW1iaW5kX3JlZ2lzdGVyX21lbW9yeV92aWV3AA4DZW52HF9fZW1iaW5kX3JlZ2lzdGVyX3N0ZF9zdHJpbmcAAANlbnYdX19lbWJpbmRfcmVnaXN0ZXJfc3RkX3dzdHJpbmcADgNlbnYWX19lbWJpbmRfcmVnaXN0ZXJfdm9pZAAAA2Vudg5fX2VtdmFsX2RlY3JlZgABA2Vudg5fX2VtdmFsX2luY3JlZgABA2VudhJfX2VtdmFsX3Rha2VfdmFsdWUABwNlbnYGX2Fib3J0AA8DZW52GV9lbXNjcmlwdGVuX2dldF9oZWFwX3NpemUAAgNlbnYWX2Vtc2NyaXB0ZW5fbWVtY3B5X2JpZwAJA2VudhdfZW1zY3JpcHRlbl9yZXNpemVfaGVhcAADA2VudgdfZ2V0ZW52AAMDZW52Dl9sbHZtX2xvZzJfZjY0ABcDZW52El9sbHZtX3N0YWNrcmVzdG9yZQABA2Vudg9fbGx2bV9zdGFja3NhdmUAAgNlbnYKX2xsdm1fdHJhcAAPA2VudhJfcHRocmVhZF9jb25kX3dhaXQABwNlbnYXX3B0aHJlYWRfbXV0ZXhhdHRyX2luaXQAAwNlbnYaX3B0aHJlYWRfbXV0ZXhhdHRyX3NldHR5cGUABwNlbnYLX3N0cmZ0aW1lX2wABQNlbnYXYWJvcnRPbkNhbm5vdEdyb3dNZW1vcnkAAwNlbnYMX190YWJsZV9iYXNlA38AA2Vudg5EWU5BTUlDVE9QX1BUUgN/AAZnbG9iYWwDTmFOA3wABmdsb2JhbAhJbmZpbml0eQN8AANlbnYGbWVtb3J5AgCAAgNlbnYFdGFibGUBcAGIJ4gnA/8h/SEDDwMCAQAPAAADAwEHAAcDAQIDDgMBAgcECAIBAAIIAAEBAwAJBwMMDAABAQUAAQEBBg0DBwcBAQEBAQEBAQMBAwEJCAABDQkADAcJAA0AAAAQAQEBAQABAQABAAMACgMMAw4MAwMDAwMDAwMDAwMDAwMDAwUDAAMDBwMDAAADAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAMABwMABwcHAwMDAwMDAAMBAAEBAQAAAAABAQ8NGAcTAA4ADQAHFhkLAAEBDAAAAAEBAwMDAwwDDQAAAAAAAAAAAAAAAAAOCBMDAwAOBwMDAA4HAwMADQcDAwALBwMDAA0HAwMAAwcDAwEDBwMDAQAAAwMADgMDAAMBAA4AAwADAQMBAwEDAQMDAwMHDgEBAwADAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMBAAMDAwMDAwMDAwMHAwMDAwMDAwMDAwMDAwMDAwcDBwAJAAAOAAkAAAMOAA4BAwMBAwADAwAAAwADDQMHBwAAAA0DDgMDAAMDAwEADg0BAwkJBw0JAAEBDQQLDQ0KCg4KCgcHCQ0ODg0NAw0NAw0NAw0NDQ0NDQ0NDg4ODgUODg0AAA4AAAADAQAACQkJCQkJCQ4AAAMKAwAJAA4ODg4NDQ0ABQ4NBwcHBw0AAA0ADg0NAA0AAAAAAAMAAQwNBwANDAAHAAADBwUABwUJDAUJCQAHAAAAAA4BDgEBAQEAAQEBDQ4HAwEBAwwHAAcJBxoKDgsACw0NDgANAAAQCAMDAAoHAwUDAAoHAwkAAw0ADAkDDAABAA4MBQwCAQAAAgAHAAABBQcIBw0ADQ4OBwAJDAUABwAHBwADAAoAAAAAAQwDBw0ODQoNAAcHBwMACQcDBQwQBQUFDAUMBQUMBQUFBQUFBQUFDgUFBRAFBQAFBQUFAAQOAAEBAwABAQcBAwAHAwAHAwAHAwAHCAEBAwABAQcAAQEDAAEBAwADAAcDAAcDAAcDAAcDAAcBAQMAAQEDAAEBAwADAAcDAAcDAAcDAAMAAQEDAAMABwMAAwABAQMAAwAHAwAHAQEDAAEBAwADAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAcDAAMABwMDBwcHAAAKEAUFAAkAAAcHDgAHBwMKDgMADg4DAAAHDgAMBRAOCQMADgMAAwAHBwcJBwcHBwcDBwcJDQAABwwIAAAAAAsHBwUMDgAKFg4NBwMACQcHBwAHBwcHAwcHCQkFAwAFBwcAAA4FAAADAAAHBwcHCQcJCQcJBwcHCQMACQkHBwAAAAAHAwAHBwMABwoDAAMOCgcJCQcOBwAAAAcDBw4HAQQNDQADAAMAAwAAAAMABwMABwMDAwMJDg4AAwAHAwAHAAMABwcOAQEDAAEBCQkADgoHAAADAA4AAwAHAAcHBw4ABwcDAQcNAwMHAwMAAwAAAAcHBwADAwUKBwAAAA4DBw0HBwcHBwcHBwcHBwcHBwcHEAkHDAcHBwcMDgkJAQEAAAABAAAAAAAAAAAHDgEAAQAHAAEBDgEBAQEBAQEBAQEBAQEBAQUBAQEAAAAAAQEAAAEBAAEBAQEBAQEAAAEBAQEBAQcHAQcNAQMDAwEHDgAABwAAAwADAAABBwMABwMABwMAAAEBAwMDAwAABQAAAAwAAAMABQMABQEBAwMDAwAABwc
return fetch(wasmData).then(res => res.arrayBuffer()).then(fromBinary);
};
const fromBinary = (wasmBinary) => {
const Module = { wasmBinary };
const promise = new Promise(res => {
Module.onRuntimeInitialized = () => {
res(Module);
};
});
{
var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}Module["arguments"]=[];Module["thisProgram"]="./this.program";Module["quit"]=function(status,toThrow){throw toThrow};Module["preRun"]=[];Module["postRun"]=[];var ENVIRONMENT_IS_WEB=false;var ENVIRONMENT_IS_WORKER=false;var ENVIRONMENT_IS_NODE=false;var ENVIRONMENT_IS_SHELL=false;ENVIRONMENT_IS_WEB=typeof window==="object";ENVIRONMENT_IS_WORKER=typeof importScripts==="function";ENVIRONMENT_IS_NODE=typeof process==="object"&&typeof require==="function"&&!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_WORKER;ENVIRONMENT_IS_SHELL=!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_NODE&&!ENVIRONMENT_IS_WORKER;var scriptDirectory="";function locateFile(path$$1){if(Module["locateFile"]){return Module["locateFile"](path$$1,scriptDirectory)}else{return scriptDirectory+path$$1}}if(ENVIRONMENT_IS_NODE){scriptDirectory=__dirname+"/";var nodeFS;var nodePath;Module["read"]=function shell_read(filename,binary){var ret;if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);ret=nodeFS["readFileSync"](filename);return binary?ret:ret.toString()};Module["readBinary"]=function readBinary(filename){var ret=Module["read"](filename,true);if(!ret.buffer){ret=new Uint8Array(ret);}assert(ret.buffer);return ret};if(process["argv"].length>1){Module["thisProgram"]=process["argv"][1].replace(/\\/g,"/");}Module["arguments"]=process["argv"].slice(2);{module["exports"]=Module;}process["on"]("uncaughtException",function(ex){if(!(ex instanceof ExitStatus)){throw ex}});process["on"]("unhandledRejection",abort);Module["quit"]=function(status){process["exit"](status);};Module["inspect"]=function(){return"[Emscripten Module object]"};}else if(ENVIRONMENT_IS_SHELL){if(typeof read!="undefined"){Module["read"]=function shell_read(f){return read(f)};}Module["readBinary"]=function readBinary(f){var data;if(typeof readbuffer==="function"){return new Uint8Array(readbuffer(f))}data=read(f,"binary");assert(typeof data==="object");return data};if(typeof scriptArgs!="undefined"){Module["arguments"]=scriptArgs;}else if(typeof arguments!="undefined"){Module["arguments"]=arguments;}if(typeof quit==="function"){Module["quit"]=function(status){quit(status);};}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href;}else if(document.currentScript){scriptDirectory=document.currentScript.src;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}Module["read"]=function shell_read(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.send(null);return xhr.responseText};if(ENVIRONMENT_IS_WORKER){Module["readBinary"]=function readBinary(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.responseType="arraybuffer";xhr.send(null);return new Uint8Array(xhr.response)};}Module["readAsync"]=function readAsync(url,onload,onerror){var xhr=new XMLHttpRequest;xhr.open("GET",url,true);xhr.responseType="arraybuffer";xhr.onload=function xhr_onload(){if(xhr.status==200||xhr.status==0&&xhr.response){onload(xhr.response);return}onerror();};xhr.onerror=onerror;xhr.send(null);};Module["setWindowTitle"]=function(title){document.title=title;};}var out=Module["print"]||(typeof console!=="undefined"?console.log.bind(console):typeof print!=="undefined"?print:null);var err=Module["printErr"]||(typeof printErr!=="undefined"?printErr:typeof console!=="undefined"&&console.warn.bind(console)||out);for(key in moduleOverrides){if(moduleOverrides.hasOwnProperty(key)){Module[key]=moduleOverrides[key];}}moduleOverrides=undefined;var asm2wasmImports={"f64-rem":function(x,y){return x%y},"debugger":function(){debugger}};var tempRet0=0;var setTempRet0=function(value){tempRet0=value;};var getTempRet0=function(){return tempRet0};if(typeof WebAssembly!=="object"){err("no native wasm support detected");}var wasmMemory;var wasmTable;var ABORT=false;function assert(condition,text){if(!condition){abort("Assertion failed: "+text);}}var UT
}
return promise;
};
});
var dist_1 = dist.instantiate;
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/** Whether we submit commands to the device queue immediately. */
tf.ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
/**
* Thread register block size for matmul kernel. If 0, we use the version of
* matMul without register blocking.
*/
tf.ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);
/**
* -1: conv2d_naive
* 0: conv2d_mm with matmul without register blocking
* >0: conv2d_mm with matmul_packed with WPT=this
*/
tf.ENV.registerFlag('WEBGPU_CONV2D_WORK_PER_THREAD', () => 2);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class BufferManager {
constructor(device) {
this.device = device;
this.numUsedBuffers = 0;
this.numFreeBuffers = 0;
this.freeBuffers = new Map();
this.usedBuffers = new Map();
this.numBytesUsed = 0;
}
acquireBuffer(byteSize, usage) {
const key = getBufferKey(byteSize, usage);
if (!this.freeBuffers.has(key)) {
this.freeBuffers.set(key, []);
}
if (!this.usedBuffers.has(key)) {
this.usedBuffers.set(key, []);
}
this.numBytesUsed += byteSize;
this.numUsedBuffers++;
if (this.freeBuffers.get(key).length > 0) {
this.numFreeBuffers--;
const newBuffer = this.freeBuffers.get(key).shift();
this.usedBuffers.get(key).push(newBuffer);
return newBuffer;
}
const newBuffer = this.device.createBuffer({ size: byteSize, usage });
this.usedBuffers.get(key).push(newBuffer);
return newBuffer;
}
releaseBuffer(buffer, byteSize, usage) {
if (this.freeBuffers == null) {
return;
}
const key = getBufferKey(byteSize, usage);
if (!this.freeBuffers.has(key)) {
this.freeBuffers.set(key, []);
}
this.freeBuffers.get(key).push(buffer);
this.numFreeBuffers++;
this.numUsedBuffers--;
const bufferList = this.usedBuffers.get(key);
const bufferIndex = bufferList.indexOf(buffer);
if (bufferIndex < 0) {
throw new Error('Cannot release a buffer that was never provided by this ' +
'buffer manager');
}
bufferList.splice(bufferIndex, 1);
this.numBytesUsed -= byteSize;
}
getNumUsedBuffers() {
return this.numUsedBuffers;
}
getNumFreeBuffers() {
return this.numFreeBuffers;
}
reset() {
this.freeBuffers = new Map();
this.usedBuffers = new Map();
this.numUsedBuffers = 0;
this.numFreeBuffers = 0;
}
dispose() {
if (this.freeBuffers == null) {
return;
}
for (const key in this.freeBuffers) {
this.freeBuffers.get(key).forEach(buff => {
buff.destroy();
});
}
for (const key in this.usedBuffers) {
this.usedBuffers.get(key).forEach(buff => {
buff.destroy();
});
}
this.freeBuffers = null;
this.usedBuffers = null;
this.numUsedBuffers = 0;
this.numFreeBuffers = 0;
}
}
function getBufferKey(byteSize, usage) {
return `${byteSize}_${usage}`;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Generates GLSL that computes strides.
function symbolicallyComputeStrides(indicesArr, variableName) {
if (Math.max(...indicesArr) > 3) {
throw new Error('Cannot symbolically compute strides for rank > 4 tensor.');
}
const numCoords = indicesArr.length;
const shape = indicesArr.map(d => `${variableName}[${d}]`);
const strides = new Array(numCoords - 1);
strides[numCoords - 2] = shape[numCoords - 1];
for (let i = numCoords - 3; i >= 0; --i) {
strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
}
return strides;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getCoordsDataType(rank) {
if (rank <= 1) {
return 'int';
}
else if (rank === 2) {
return 'ivec2';
}
else if (rank === 3) {
return 'ivec3';
}
else if (rank === 4) {
return 'ivec4';
}
else {
throw Error(`GPU for rank ${rank} is not yet supported`);
}
}
function mapToGlslTypes(type) {
if (type === 'float32') {
return 'float';
}
if (type === 'int32') {
return 'int';
}
return type;
}
function makeShader(inputInfo, outputData, program) {
const prefixSnippets = [];
if (program.workGroupSize != null) {
prefixSnippets.push(`
layout (local_size_x = ${program.workGroupSize[0]},
local_size_y = ${program.workGroupSize[1]},
local_size_z = ${program.workGroupSize[2]}) in;
`);
}
// Output buffer.
prefixSnippets.push(`
layout(std430, set = 0, binding = 0) writeonly buffer ssbOut {
${mapToGlslTypes(outputData.dtype)} result[];
};
`);
let uniformDeclaration = '';
program.variableNames.forEach((x, i) => {
uniformDeclaration += `${getCoordsDataType(inputInfo[i].shape.length)} ${x.charAt(0).toLowerCase() + x.slice(1)}Shape; `;
prefixSnippets.push(`
layout(std430, set = 0, binding = ${1 + i}) readonly buffer ssb${x} {
${mapToGlslTypes(inputInfo[i].dtype)} ${x}[];
};
`);
});
uniformDeclaration +=
`${getCoordsDataType(outputData.shape.length)} outShape; `;
if (program.uniforms) {
uniformDeclaration += program.uniforms;
}
prefixSnippets.push(`
layout(std140, set = 0, binding = ${1 + program.variableNames.length}) uniform Uniforms {
${uniformDeclaration}
};
`);
const [getOutputCoords, dispatchLayoutRank] = generateGetOutputCoords(program.dispatchLayout);
const getCoords = generateGetCoordsFromFlatIndex(outputData.shape);
const sources = [
SHADER_PREFIX, prefixSnippets.join('\n'), SAMPLING_SNIPPETS,
getOutputCoords, getCoords,
getSetOutputSnippet(outputData.shape.length, outputData.dtype)
];
if (dispatchLayoutRank === outputData.shape.length) {
// Input sampling snippet is only meaningful when the output isn't getting
// implicitly reshaped (like it does in conv2d_matmul).
const inputSamplingSnippet = inputInfo.map(x => getInputSamplingSnippet(x, outputData.shape))
.join('\n');
sources.push(inputSamplingSnippet);
}
sources.push(program.userCode);
const source = sources.join('\n');
return source;
}
const SHADER_PREFIX = `#version 450
int idiv(int a, int b, float sign) {
int res = a / b;
int mod = a % b;
if (sign < 0. && mod != 0) {
res -= 1;
}
return res;
}
bool coordIsValid(ivec4 coord, ivec4 shape) {
return all(greaterThanEqual(coord, ivec4(0))) &&
all(lessThan(coord, shape));
}
`;
const SAMPLING_SNIPPETS = `
uint getFlatIndex(uint coord, uint shape) {
return coord;
}
uint getFlatIndex(ivec2 coords, ivec2 shape) {
return uint(dot(coords, ivec2(shape.y, 1.)));
}
uint getFlatIndex(ivec3 coords, ivec3 shape) {
return uint(dot(coords, ivec3(shape.y * shape.z, shape.z, 1.)));
}
uint getFlatIndex(ivec4 coords, ivec4 shape) {
return uint(dot(coords, ivec4(
shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1.)));
}
`;
function getSetOutputSnippet(outRank, outBufferType) {
let snippet = `void setOutput(uint flatIndex, float value) {
result[flatIndex] = ${mapToGlslTypes(outBufferType) === 'int' ? 'int(value)' : 'value'};
}
void setOutput(uint flatIndex, int value) {
result[flatIndex] = ${mapToGlslTypes(outBufferType) === 'float' ? 'float(value)' : 'value'};
}`;
if (outRank >= 2) {
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, outRank);
const type = getCoordsDataType(outRank);
snippet += `
void setOutput(${dims.map(d => `int ${d}`).join(', ')}, float value) {
uint flatIndex = getFlatIndex(${type}(${dims.join(', ')}), outShape);
setOutput(flatIndex, value);
}
void setOutput(${dims.map(d => `int ${d}`).join(', ')}, int value) {
uint flatIndex = getFlatIndex(${type}(${dims.join(', ')}), outShape);
setOutput(flatIndex, value);
}
`;
}
return snippet;
}
function getInputSamplingSnippet(inInfo, outShape) {
let res = getSamplerFromInInfo(inInfo);
const inShape = inInfo.shape;
if (inShape.length <= outShape.length) {
res += getSamplerAtOutputCoords(inInfo, outShape);
}
return res;
}
function getSamplerFromInInfo(inInfo) {
const texName = inInfo.name;
const rank = inInfo.shape.length;
const type = getCoordsDataType(rank);
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, rank);
const inputs = dims.map(d => `int ${d}`).join(', ');
if (rank < 1) {
return `
float ${funcName}() {
return ${texName}[0];
}
`;
}
return `
float ${funcName}(${inputs}) {
return ${texName}[getFlatIndex(${type}(${dims.join(',')}),
${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
}
`;
}
function getSamplerAtOutputCoords(inInfo, outShape) {
const texName = inInfo.name;
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
const inRank = inInfo.shape.length;
const outRank = outShape.length;
const type = getCoordsDataType(outRank);
const broadcastDims = tf.backend_util.getBroadcastDims(inInfo.shape, outShape);
const rankDiff = outRank - inRank;
let coordsSnippet = '';
if (inRank === 0) {
return `
float ${funcName}() {
return get${texFuncSnippet}();
}
float ${funcName}(${type} coords) {
return get${texFuncSnippet}();
}
`;
}
else {
if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
}
else {
coordsSnippet =
broadcastDims.map(d => `coords[${d + rankDiff}] = 0;`).join('\n');
}
}
let unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
}
else {
if (outRank > 1) {
const coordsType = getCoordsDataType(inRank);
const coordsValues = inInfo.shape.map((s, i) => `coords[${i + rankDiff}]`).join(', ');
unpackedCoordsSnippet = `${coordsType}(${coordsValues})`;
}
else {
unpackedCoordsSnippet = 'coords';
}
}
return `
float ${funcName}() {
${type} coords = getOutputCoords();
${coordsSnippet}
return ${texName}[getFlatIndex(${unpackedCoordsSnippet}, ${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
}
float ${funcName}(${type} coords) {
${coordsSnippet}
return ${texName}[getFlatIndex(${unpackedCoordsSnippet}, ${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape)];
}
`;
}
/**
* Generates getOutputCoords() function that computes output coordinates from
* dispatch geometry to reduce arithmetic.
*/
function generateGetOutputCoords(dispatchLayout) {
const { x, y = [], z = [] } = dispatchLayout;
let gatherDimensionsStr = '';
const dims = [x, y, z];
let rank = 0;
for (let i = 0; i < dims.length; i++) {
const arr = dims[i];
if (arr.length === 0) {
continue;
}
rank += arr.length;
if (arr.length === 1) {
gatherDimensionsStr += `uint d${arr[0]} = gl_GlobalInvocationID[${i}];`;
}
else {
const strides = symbolicallyComputeStrides(arr, 'outShape');
gatherDimensionsStr += `uint index${i} =
gl_GlobalInvocationID[${i}];`;
for (let j = 0; j < strides.length; j++) {
gatherDimensionsStr += `uint d${arr[j]} = index${i} / ${strides[j]};`;
if (j === strides.length - 1) {
gatherDimensionsStr += `uint d${arr[j + 1]} = ` +
`index${i} - d${arr[j]} * ${strides[j]};`;
}
else {
gatherDimensionsStr += `index${i} -= d${arr[j]} * ${strides[j]};`;
}
}
}
}
const dimensions = [];
for (let i = 0; i < rank; i++) {
dimensions.push(`d${i}`);
}
const dtype = getCoordsDataType(rank);
const snippet = `${dtype} getOutputCoords() {
${gatherDimensionsStr}
return ${dtype}(${dimensions.join(',')});
}`;
return [snippet, rank];
}
/**
* Derives logical coordinates from a flat index. Performs integer division with
* each stride and decrements the index until the index equals the final
* dimension coordinate.
*/
function generateGetCoordsFromFlatIndex(shape) {
const rank = shape.length;
if (rank <= 1) {
return `int getCoordsFromFlatIndex(int index) {return index; }`;
}
const strides = tf.util.computeStrides(shape);
const dtype = getCoordsDataType(rank);
const coords = [];
for (let i = 0; i < rank; i++) {
coords.push(`d${i}`);
}
const snippet = strides
.map((stride, i) => {
const line1 = `uint ${coords[i]} = index / ${stride}`;
const line2 = i === strides.length - 1 ?
`uint ${coords[i + 1]} = index - ${coords[i]} * ${stride}` :
`index -= ${coords[i]} * ${stride}`;
return `${line1}; ${line2};`;
})
.join('');
return `
${dtype} getCoordsFromFlatIndex(uint index) {
${snippet}
return ${dtype}(${coords.join(',')});
}
`;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const arrayProduct = (arr) => {
let product = 1;
for (let i = 0; i < arr.length; i++) {
product *= arr[i];
}
return product;
};
// Computes dispatch geometry based on layout of output dimensions and
// workGroupSize.
function computeDispatch(layout, outputShape, workGroupSize = [1, 1, 1], elementsPerThread = [1, 1, 1]) {
return [
Math.ceil(arrayProduct(layout.x.map(d => outputShape[d])) /
(workGroupSize[0] * elementsPerThread[0])),
layout.y ? Math.ceil(arrayProduct(layout.y.map(d => outputShape[d])) /
(workGroupSize[1] * elementsPerThread[1])) :
1,
layout.z ? Math.ceil(arrayProduct(layout.z.map(d => outputShape[d])) /
(workGroupSize[2] * elementsPerThread[2])) :
1
];
}
function flatDispatchLayout(shape) {
return { x: shape.map((d, i) => i) };
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class ArgMinMaxProgram {
constructor(inputShape, axis, reduceType) {
this.variableNames = ['x'];
this.uniforms = 'uint axis;';
const axes = [axis];
tf.backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, inputShape.length);
const op = reduceType === 'min' ? '<' : '>';
// |outShape| is the shape with the removed axis
// |reduceShape| is the shape we are reducing. i.e. [ inputShape[axis] ]
const [outputShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(inputShape, axes);
this.outputShape = outputShape.length === 0 ? [1] : outputShape;
// Length of the axis we're reducing on.
const reduceSize = tf.util.sizeFromShape(reduceShape);
// The number of comparisons each thread will do
const reductionFactor = 2;
const xMaxThreads = 1024; // gl_MaxComputeWorkGroupSize
const xThreads = Math.min(Math.ceil(reduceSize / reductionFactor), xMaxThreads);
this.workGroupSize = [xThreads, 1, 1];
this.dispatchLayout = { x: [], y: this.outputShape.map((d, i) => i) };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
// When xThreads > 1, each thread reduces Length / xThreads values.
// Thes results are stored in shared memory and iteratively reduced.
const reduceInSharedMemory = xThreads > 1;
const sharedMemorySnippet = `
shared uint xBestIndices[WorkGroupSize];
shared float xBestValues[WorkGroupSize];
`;
const sharedMemoryReduceSnippet = `
xBestIndices[gl_LocalInvocationID.x] = bestIndex;
xBestValues[gl_LocalInvocationID.x] = bestValue;
uint currentSize = WorkGroupSize;
while (currentSize > 1) {
barrier();
for (uint w = 0; w < ${reductionFactor}; ++w) {
uint i = gl_LocalInvocationID.x * ${reductionFactor} + w;
if (i < currentSize) {
uint candidateIndex = xBestIndices[i];
float candidate = xBestValues[i];
if (candidate ${op} bestValue && !isnan(candidate)) {
bestValue = candidate;
bestIndex = candidateIndex;
}
}
}
xBestIndices[gl_LocalInvocationID.x] = bestIndex;
xBestValues[gl_LocalInvocationID.x] = bestValue;
currentSize = DIV_CEIL(currentSize, ${reductionFactor});
}
if (gl_LocalInvocationID.x == 0) {
setOutput(flatOutputIndex, int(bestIndex));
}
`;
const outputCoordsType = getCoordsDataType(this.outputShape.length);
const indexOutputCoords = (outputCoords, index) => {
if (this.outputShape.length === 1) {
return outputCoords;
}
else {
return `${outputCoords}[${index}]`;
}
};
const indexInputShape = (index) => {
if (inputShape.length === 1) {
return 'xShape';
}
else {
return `xShape[${index}]`;
}
};
this.userCode = `
#define DIV_CEIL(x, y) (((x) - 1) / (y) + 1)
const uint WorkGroupSize = gl_WorkGroupSize.x;
${reduceInSharedMemory ? sharedMemorySnippet : ''}
// In order to get a flattened index into the input tensor, we need to
// add back the index along the reduced dimension to |outputCoords|.
// This function outputs the offset to the first value along
// |axis| and the stride to get the next value of the input along |axis|.
uvec2 getInputCoordInfo() {
const ${outputCoordsType} outputCoords = getOutputCoords();
uint i = ${this.outputShape.length - 1};
uint stride = 1;
uint inputStride = 1;
uint offset = 0;
for (uint r = 1; r <= ${inputShape.length}; ++r) {
uint length = ${indexInputShape(`${inputShape.length} - r`)};
if (${inputShape.length} - r == axis) {
inputStride = stride;
} else {
offset += ${indexOutputCoords('outputCoords', 'i--')} * stride;
}
stride *= length;
}
return uvec2(offset, inputStride);
}
uint getInputIndex(uvec2 coordInfo, uint index) {
return coordInfo[0] + coordInfo[1] * index;
}
void main() {
const uvec2 coordInfo = getInputCoordInfo();
uint bestIndex = 0;
float bestValue = x[getInputIndex(coordInfo, bestIndex)];
const uint Length = ${indexInputShape('axis')};
const uint WorkPerThread = DIV_CEIL(Length, WorkGroupSize);
for (uint w = 0; w < WorkPerThread; ++w) {
uint i = gl_GlobalInvocationID.x * WorkPerThread + w;
if (i < Length) {
float candidate = x[getInputIndex(coordInfo, i)];
if (candidate ${op} bestValue && !isnan(candidate)) {
bestValue = candidate;
bestIndex = i;
}
}
}
const uint flatOutputIndex = gl_GlobalInvocationID.y;
${reduceInSharedMemory ? sharedMemoryReduceSnippet :
'setOutput(flatOutputIndex, int(bestIndex));'}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const MUL = 'return a * b;';
const ADD = 'return a + b;';
const SUB = 'return a - b;';
const INT_DIV = `
float s = sign(a) * sign(b);
int ia = int(round(a));
int ib = int(round(b));
return float(idiv(ia, ib, s));
`;
class BinaryOpProgram {
constructor(op, aShape, bShape) {
this.variableNames = ['A', 'B'];
this.workPerThread = 4;
this.workGroupSize = [1, 1, 1];
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
const size = tf.util.sizeFromShape(this.outputShape);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
const type = getCoordsDataType(this.outputShape.length);
this.userCode = `
float binaryOperation(float a, float b) {
${op}
}
void main() {
int index = int(gl_GlobalInvocationID.x);
for(int i = 0; i < ${this.workPerThread}; i++) {
int flatIndex = index * ${this.workPerThread} + i;
if(flatIndex < ${size}) {
${type} coords = getCoordsFromFlatIndex(flatIndex);
float a = getAAtOutCoords(coords);
float b = getBAtOutCoords(coords);
setOutput(flatIndex, binaryOperation(a, b));
}
}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class ConcatProgram {
constructor(shapes) {
this.outputShape =
tf.backend_util.computeOutShape(shapes, 1 /* axis */);
this.variableNames = shapes.map((_, i) => `T${i}`);
this.dispatchLayout = { x: [0], y: [1] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
const offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
for (let i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][1];
}
const snippets = [
`if (yC < ${offsets[0]}) setOutput(coords.x, coords.y, getT0(yR, yC));`
];
for (let i = 1; i < offsets.length; i++) {
const shift = offsets[i - 1];
snippets.push(`else if (yC < ${offsets[i]}) ` +
`setOutput(coords.x, coords.y, getT${i}(yR, yC-${shift}));`);
}
const lastIndex = offsets.length;
const lastShift = offsets[offsets.length - 1];
snippets.push(`else setOutput(coords.x, coords.y, getT${lastIndex}(yR, yC-${lastShift}));`);
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int yR = coords.x;
int yC = coords.y;
${snippets.join('\n ')}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const matMulHeader = `
float mm_readA(uint row, uint col);
float mm_readB(uint row, uint col);
void mm_write(uint row, uint col, float value);
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter);`;
function makeMatMulSource() {
return `
${matMulHeader}
const uint MatTileSize = gl_WorkGroupSize.x; // .x == .y
shared float mm_Asub[MatTileSize][MatTileSize];
shared float mm_Bsub[MatTileSize][MatTileSize];
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
uint localRow = gl_LocalInvocationID.y; // 0..MatTileSize
uint localCol = gl_LocalInvocationID.x; // 0..MatTileSize
uint globalRow = gl_GlobalInvocationID.y; // AOuter
uint globalCol = gl_GlobalInvocationID.x; // Inner
float acc = 0.0;
uint numTiles = (dimInner - 1) / MatTileSize + 1;
for (uint t = 0; t < numTiles; t++) {
// Load one tile of A and B into local memory
uint tiledACol = MatTileSize * t + localCol;
uint tiledBRow = MatTileSize * t + localRow;
mm_Asub[localRow][localCol] = mm_readA(globalRow, tiledACol);
mm_Bsub[localRow][localCol] = mm_readB(tiledBRow, globalCol);
// Synchronise to make sure the tile is loaded
barrier();
for (uint k = 0; k < MatTileSize; k++) {
acc += mm_Asub[localRow][k] * mm_Bsub[k][localCol];
}
// Synchronise before loading the next tile
barrier();
}
if (globalCol < dimBOuter && globalRow < dimAOuter) {
mm_write(globalRow, globalCol, acc);
}
}
`;
}
class MatMulProgram {
constructor(outputShape) {
this.variableNames = ['A', 'B'];
this.workGroupSize = [16, 16, 1]; // Must be square.
this.outputShape = outputShape;
this.dispatchLayout = { x: [1], y: [2], z: [0] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
this.userCode = `
uint dimAOuter = aShape[1];
uint dimInner = aShape[2];
uint dimBOuter = bShape[2];
${makeMatMulSource()}
float mm_readA(uint row, uint col) {
if (row < dimAOuter && col < dimInner) {
return A[row * dimInner + col];
} else {
return 0.0;
}
}
float mm_readB(uint row, uint col) {
if (row < dimInner && col < dimBOuter) {
return B[row * dimBOuter + col];
} else {
return 0.0;
}
}
void mm_write(uint row, uint col, float value) {
setOutput(row * dimBOuter + col, value);
}
void main() {
mm_matMul(dimAOuter, dimInner, dimBOuter);
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function makeMatMulPackedSource(workPerThread) {
return `
${matMulHeader}
const uint WorkGroupSize = gl_WorkGroupSize.x; // .x == .y
const uint WorkPerThread = ${workPerThread};
const uint MatTileSize = WorkGroupSize * WorkPerThread;
shared float mm_Asub[MatTileSize][MatTileSize];
shared float mm_Bsub[MatTileSize][MatTileSize];
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
// These are 0..MatTileSize, in increments of WorkPerThread.
uint tileRow = gl_LocalInvocationID.y * WorkPerThread;
uint tileCol = gl_LocalInvocationID.x * WorkPerThread;
// These are 0..AOuter, in increments of WorkPerThread.
uint globalRow = gl_GlobalInvocationID.y * WorkPerThread;
uint globalCol = gl_GlobalInvocationID.x * WorkPerThread;
uint numTiles = (dimInner - 1) / MatTileSize + 1;
float acc[WorkPerThread][WorkPerThread];
float ACached;
float BCached[WorkPerThread];
// Without this initialization strange values show up in acc.
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
acc[innerRow][innerCol] = 0.0;
}
}
// Loop over shared dimension.
for (uint t = 0; t < numTiles; t++) {
// Load one tile of A and B into local memory.
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
uint inputRow = tileRow + innerRow;
uint inputCol = tileCol + innerCol;
mm_Asub[inputRow][inputCol] = mm_readA(
globalRow + innerRow,
t * MatTileSize + tileCol + innerCol);
mm_Bsub[inputRow][inputCol] = mm_readB(
t * MatTileSize + tileRow + innerRow,
globalCol + innerCol);
}
}
barrier();
// Compute acc values for a single thread.
for (uint k = 0; k < MatTileSize; k++) {
for (uint inner = 0; inner < WorkPerThread; inner++) {
BCached[inner] = mm_Bsub[k][tileCol + inner];
}
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
ACached = mm_Asub[tileRow + innerRow][k];
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
acc[innerRow][innerCol] += ACached * BCached[innerCol];
}
}
}
barrier();
}
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
uint globalFlatIndex =
(globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
if ((globalCol + innerCol) < dimBOuter &&
(globalRow + innerRow) < dimAOuter) {
mm_write(globalRow + innerRow,
globalCol + innerCol,
acc[innerRow][innerCol]);
}
}
}
}
`;
}
class MatMulPackedProgram {
constructor(outputShape, workPerThread) {
this.variableNames = ['A', 'B'];
this.workGroupSize = [16, 16, 1];
this.outputShape = outputShape;
this.workPerThread = workPerThread;
this.dispatchLayout = { x: [1], y: [2], z: [0] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [workPerThread, workPerThread, 1]);
// Consider compiling a different version of the shader that doesn't care
// about boundary conditions when loading from Asub / Bsub when tiles fit
// neatly inside of output. May slightly improve performance.
this.userCode = `
uint dimAOuter = aShape[1];
uint dimInner = aShape[2];
uint dimBOuter = bShape[2];
${makeMatMulPackedSource(workPerThread)}
float mm_readA(uint row, uint col) {
if (row < dimAOuter && col < dimInner) {
return A[row * dimInner + col];
} else {
return 0.0;
}
}
float mm_readB(uint row, uint col) {
if (row < dimInner && col < dimBOuter) {
return B[row * dimBOuter + col];
} else {
return 0.0;
}
}
void mm_write(uint row, uint col, float value) {
setOutput(row * dimBOuter + col, value);
}
void main() {
mm_matMul(dimAOuter, dimInner, dimBOuter);
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class Conv2DMMProgram {
constructor(convInfo, workPerThread) {
this.variableNames = ['x', 'W'];
this.uniforms = 'ivec2 filterDims, pad, stride;';
this.workGroupSize = [
16, 16,
1
];
this.outputShape = convInfo.outShape;
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
let elementsPerThread;
let matMulSource;
if (workPerThread === 0) {
elementsPerThread = [1, 1, 1];
matMulSource = makeMatMulSource();
}
else {
elementsPerThread = [workPerThread, workPerThread, 1];
matMulSource = makeMatMulPackedSource(workPerThread);
}
this.dispatchLayout = { x: [1], y: [2], z: [0] };
const matMulOutShape = [
convInfo.outShape[0], convInfo.outShape[1] * convInfo.outShape[2],
convInfo.outShape[3]
];
this.dispatch = computeDispatch(this.dispatchLayout, matMulOutShape, this.workGroupSize, elementsPerThread);
this.userCode = `
${matMulSource}
int batch;
float mm_readA(uint row, uint col) {
int r = int(row), c = int(col);
ivec4 coord = ivec4(
(c / filterDims[1]) % filterDims[0],
c % filterDims[1],
c / (filterDims[0] * filterDims[1]),
r);
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
return coordIsValid(coord, shape) ? W[getFlatIndex(coord, shape)] : 0;
}
float mm_readB(uint row, uint col) {
int r = int(row), c = int(col);
int outRow = c / outShape[2];
int outCol = c % outShape[2];
int WRow = (r / filterDims[1]) % filterDims[0];
int WCol = r % filterDims[1];
ivec4 coord = ivec4(
batch,
pad[0] + outRow * stride[0] + WRow,
pad[1] + outCol * stride[1] + WCol,
r / (filterDims[0] * filterDims[1]));
return coordIsValid(coord, xShape) ?
x[getFlatIndex(coord, xShape)] : 0;
}
void mm_write(uint row, uint col, float value) {
ivec4 outCoord = ivec4(
batch,
col / outShape[2],
col % outShape[2],
row);
if (coordIsValid(outCoord, outShape)) {
result[getFlatIndex(outCoord, outShape)] = value;
}
}
void main() {
batch = int(gl_GlobalInvocationID.z);
int dimAOuter = outShape[3];
int dimBOuter = outShape[1] * outShape[2];
int dimInner = filterDims[0] * filterDims[1] * xShape[3];
mm_matMul(dimAOuter, dimInner, dimBOuter);
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class Conv2DNaiveProgram {
constructor(convInfo) {
this.variableNames = ['x', 'W'];
this.uniforms = 'ivec2 filterDims, pad, stride;';
this.workGroupSize = [4, 8, 1];
this.outputShape = convInfo.outShape;
this.dispatchLayout = { x: [2], y: [1], z: [0, 3] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
this.userCode = `
float readInp(uint batch, uint row, uint col, uint chan) {
ivec4 coord = ivec4(batch, row, col, chan);
return coordIsValid(coord, xShape) ? getX(coord) : 0;
}
float readFilt(uint row, uint col, uint xChannel, uint outChannel) {
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
return coordIsValid(coord, shape) ?
getW(row, col, xChannel, outChannel) : 0;
}
void writeResult(uint batch, uint row, uint col, uint chan, float value) {
ivec4 coord = ivec4(batch, row, col, chan);
if (coordIsValid(coord, outShape)) {
setOutput(batch, row, col, chan, value);
}
}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int outChannel = coords[3];
float acc = 0.0;
for (int row = 0; row < filterDims[0]; ++row) {
for (int col = 0; col < filterDims[1]; ++col) {
for (int xChannel = 0; xChannel < xShape[3]; ++xChannel) {
float v = readInp(batch,
pad[0] + coords[1] * stride[0] + row,
pad[1] + coords[2] * stride[1] + col, xChannel);
float f = readFilt(row, col, xChannel, outChannel);
acc += v * f;
}
}
}
writeResult(batch, coords[1], coords[2], outChannel, acc);
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class DepthwiseConv2DProgram {
constructor(convInfo) {
this.variableNames = ['x', 'W'];
this.uniforms = 'ivec2 filterDims, pad, stride;';
this.workGroupSize = [4, 8, 1];
this.outputShape = convInfo.outShape;
this.dispatchLayout = { x: [2], y: [1], z: [0, 3] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
const xNumRows = convInfo.inHeight;
const xNumCols = convInfo.inWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const channelMul = convInfo.outChannels / convInfo.inChannels;
tf.util.assert(convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented');
tf.util.assert(convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1, () => 'TODO: Dilation is unimplemented');
this.userCode = `
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void writeResult(int batch, int row, int col, int chan, float value) {
ivec4 coord = ivec4(batch, row, col, chan);
if (coordIsValid(coord, outShape)) {
setOutput(batch, row, col, chan, value);
}
}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords[3];
int d1 = d2 / ${channelMul};
int q = d2 - d1 * ${channelMul};
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
// TODO(xing.xu): Flatten the two for loops and vec4 the operations.
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${xNumRows}) {
continue;
}
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${xNumCols}) {
continue;
}
float xVal = getX(batch, xR, xC, d1);
float wVal = getW(wR, wC, d1, q);
dotProd += xVal * wVal;
}
}
writeResult(batch, coords[1], coords[2], d2, dotProd);
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class MaxPoolProgram {
constructor(convInfo) {
this.variableNames = ['x'];
this.uniforms = 'ivec2 pad, stride, dilation, convDims, filterDims;';
this.workGroupSize = [4, 4, 1];
this.outputShape = convInfo.outShape;
this.dispatchLayout = { x: [1], y: [2], z: [0, 3] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize);
// TODO: Parallelize max computation by thread and merge result.
this.userCode = `
float getValue(int batch, int xR, int xC, int d) {
if (xC < 0 || xC >= convDims.x) {
return 0.0;
}
return getX(batch, xR, xC, d);
}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
if (all(lessThan(coords, outShape))) {
ivec2 xRCCorner = coords.yz * stride - pad;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
float minMaxValue = 0.0;
for (int wR = 0; wR < filterDims.y; wR += dilation.y) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= convDims.y) {
continue;
}
for (int wC = 0; wC < filterDims.x; wC += dilation.x) {
int xC = xCCorner + wC * dilation.x;
float value = getValue(batch, xR, xC, d);
minMaxValue = max(value, minMaxValue);
}
}
setOutput(batch, coords[1], coords[2], d, minMaxValue);
}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class PadProgram {
constructor(xShape, paddings, constantValue) {
this.variableNames = ['x'];
this.workPerThread = 8;
this.workGroupSize = [1, 1, 1];
this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
const rank = xShape.length;
const size = tf.util.sizeFromShape(this.outputShape);
const type = getCoordsDataType(rank);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const startValue = rank > 1 ? `${type}(${start})` : `${start}`;
const endValue = rank > 1 ? `${type}(${end})` : `${end}`;
const leftPadCondition = rank > 1 ? `any(lessThan(outC, start))` : `outC < start`;
const rightPadCondition = rank > 1 ? `any(greaterThanEqual(outC, end))` : `outC >= end`;
const unpackedCoords = rank > 1 ?
['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank) :
'coords';
this.userCode = `
${type} start = ${startValue};
${type} end = ${endValue};
void main() {
int index = int(gl_GlobalInvocationID.x);
for (int i = 0; i < ${this.workPerThread}; i++) {
int flatIndex = index * ${this.workPerThread} + i;
if (flatIndex < ${size}) {
${type} outC = getCoordsFromFlatIndex(flatIndex);
if (${leftPadCondition} || ${rightPadCondition}) {
setOutput(flatIndex, ${constantValue});
} else {
${type} coords = outC - start;
setOutput(flatIndex, getX(${unpackedCoords}));
}
}
}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class ResizeBilinearProgram {
constructor(inputShape, newHeight, newWidth, alignCorners) {
this.variableNames = ['x'];
this.outputShape = [inputShape[0], newHeight, newWidth, inputShape[3]];
this.dispatchLayout = { x: [1], y: [2], z: [0, 3] };
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
const adjustHeight = alignCorners && newHeight > 1;
const adjustWidth = alignCorners && newWidth > 1;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
if (all(lessThan(coords, outShape))) {
int b = coords[0];
int d = coords[3];
ivec2 rc = coords.yz;
vec2 effectiveInSize = vec2(
${adjustHeight ? 'xShape.y - 1.0' : 'xShape.y'},
${adjustWidth ? 'xShape.z - 1.0' : 'xShape.z'});
vec2 effectiveOutSize = vec2(
${adjustHeight ? 'outShape.y - 1.0' : 'outShape.y'},
${adjustWidth ? 'outShape.z - 1.0' : 'outShape.z'});
vec2 effectiveInputOverOutputRatioRC =
effectiveInSize / effectiveOutSize;
// Fractional source index
vec2 sourceFracIndexRC = vec2(rc) * effectiveInputOverOutputRatioRC;
// Compute the four integer indices.
ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);
ivec2 sourceCeilRC = ivec2(
min(xShape.yz - 1.0, ceil(sourceFracIndexRC)));
float topLeft = getX(b, sourceFloorRC.x, sourceFloorRC.y, d);
float bottomLeft = getX(b, sourceCeilRC.x, sourceFloorRC.y, d);
float topRight = getX(b, sourceFloorRC.x, sourceCeilRC.y, d);
float bottomRight = getX(b, sourceCeilRC.x, sourceCeilRC.y, d);
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
float top = topLeft + (topRight - topLeft) * fracRC.y;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
float newValue = top + (bottom - top) * fracRC.x;
setOutput(b, coords[1], coords[2], d, newValue);
}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
class TransposeProgram {
constructor(aShape, newDim) {
this.variableNames = ['A'];
const outputShape = new Array(aShape.length);
for (let i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
const dtype = getCoordsDataType(this.rank);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
const switched = getSwitchedCoords(newDim);
this.userCode = `
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getFlatIndex(resRC, outShape), A[getFlatIndex(
${dtype}(${switched}), aShape)]);
}
`;
}
}
function getSwitchedCoords(newDim) {
const rank = newDim.length;
if (rank > 4) {
throw Error(`Transpose for rank ${rank} is not yet supported`);
}
const switchedCoords = new Array(rank);
for (let i = 0; i < newDim.length; i++) {
switchedCoords[newDim[i]] = `resRC[${i}]`;
}
return switchedCoords.join();
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const RELU = 'return max(a, 0.0);';
const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
class UnaryOpProgram {
constructor(outputShape, op) {
this.variableNames = ['A'];
this.workPerThread = 4;
this.workGroupSize = [1, 1, 1];
this.outputShape = outputShape;
const size = tf.util.sizeFromShape(this.outputShape);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]);
const type = getCoordsDataType(this.outputShape.length);
this.userCode = `
float unaryOperation(float a) {
${op}
}
void main() {
int index = int(gl_GlobalInvocationID.x);
for(int i = 0; i < ${this.workPerThread}; i++) {
int flatIndex = index * ${this.workPerThread} + i;
if(flatIndex < ${size}) {
${type} coords = getCoordsFromFlatIndex(flatIndex);
float a = getAAtOutCoords(coords);
setOutput(flatIndex, unaryOperation(a));
}
}
}
`;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const makeBindGroup = (device, bindGroupLayout, inputs, output, uniforms) => {
const bindings = [output, ...inputs];
if (uniforms) {
bindings.push(uniforms);
}
return device.createBindGroup({
layout: bindGroupLayout,
bindings: bindings.map((b, i) => ({ binding: i, resource: b.resource })),
});
};
const makeBindGroupLayout = (device, inputs, output, uniforms) => {
const bindings = Array(1 + inputs.length).fill({
visibility: GPUShaderStage.COMPUTE,
type: 'storage-buffer'
});
if (uniforms) {
bindings.push({
visibility: GPUShaderStage.COMPUTE,
type: 'uniform-buffer'
});
}
return device.createBindGroupLayout({
bindings: bindings.map((b, i) => (Object.assign({ binding: i }, b))),
});
};
const compileProgram = (shaderCompiler, shaderKind, compileOptions, device, program, inputsData, output, uniforms) => {
const outputData = { dtype: output.dtype, shape: output.shape };
const source = makeShader(inputsData, outputData, program);
const result = shaderCompiler.CompileGlslToSpv(source, shaderKind, 'file', 'main', compileOptions);
const error = result.GetErrorMessage();
if (error.length) {
console.error(source.split('\n')
.map((s, l) => (l + 1).toString().padStart(5, ' ') + ' ' + s)
.join('\n'));
throw new Error(`Shader compilation failed: ${error}`);
}
const bindGroupLayout = makeBindGroupLayout(device, inputsData, output, uniforms);
const code = result.GetBinary();
const layout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] });
const module = device.createShaderModule({ code });
const pipeline = device.createComputePipeline({ layout, computeStage: { module, entryPoint: 'main' } });
return { bindGroupLayout, pipeline };
};
// TODO: Consider allowing each program to specify its own shader key. E.g. some
// kernels account for different work group sizes, but some don't.
// TODO: Consider uploading shape info as vec4s regardless of rank to reduce
// recompilation.
function makeShaderKey(program, ranks) {
const key = (program.workGroupSize ? program.workGroupSize.join(',') : '') +
ranks.join(',') + program.userCode;
return key;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const DEFAULT_GPUBUFFER_USAGE = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
class WebGPUBackend extends tf.KernelBackend {
constructor(device, shaderc) {
super();
this.commandQueueOwnedIds = new WeakSet();
this.tensorMap = new WeakMap();
this.disposalQueue = [];
this.disposed = false;
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
this.binaryCache = {};
this.device = device;
this.queue = device.getQueue();
this.commandQueue = [];
this.shaderc = shaderc;
this.compiler = new shaderc.Compiler();
const opts = new shaderc.CompileOptions();
opts.SetOptimizationLevel(shaderc.optimization_level.performance);
this.compileOpts = opts;
this.bufferManager = new BufferManager(this.device);
}
floatPrecision() {
return 32;
}
setDataMover(dataMover) {
// TODO: tfjs team to implement this. Call GPUBuffer.destroy()
}
flushDisposalQueue() {
this.disposalQueue.forEach(d => {
this.releaseBuffer(d.buffer, d.byteSize, d.usage);
});
this.disposalQueue = [];
}
disposeData(dataId) {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
const info = this.tensorMap.get(dataId);
if (this.commandQueueOwnedIds.has(dataId)) {
this.disposalQueue.push(info.bufferInfo);
}
else {
this.releaseBuffer(info.bufferInfo.buffer, info.bufferInfo.byteSize, info.bufferInfo.usage);
}
this.tensorMap.delete(dataId);
}
memory() {
return {
numBytesInGPU: this.bufferManager.numBytesUsed,
unreliable: false
};
}
getBufferManager() {
return this.bufferManager;
}
acquireBuffer(byteSize, usage = DEFAULT_GPUBUFFER_USAGE) {
return this.bufferManager.acquireBuffer(byteSize, usage);
}
releaseBuffer(buffer, byteSize, usage) {
this.bufferManager.releaseBuffer(buffer, byteSize, usage);
}
register(dataId, shape, dtype) {
if (!this.tensorMap.has(dataId)) {
const byteSize = tf.util.sizeFromShape(shape) * tf.util.bytesPerElement(dtype);
const buffer = this.acquireBuffer(byteSize);
this.tensorMap.set(dataId, {
values: null,
id: -1,
dtype,
bufferInfo: { byteSize, usage: DEFAULT_GPUBUFFER_USAGE, buffer }
});
}
}
write(dataId, values) {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
const info = this.tensorMap.get(dataId);
info.values = values;
info.bufferInfo.buffer.setSubData(0, values);
this.tensorMap.set(dataId, info);
}
submitQueue() {
this.queue.submit(this.commandQueue.map(enc => enc.finish()));
this.commandQueue = [];
this.commandQueueOwnedIds = new WeakSet();
this.flushDisposalQueue();
}
async getBufferData(info) {
const staging = this.acquireBuffer(info.bufferInfo.byteSize, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
const encoder = this.device.createCommandEncoder({});
encoder.copyBufferToBuffer(info.bufferInfo.buffer, 0, staging, 0, info.bufferInfo.byteSize);
this.commandQueue.push(encoder);
this.submitQueue();
const mapped = await staging.mapReadAsync();
const values = mapped.slice(0);
staging.unmap();
this.releaseBuffer(staging, info.bufferInfo.byteSize, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
return values;
}
convertAndCacheOnCPU(dataId, data) {
const info = this.tensorMap.get(dataId);
info.values = data;
return info.values;
}
// TODO: Remove once this is fixed:
// https://github.com/tensorflow/tfjs/issues/1595
readSync(dataId) {
const texData = this.tensorMap.get(dataId);
const { values } = texData;
if (values == null) {
throw new Error('WebGPU readSync is only available for CPU-resident tensors.');
}
return values;
}
async read(dataId) {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
const info = this.tensorMap.get(dataId);
const data = await this.getBufferData(info);
const dataAsTypedArray = info.dtype === 'int32' ? new Int32Array(data) : new Float32Array(data);
this.convertAndCacheOnCPU(dataId, dataAsTypedArray);
return dataAsTypedArray;
}
async time(f) {
const oldActiveTimers = this.activeTimers;
const newActiveTimers = [];
let outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
}
else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
const flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map((d) => d.query))
.filter(d => d != null);
const flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map((d) => d.name))
.filter(d => d != null);
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
const res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: tf.util.sum(kernelMs),
getExtraProfileInfo: () => kernelMs.map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
.map(d => `${d.name}: ${d.ms}`)
.join(', '),
wallMs: null
};
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return res;
}
getAndSavePipeline(key, getBinary) {
if (!(key in this.binaryCache)) {
this.binaryCache[key] = getBinary();
}
return this.binaryCache[key];
}
makeOutputArray(shape, dtype) {
return tf.Tensor.make(shape, {}, dtype, this);
}
tensorToBinding(tensor) {
if (!tensor) {
return null;
}
const tensorData = this.tensorMap.get(tensor.dataId);
return {
resource: {
offset: 0,
size: tensor.size * tf.util.bytesPerElement(tensor.dtype),
buffer: tensorData.bufferInfo.buffer
}
};
}
startTimer() {
return { startMs: tf.util.now(), endMs: 0 };
}
endTimer(query) {
query.endMs = tf.util.now();
return query;
}
async getQueryTime(query) {
const timerQuery = query;
return timerQuery.endMs - timerQuery.startMs;
}
compileAndRun(program, inputs, output, programUniforms) {
if (output == null) {
output = this.makeOutputArray(program.outputShape, inputs[0].dtype);
}
let dimUniforms = [];
const bufferShapes = inputs.concat(output).map(d => d.shape);
let currentOffset = 0;
bufferShapes.forEach((d, i) => {
// Uniforms.
if (d.length === 0) {
d = [1];
}
// Complete std140 layout rules are documented here:
// tslint:disable-next-line:max-line-length
// https://www.khronos.org/registry/OpenGL/specs/gl/glspec45.core.pdf#page=159
let baseAlignment;
switch (d.length) {
case 0:
baseAlignment = 1;
break;
case 1:
baseAlignment = 1;
break;
case 2:
baseAlignment = 2;
break;
case 3:
baseAlignment = 4;
break;
case 4:
baseAlignment = 4;
break;
default:
tf.util.assert(false, () => `Unsupported ${d.length}D shape`);
}
const padding = Math.ceil(currentOffset / baseAlignment) * baseAlignment -
currentOffset;
for (let p = 0; p < padding; ++p) {
dimUniforms.push(0);
}
dimUniforms.push(...d);
currentOffset += d.length + padding;
});
// TODO: handle padding of program-specific uniforms
if (programUniforms) {
dimUniforms = dimUniforms.concat(programUniforms);
}
const uniformData = new Int32Array(dimUniforms);
const uniforms = this.makeUniforms(uniformData);
const key = makeShaderKey(program, bufferShapes.map(d => d.length));
const inputsData = inputs.map((input, i) => ({
// Returning dtype from tensorMap because it reflects dtype
// of underlying buffer, rather than abstract dtype.
dtype: this.tensorMap.get(input.dataId).dtype,
shape: input.shape,
name: program.variableNames[i]
}));
const { bindGroupLayout, pipeline } = this.getAndSavePipeline(key, () => {
return compileProgram(this.compiler, this.shaderc.shader_kind.compute, this.compileOpts, this.device, program, inputsData, output, uniforms);
});
const shouldTimeProgram = this.activeTimers != null;
let query;
if (shouldTimeProgram) {
query = this.startTimer();
}
// Creating bind groups on the fly should never be a bottleneck.
const bg = makeBindGroup(this.device, bindGroupLayout, inputs.map(t => this.tensorToBinding(t)), this.tensorToBinding(output), uniforms);
const encoder = this.device.createCommandEncoder({});
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg);
pass.dispatch(program.dispatch[0], program.dispatch[1], program.dispatch[2]);
pass.endPass();
this.commandQueue.push(encoder);
inputs.forEach(input => {
this.commandQueueOwnedIds.add(input.dataId);
});
this.commandQueueOwnedIds.add(output.dataId);
if (tf.ENV.get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) {
this.submitQueue();
}
this.releaseBuffer(uniforms.resource.buffer, uniformData.byteLength, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
if (shouldTimeProgram) {
query = this.endTimer(query);
this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
}
return output;
}
makeUniforms(data) {
const dimensionsBuffer = this.acquireBuffer(data.byteLength, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
dimensionsBuffer.setSubData(0, data);
return {
resource: { offset: 0, size: data.byteLength, buffer: dimensionsBuffer }
};
}
pad(x, paddings, constantValue) {
const program = new PadProgram(x.shape, paddings, constantValue);
const output = this.makeOutputArray(program.outputShape, x.dtype);
return this.compileAndRun(program, [x], output);
}
maxPool(x, convInfo) {
const program = new MaxPoolProgram(convInfo);
const output = this.makeOutputArray(program.outputShape, x.dtype);
const dimensions = [
convInfo.padInfo.left, convInfo.padInfo.top,
convInfo.strideWidth, convInfo.strideHeight,
convInfo.dilationWidth, convInfo.dilationHeight,
convInfo.inWidth, convInfo.inHeight,
convInfo.effectiveFilterWidth,
convInfo.effectiveFilterHeight // Filter dims.
];
return this.compileAndRun(program, [x], output, dimensions);
}
binaryOp(a, b, op) {
const dtype = tf.backend_util.upcastType(a.dtype, b.dtype);
const program = new BinaryOpProgram(op, a.shape, b.shape);
const output = tf.Tensor.make(program.outputShape, {}, dtype);
return this.compileAndRun(program, [a, b], output);
}
add(a, b) {
return this.binaryOp(a, b, ADD);
}
subtract(a, b) {
return this.binaryOp(a, b, SUB);
}
conv2d(x, filter, convInfo) {
const output = tf.Tensor.make(convInfo.outShape, {}, x.dtype, this);
let program;
const workPerThread = tf.ENV.get('WEBGPU_CONV2D_WORK_PER_THREAD');
if (workPerThread === -1) {
// TODO(kainino0x): This may be obsolete, but is kept for reference.
program = new Conv2DNaiveProgram(convInfo);
}
else {
program = new Conv2DMMProgram(convInfo, workPerThread);
}
const pad = convInfo.padInfo.type === 'VALID' ?
[0, 0] :
convInfo.padInfo.type === 'SAME' ?
[
-Math.floor((convInfo.filterShape[0] - 1) / 2),
-Math.floor((convInfo.filterShape[1] - 1) / 2)
] :
[convInfo.padInfo.top, convInfo.padInfo.left];
const dimensions = [
convInfo.filterHeight, convInfo.filterWidth, ...pad,
convInfo.strideHeight, convInfo.strideWidth
];
return this.compileAndRun(program, [x, filter], output, dimensions);
}
depthwiseConv2D(x, filter, convInfo) {
const program = new DepthwiseConv2DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
}
argMinMaxReduce(x, axis, reduceType) {
const program = new ArgMinMaxProgram(x.shape, axis, reduceType);
const output = this.makeOutputArray(program.outputShape, 'int32');
return this.compileAndRun(program, [x], output, [axis]);
}
argMin(x, axis) {
return this.argMinMaxReduce(x, axis, 'min');
}
argMax(x, axis) {
return this.argMinMaxReduce(x, axis, 'max');
}
concat(tensors, axis) {
if (tensors.length === 1) {
return tensors[0];
}
// Is there a maximum number of buffers that can be uploaded to a WebGPU
// program?
// if (tensors.length > MAX_SSBOS_FOR_WEBGPU_PROGRAM) {
// const midIndex = Math.floor(tensors.length / 2);
// const leftSide = this.concat(tensors.slice(0, midIndex), axis);
// const rightSide = this.concat(tensors.slice(midIndex), axis);
// return this.concat([leftSide, rightSide], axis);
// }
const outShape = tf.backend_util.computeOutShape(tensors.map(t => t.shape), axis);
const tensors2D = tensors.map(t => t.reshape([
tf.util.sizeFromShape(t.shape.slice(0, axis)),
tf.util.sizeFromShape(t.shape.slice(axis))
]));
const program = new ConcatProgram(tensors2D.map(t => t.shape));
const res = this.compileAndRun(program, tensors2D);
return res.reshape(outShape);
}
multiply(a, b) {
return this.binaryOp(a, b, MUL);
}
floorDiv(a, b) {
return this.binaryOp(a, b, INT_DIV);
}
sigmoid(x) {
const program = new UnaryOpProgram(x.shape, SIGMOID);
return this.compileAndRun(program, [x]);
}
relu(x) {
const program = new UnaryOpProgram(x.shape, RELU);
return this.compileAndRun(program, [x]);
}
resizeBilinear(x, newHeight, newWidth, alignCorners) {
const program = new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);
const output = this.makeOutputArray(program.outputShape, x.dtype);
return this.compileAndRun(program, [x], output);
}
reshape(x, shape) {
return tf.Tensor.make(shape, { dataId: x.dataId }, x.dtype);
}
cast(x, dtype) {
return tf.backend_util.castTensor(x, dtype, this);
}
transpose(x, perm) {
const program = new TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
}
batchMatMul(a, b, transposeA, transposeB) {
// TODO: Support transposed inputs.
// const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
// const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
const outerShapeA = a.shape[1];
const outerShapeB = b.shape[2];
const [batch, ,] = a.shape;
const output = tf.Tensor.make([batch, outerShapeA, outerShapeB], {}, a.dtype, this);
let program;
// TODO: We should eventually use the blocked version, but keeping around
// the old version while we try to understand conditions under which blocked
// is faster.
if (tf.ENV.get('WEBGPU_MATMUL_WORK_PER_THREAD') === 0) {
program = new MatMulProgram(output.shape);
}
else {
program = new MatMulPackedProgram(output.shape, tf.ENV.get('WEBGPU_MATMUL_WORK_PER_THREAD'));
}
return this.compileAndRun(program, [a, b], output);
}
fromPixels(pixels, numChannels) {
if (pixels == null) {
throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
}
const outShape = [pixels.height, pixels.width, numChannels];
let imageData = pixels.data;
if (tf.ENV.getBool('IS_BROWSER')) {
if (!(pixels instanceof HTMLVideoElement) &&
!(pixels instanceof HTMLImageElement) &&
!(pixels instanceof HTMLCanvasElement) &&
!(pixels instanceof ImageData) &&
!(pixels.data instanceof Uint8Array)) {
throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData` +
` or {data: Uint32Array, width: number, height: number}, ` +
`but was ${pixels.constructor.name}`);
}
if (pixels instanceof HTMLVideoElement ||
pixels instanceof HTMLImageElement ||
pixels instanceof HTMLCanvasElement) {
if (this.fromPixels2DContext == null) {
this.fromPixels2DContext =
document.createElement('canvas').getContext('2d');
this.fromPixels2DContext.canvas.width = pixels.width;
this.fromPixels2DContext.canvas.height = pixels.height;
}
this.fromPixels2DContext.drawImage(pixels, 0, 0, pixels.width, pixels.height);
pixels = this.fromPixels2DContext.canvas;
}
// TODO: Remove this once we figure out how to upload textures directly to
// WebGPU.
const imageDataLivesOnGPU = pixels instanceof HTMLVideoElement ||
pixels instanceof HTMLImageElement ||
pixels instanceof HTMLCanvasElement;
if (imageDataLivesOnGPU) {
imageData = this.fromPixels2DContext
.getImageData(0, 0, pixels.width, pixels.height)
.data;
}
}
// TODO: Encoding should happen on GPU once we no longer have to download
// image data to the CPU.
let pixelArray = imageData;
if (numChannels != null && numChannels !== 4) {
pixelArray = new Uint8Array(pixels.width * pixels.height * numChannels);
for (let i = 0; i < imageData.length; i++) {
if (i % 4 < numChannels) {
const pixelIndex = Math.floor(i / 4);
pixelArray[pixelIndex * numChannels + i % 4] = imageData[i];
}
}
}
const output = this.makeOutputArray(outShape, 'int32');
this.write(output.dataId, Int32Array.from(pixelArray));
return output;
}
dispose() {
if (this.disposed) {
return;
}
this.bufferManager.dispose();
this.disposed = true;
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
tf.registerBackend('webgpu', async () => {
const shaderc = await dist_1();
// @ts-ignore navigator.gpu is required
const adapter = await navigator.gpu.requestAdapter({});
const device = await adapter.requestDevice({});
return new WebGPUBackend(device, shaderc);
}, 3 /*priority*/);
})));
//# sourceMappingURL=tf-webgpu.js.map