mirror of https://github.com/vladmandic/human
4113 lines
262 KiB
JavaScript
4113 lines
262 KiB
JavaScript
![]() |
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
(function (global, factory) {
|
||
|
typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core'), require('path'), require('fs'), require('worker_threads'), require('perf_hooks')) :
|
||
|
typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core', 'path', 'fs', 'worker_threads', 'perf_hooks'], factory) :
|
||
|
(global = global || self, factory((global.tf = global.tf || {}, global.tf.wasm = global.tf.wasm || {}), global.tf, global.path, global.fs, global.worker_threads, global.perf_hooks));
|
||
|
}(this, (function (exports, tfjsCore, path, fs, worker_threads, perf_hooks) { 'use strict';
|
||
|
|
||
|
path = path && path.hasOwnProperty('default') ? path['default'] : path;
|
||
|
fs = fs && fs.hasOwnProperty('default') ? fs['default'] : fs;
|
||
|
worker_threads = worker_threads && worker_threads.hasOwnProperty('default') ? worker_threads['default'] : worker_threads;
|
||
|
perf_hooks = perf_hooks && perf_hooks.hasOwnProperty('default') ? perf_hooks['default'] : perf_hooks;
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
// This enum must align with the enum defined in cc/backend.h.
|
||
|
var CppDType;
|
||
|
(function (CppDType) {
|
||
|
CppDType[CppDType["float32"] = 0] = "float32";
|
||
|
CppDType[CppDType["int32"] = 1] = "int32";
|
||
|
CppDType[CppDType["bool"] = 2] = "bool";
|
||
|
CppDType[CppDType["string"] = 3] = "string";
|
||
|
CppDType[CppDType["complex64"] = 4] = "complex64";
|
||
|
})(CppDType || (CppDType = {}));
|
||
|
// Must match enum in cc/fusable_activations.h.
|
||
|
var FusableActivation;
|
||
|
(function (FusableActivation) {
|
||
|
FusableActivation[FusableActivation["linear"] = 0] = "linear";
|
||
|
FusableActivation[FusableActivation["relu"] = 1] = "relu";
|
||
|
FusableActivation[FusableActivation["relu6"] = 2] = "relu6";
|
||
|
FusableActivation[FusableActivation["prelu"] = 3] = "prelu";
|
||
|
})(FusableActivation || (FusableActivation = {}));
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFusedMatMul;
|
||
|
function setup(backend) {
|
||
|
wasmFusedMatMul = backend.wasm.cwrap(tfjsCore._FusedMatMul, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function fusedBatchMatMul(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { a, b, bias, preluActivationWeights } = inputs;
|
||
|
if (a.dtype !== 'float32' || b.dtype !== 'float32') {
|
||
|
throw new Error(`_FusedMatMul for non non-float32 tensors not yet supported.`);
|
||
|
}
|
||
|
const { transposeA, transposeB, activation } = attrs;
|
||
|
const aId = backend.dataIdMap.get(a.dataId).id;
|
||
|
const bId = backend.dataIdMap.get(b.dataId).id;
|
||
|
let biasId = 0;
|
||
|
if (bias != null) {
|
||
|
const biasData = backend.dataIdMap.get(bias.dataId);
|
||
|
if (biasData.shape.length !== 1) {
|
||
|
throw new Error(`_FusedMatMul only supports rank-1 bias but got ` +
|
||
|
`rank ${biasData.shape.length}.`);
|
||
|
}
|
||
|
biasId = biasData.id;
|
||
|
}
|
||
|
const preluActivationWeightsId = preluActivationWeights == null ?
|
||
|
0 :
|
||
|
backend.dataIdMap.get(preluActivationWeights.dataId).id;
|
||
|
const fusedActivation = FusableActivation[activation];
|
||
|
if (fusedActivation == null) {
|
||
|
throw new Error(`${activation} activation not yet supported for FusedConv2D ` +
|
||
|
`in the wasm backend.`);
|
||
|
}
|
||
|
const leftDim = transposeA ? a.shape[2] : a.shape[1];
|
||
|
const rightDim = transposeB ? b.shape[1] : b.shape[2];
|
||
|
const batchDim = a.shape[0];
|
||
|
const out = backend.makeOutput([batchDim, leftDim, rightDim], a.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
|
||
|
const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
|
||
|
wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const fusedMatMulConfig = {
|
||
|
kernelName: tfjsCore._FusedMatMul,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup,
|
||
|
kernelFunc: fusedBatchMatMul
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function createUnaryKernelConfig(kernelName) {
|
||
|
let wasmFunc;
|
||
|
function setupFunc(backend) {
|
||
|
wasmFunc =
|
||
|
backend.wasm.cwrap(kernelName, null /* void */, ['number', 'number']);
|
||
|
}
|
||
|
function kernelFunc(args) {
|
||
|
const { backend, inputs: { x } } = args;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
wasmFunc(xId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
return { kernelName, backendName: 'wasm', setupFunc, kernelFunc };
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const absConfig = createUnaryKernelConfig(tfjsCore.Abs);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function createBinaryKernelConfig(kernelName, supportsFullBroadcast, dtype) {
|
||
|
let wasmFunc;
|
||
|
function setupFunc(backend) {
|
||
|
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function kernelFunc(args) {
|
||
|
const { backend, inputs } = args;
|
||
|
const { a, b } = inputs;
|
||
|
const aId = backend.dataIdMap.get(a.dataId).id;
|
||
|
const bId = backend.dataIdMap.get(b.dataId).id;
|
||
|
const outputType = dtype != null ? dtype : a.dtype;
|
||
|
const newShape = tfjsCore.backend_util.assertAndGetBroadcastShape(a.shape, b.shape);
|
||
|
const out = backend.makeOutput(newShape, outputType);
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(newShape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
|
||
|
const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const kernelFunc = () => wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId);
|
||
|
// Currently only some float operations support full broadcast.
|
||
|
if (supportsFullBroadcast && a.dtype === 'float32') {
|
||
|
kernelFunc();
|
||
|
return out;
|
||
|
}
|
||
|
const aBroadcastDims = tfjsCore.backend_util.getBroadcastDims(a.shape, newShape);
|
||
|
const bBroadcastDims = tfjsCore.backend_util.getBroadcastDims(b.shape, newShape);
|
||
|
const loopsOverAllOfA = aBroadcastDims.every((v, i) => v === i);
|
||
|
const loopsOverAllOfB = bBroadcastDims.every((v, i) => v === i);
|
||
|
if (loopsOverAllOfA && loopsOverAllOfB) {
|
||
|
kernelFunc();
|
||
|
return out;
|
||
|
}
|
||
|
else {
|
||
|
throw new Error(`Broadcasting along outer dims is not yet ` +
|
||
|
`supported for ${a.dtype} ${kernelName}.`);
|
||
|
}
|
||
|
}
|
||
|
return { kernelName, backendName: 'wasm', setupFunc, kernelFunc };
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast = true;
|
||
|
const addConfig = createBinaryKernelConfig(tfjsCore.Add, supportsFullBroadcast);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc;
|
||
|
function setupFunc(backend) {
|
||
|
wasmFunc = backend.wasm.cwrap(tfjsCore.AddN, null /* void */, [
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function addn(args) {
|
||
|
const { inputs, backend } = args;
|
||
|
const out = backend.makeOutput(inputs[0].shape, inputs[0].dtype);
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const inputIds = inputs.map(x => backend.dataIdMap.get(x.dataId).id);
|
||
|
const inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmFunc(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId);
|
||
|
return out;
|
||
|
}
|
||
|
const addNConfig = {
|
||
|
kernelName: tfjsCore.AddN,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc,
|
||
|
kernelFunc: addn,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function identity(args) {
|
||
|
const { inputs: { x }, backend } = args;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const inVals = backend.typedArrayFromHeap(x);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
outVals.set(inVals);
|
||
|
return out;
|
||
|
}
|
||
|
const identityConfig = {
|
||
|
kernelName: tfjsCore.Identity,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: identity,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmTranspose;
|
||
|
function setup$1(backend) {
|
||
|
wasmTranspose = backend.wasm.cwrap(tfjsCore.Transpose, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function transpose(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
// Reduce any dimensions with size one. Lower-rank transpose kernel performs
|
||
|
// better due to simpler memory access pattern.
|
||
|
const [reducedShape, perm] = removeOneSizeDims(inputs.x.shape, attrs.perm);
|
||
|
let permIsNoOp = true;
|
||
|
for (let i = 0; i < perm.length; i++) {
|
||
|
if (perm[i] !== i) {
|
||
|
permIsNoOp = false;
|
||
|
}
|
||
|
}
|
||
|
const outShape = computeOutShape(inputs.x.shape, attrs.perm);
|
||
|
const x = {
|
||
|
dataId: inputs.x.dataId,
|
||
|
shape: reducedShape,
|
||
|
dtype: inputs.x.dtype
|
||
|
};
|
||
|
if (permIsNoOp) {
|
||
|
const cloned = identity({ inputs, backend });
|
||
|
cloned.shape = outShape;
|
||
|
return cloned;
|
||
|
}
|
||
|
const out = backend.makeOutput(outShape, x.dtype);
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const permBytes = new Uint8Array(new Int32Array(perm).buffer);
|
||
|
const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
|
||
|
wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length);
|
||
|
return out;
|
||
|
}
|
||
|
function computeOutShape(inShape, perm) {
|
||
|
const outShape = new Array(inShape.length);
|
||
|
for (let i = 0; i < outShape.length; i++) {
|
||
|
outShape[i] = inShape[perm[i]];
|
||
|
}
|
||
|
return outShape;
|
||
|
}
|
||
|
function removeOneSizeDims(shape, perm) {
|
||
|
const newShape = [];
|
||
|
const newPerm = [];
|
||
|
for (let i = 0; i < shape.length; ++i) {
|
||
|
if (shape[i] !== 1) {
|
||
|
newShape.push(shape[i]);
|
||
|
}
|
||
|
if (shape[perm[i]] !== 1) {
|
||
|
newPerm.push(perm[i]);
|
||
|
}
|
||
|
}
|
||
|
for (let i = 0; i < newPerm.length; ++i) {
|
||
|
let minValIdx = -1;
|
||
|
for (let j = 0; j < newPerm.length; ++j) {
|
||
|
if (newPerm[j] >= i &&
|
||
|
(minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) {
|
||
|
minValIdx = j;
|
||
|
}
|
||
|
}
|
||
|
newPerm[minValIdx] = i;
|
||
|
}
|
||
|
return [newShape, newPerm];
|
||
|
}
|
||
|
const transposeConfig = {
|
||
|
kernelName: tfjsCore.Transpose,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: transpose,
|
||
|
setupFunc: setup$1,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
/**
|
||
|
* Compute permutation axes and do a transpose if necessary.
|
||
|
*
|
||
|
* Used by reduction ops.
|
||
|
* @param x input TensorInfo
|
||
|
* @param axis reduction axes
|
||
|
* @param backend wasm backend instance
|
||
|
*/
|
||
|
function permuteAxesAndTranspose(x, axis, backend) {
|
||
|
const xShape = x.shape;
|
||
|
const xRank = x.shape.length;
|
||
|
const originalAxes = tfjsCore.util.parseAxisParam(axis, xShape);
|
||
|
let axes = originalAxes;
|
||
|
const permutedAxes = tfjsCore.backend_util.getAxesPermutation(axes, xRank);
|
||
|
let xTransposed = null;
|
||
|
let inputWasTransposed = false;
|
||
|
if (permutedAxes != null) {
|
||
|
const newShape = new Array(xRank);
|
||
|
for (let i = 0; i < newShape.length; i++) {
|
||
|
newShape[i] = xShape[permutedAxes[i]];
|
||
|
}
|
||
|
axes = tfjsCore.backend_util.getInnerMostAxes(axes.length, xRank);
|
||
|
xTransposed =
|
||
|
transpose({ inputs: { x }, attrs: { perm: permutedAxes }, backend });
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const transposedId = backend.dataIdMap.get(xTransposed.dataId).id;
|
||
|
if (transposedId !== xId) {
|
||
|
inputWasTransposed = true;
|
||
|
}
|
||
|
}
|
||
|
return { transposed: xTransposed, originalAxes, axes, inputWasTransposed };
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$1;
|
||
|
function setup$2(backend) {
|
||
|
wasmFunc$1 = backend.wasm.cwrap(tfjsCore.ArgMax, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function argmax(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { axis } = attrs;
|
||
|
const { x } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
let inputId = xId;
|
||
|
let input = x;
|
||
|
const { transposed, axes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
|
||
|
if (inputWasTransposed) {
|
||
|
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
|
||
|
if (transposedId !== xId) {
|
||
|
// transpose was not a no-op. We will need to dispose of this
|
||
|
// once we are done.
|
||
|
input = transposed;
|
||
|
inputId = transposedId;
|
||
|
}
|
||
|
}
|
||
|
const outShape = input.shape.slice(0, -1);
|
||
|
const out = backend.makeOutput(outShape, 'int32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const outerSize = tfjsCore.util.sizeFromShape(out.shape);
|
||
|
const innerSize = input.shape[axes[0]];
|
||
|
wasmFunc$1(inputId, CppDType[input.dtype], outerSize, innerSize, outId);
|
||
|
if (inputWasTransposed) {
|
||
|
// dispose of the transposed tensor.
|
||
|
backend.disposeData(transposed.dataId);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const argMaxConfig = {
|
||
|
kernelName: tfjsCore.ArgMax,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: argmax,
|
||
|
setupFunc: setup$2
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmAvgPool;
|
||
|
function setup$3(backend) {
|
||
|
wasmAvgPool = backend.wasm.cwrap(tfjsCore.AvgPool, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function avgPool(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const x = inputs.x;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
||
|
const convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const channels = convInfo.inChannels;
|
||
|
if (convInfo.dataFormat !== 'channelsLast') {
|
||
|
throw new Error(`wasm backend does not support dataFormat:'` +
|
||
|
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
|
||
|
}
|
||
|
if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) {
|
||
|
throw new Error(`was backend only supports average pooling with dilation = [1, 1], ` +
|
||
|
`got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const avgPoolConfig = {
|
||
|
kernelName: tfjsCore.AvgPool,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$3,
|
||
|
kernelFunc: avgPool
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function reshape(args) {
|
||
|
const { inputs, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const { shape } = attrs;
|
||
|
const xSize = tfjsCore.util.sizeFromShape(x.shape);
|
||
|
const $shape = tfjsCore.util.inferFromImplicitShape(shape, xSize);
|
||
|
tfjsCore.util.assert(xSize === tfjsCore.util.sizeFromShape($shape), () => `new shape: ${$shape}, old shape: ${x.shape}. New shape and old ` +
|
||
|
`shape must have the same number of elements.`);
|
||
|
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
|
||
|
}
|
||
|
const reshapeConfig = {
|
||
|
kernelName: tfjsCore.Reshape,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: reshape,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmBatchMatMul;
|
||
|
function setup$4(backend) {
|
||
|
wasmBatchMatMul = backend.wasm.cwrap(tfjsCore.BatchMatMul, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function batchMatMul(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { a, b } = inputs;
|
||
|
const { transposeA, transposeB } = attrs;
|
||
|
if (a.dtype !== 'float32' || b.dtype !== 'float32') {
|
||
|
throw new Error(`BatchMatMul for non non-float32 tensors not yet supported.`);
|
||
|
}
|
||
|
const aRank = a.shape.length;
|
||
|
const bRank = b.shape.length;
|
||
|
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
|
||
|
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
|
||
|
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
|
||
|
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
|
||
|
const outerDimsA = a.shape.slice(0, -2);
|
||
|
const outerDimsB = b.shape.slice(0, -2);
|
||
|
const batchDimA = tfjsCore.util.sizeFromShape(outerDimsA);
|
||
|
const batchDimB = tfjsCore.util.sizeFromShape(outerDimsB);
|
||
|
const batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
|
||
|
tfjsCore.util.assert(aRank >= 2 && bRank >= 2 && batchDimsCompatible, () => `Error in matMul: the input batch dimensions must either be the ` +
|
||
|
`same or at least one input batch dimension must be 1. Got input ` +
|
||
|
`batch dimensions of (${outerDimsA}) and (${outerDimsB}).`);
|
||
|
const outShapeOuterDims = batchDimA > batchDimB ? a.shape.slice(0, -2) : b.shape.slice(0, -2);
|
||
|
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
||
|
tfjsCore.util.assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
|
||
|
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
|
||
|
`${b.shape} and transposeA=${transposeA}` +
|
||
|
` and transposeB=${transposeB} must match.`);
|
||
|
const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
|
||
|
[batchDimA, outerShapeA, innerShapeA];
|
||
|
const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
|
||
|
[batchDimB, innerShapeB, outerShapeB];
|
||
|
// The rest of the implementation is designed to operate on rank-3 tensors
|
||
|
const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
|
||
|
const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
|
||
|
const a3dId = backend.dataIdMap.get(a3d.dataId).id;
|
||
|
const b3dId = backend.dataIdMap.get(b3d.dataId).id;
|
||
|
const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
|
||
|
const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
|
||
|
const batchDim = Math.max(batchDimA, batchDimB);
|
||
|
const out = backend.makeOutput([batchDim, leftDim, rightDim], a3d.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const aShapeBytes = new Uint8Array(new Int32Array(a3d.shape).buffer);
|
||
|
const bShapeBytes = new Uint8Array(new Int32Array(b3d.shape).buffer);
|
||
|
wasmBatchMatMul(a3dId, aShapeBytes, a3d.shape.length, b3dId, bShapeBytes, b3d.shape.length, transposeA, transposeB, outId);
|
||
|
out.shape = outShape;
|
||
|
return out;
|
||
|
}
|
||
|
const batchMatMulConfig = {
|
||
|
kernelName: tfjsCore.BatchMatMul,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$4,
|
||
|
kernelFunc: batchMatMul
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function cast(args) {
|
||
|
const { inputs: { x }, attrs: { dtype }, backend } = args;
|
||
|
const out = backend.makeOutput(x.shape, dtype);
|
||
|
const inVals = backend.typedArrayFromHeap(x);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
outVals.set(inVals);
|
||
|
return out;
|
||
|
}
|
||
|
const castConfig = {
|
||
|
kernelName: tfjsCore.Cast,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: cast,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmClip;
|
||
|
function setup$5(backend) {
|
||
|
wasmClip = backend.wasm.cwrap(tfjsCore.ClipByValue, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function clip(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const { clipValueMin, clipValueMax } = attrs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmClip(xId, clipValueMin, clipValueMax, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const clipByValueConfig = {
|
||
|
kernelName: tfjsCore.ClipByValue,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$5,
|
||
|
kernelFunc: clip
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function concat(args) {
|
||
|
const { inputs, backend } = args;
|
||
|
const axis = tfjsCore.util.parseAxisParam(args.attrs.axis, inputs[0].shape)[0];
|
||
|
const outShape = tfjsCore.backend_util.computeOutShape(inputs.map(t => t.shape), axis);
|
||
|
const out = backend.makeOutput(outShape, inputs[0].dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(outShape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
// Keep only non-empty tensors (ignore tensors with 0 in their shape).
|
||
|
const $inputs = inputs.filter(t => tfjsCore.util.sizeFromShape(t.shape) > 0);
|
||
|
if ($inputs.length === 1) {
|
||
|
return $inputs[0];
|
||
|
}
|
||
|
const shapes = $inputs.map(t => t.shape);
|
||
|
tfjsCore.backend_util.assertParamsConsistent(shapes, axis);
|
||
|
const batchDim = tfjsCore.util.sizeFromShape($inputs[0].shape.slice(0, axis));
|
||
|
let sumInnerDims = 0;
|
||
|
const innerDims = $inputs.map(input => {
|
||
|
const innerDim = tfjsCore.util.sizeFromShape(input.shape.slice(axis));
|
||
|
sumInnerDims += innerDim;
|
||
|
return innerDim;
|
||
|
});
|
||
|
const inVals = $inputs.map(input => backend.typedArrayFromHeap(input));
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
for (let b = 0; b < batchDim; b++) {
|
||
|
let outOffset = b * sumInnerDims;
|
||
|
for (let i = 0; i < inVals.length; i++) {
|
||
|
const innerDim = innerDims[i];
|
||
|
const inOffset = b * innerDim;
|
||
|
const vals = inVals[i].subarray(inOffset, inOffset + innerDim);
|
||
|
outVals.set(vals, outOffset);
|
||
|
outOffset += innerDim;
|
||
|
}
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const concatConfig = {
|
||
|
kernelName: tfjsCore.Concat,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: concat,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmConv2d;
|
||
|
function setup$6(backend) {
|
||
|
wasmConv2d = backend.wasm.cwrap(tfjsCore.Conv2D, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function conv2d(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const { x, filter } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const filterId = backend.dataIdMap.get(filter.dataId).id;
|
||
|
const { strides, dilations, pad, dimRoundingMode, dataFormat } = attrs;
|
||
|
const $dataFormat = tfjsCore.backend_util.convertConv2DDataFormat(dataFormat);
|
||
|
const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat);
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const inputChannels = convInfo.inChannels;
|
||
|
const outputChannels = convInfo.outChannels;
|
||
|
const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0;
|
||
|
if (convInfo.dataFormat !== 'channelsLast') {
|
||
|
throw new Error(`wasm backend Conv2D does not support dataFormat:'` +
|
||
|
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const conv2DConfig = {
|
||
|
kernelName: tfjsCore.Conv2D,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$6,
|
||
|
kernelFunc: conv2d
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmConv2DBackpropInput;
|
||
|
function setup$7(backend) {
|
||
|
wasmConv2DBackpropInput = backend.wasm.cwrap(tfjsCore.Conv2DBackpropInput, null, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function conv2DBackpropInput(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { dy, filter } = inputs;
|
||
|
const { strides, pad, dataFormat, dimRoundingMode, inputShape } = attrs;
|
||
|
const dilations = 1;
|
||
|
const $dataFormat = tfjsCore.backend_util.convertConv2DDataFormat(dataFormat);
|
||
|
const convInfo = tfjsCore.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
||
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
||
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
||
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
||
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
||
|
const dxStrides = tfjsCore.util.computeStrides(convInfo.inShape);
|
||
|
const dyStrides = tfjsCore.util.computeStrides(dy.shape);
|
||
|
const [fltS0, fltS1, fltS2] = tfjsCore.util.computeStrides(filter.shape);
|
||
|
const xBatchStride = dxStrides[0];
|
||
|
const xRowStride = isChannelsLast ? dxStrides[1] : dxStrides[2];
|
||
|
const xColStride = isChannelsLast ? dxStrides[2] : 1;
|
||
|
const xChannelStride = isChannelsLast ? 1 : dxStrides[1];
|
||
|
const yBatchStride = dyStrides[0];
|
||
|
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
|
||
|
const yColStride = isChannelsLast ? dyStrides[2] : 1;
|
||
|
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
|
||
|
const out = backend.makeOutput(convInfo.inShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const dyId = backend.dataIdMap.get(dy.dataId).id;
|
||
|
const filterId = backend.dataIdMap.get(filter.dataId).id;
|
||
|
wasmConv2DBackpropInput(dyId, filterId, batchSize, filterHeight, filterWidth, inHeight, inWidth, inChannels, outHeight, outWidth, outChannels, strideHeight, strideWidth, topPad, leftPad, fltS0, fltS1, fltS2, xBatchStride, xRowStride, xColStride, xChannelStride, yBatchStride, yRowStride, yColStride, yChannelStride, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const conv2DBackpropInputConfig = {
|
||
|
kernelName: tfjsCore.Conv2DBackpropInput,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$7,
|
||
|
kernelFunc: conv2DBackpropInput
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const cosConfig = createUnaryKernelConfig(tfjsCore.Cos);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
// Must match enum in CropAndResize.cc
|
||
|
var InterpolationMethod;
|
||
|
(function (InterpolationMethod) {
|
||
|
InterpolationMethod[InterpolationMethod["bilinear"] = 0] = "bilinear";
|
||
|
InterpolationMethod[InterpolationMethod["nearest"] = 1] = "nearest";
|
||
|
})(InterpolationMethod || (InterpolationMethod = {}));
|
||
|
let wasmCropAndResize;
|
||
|
function setup$8(backend) {
|
||
|
wasmCropAndResize = backend.wasm.cwrap(tfjsCore.CropAndResize, null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out id
|
||
|
]);
|
||
|
}
|
||
|
function cropAndResize(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { method, extrapolationValue, cropSize } = attrs;
|
||
|
const { image, boxes, boxInd } = inputs;
|
||
|
const numBoxes = boxes.shape[0];
|
||
|
const [cropHeight, cropWidth] = cropSize;
|
||
|
const outShape = [numBoxes, cropHeight, cropWidth, image.shape[3]];
|
||
|
let imagesData = backend.dataIdMap.get(image.dataId);
|
||
|
let castedData;
|
||
|
if (image.dtype !== 'float32') {
|
||
|
castedData = cast({ backend, inputs: { x: image }, attrs: { dtype: 'float32' } });
|
||
|
imagesData = backend.dataIdMap.get(castedData.dataId);
|
||
|
}
|
||
|
const imagesId = imagesData.id;
|
||
|
const boxesId = backend.dataIdMap.get(boxes.dataId).id;
|
||
|
const boxIndId = backend.dataIdMap.get(boxInd.dataId).id;
|
||
|
const out = backend.makeOutput(outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const imagesShapeBytes = new Uint8Array(new Int32Array(image.shape).buffer);
|
||
|
wasmCropAndResize(imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight, cropWidth, InterpolationMethod[method], extrapolationValue, outId);
|
||
|
if (castedData != null) {
|
||
|
backend.disposeData(castedData.dataId);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const cropAndResizeConfig = {
|
||
|
kernelName: tfjsCore.CropAndResize,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$8,
|
||
|
kernelFunc: cropAndResize
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmCumsum;
|
||
|
function setup$9(backend) {
|
||
|
wasmCumsum = backend.wasm.cwrap(tfjsCore.Cumsum, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // dtype
|
||
|
]);
|
||
|
}
|
||
|
function cumsum(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const { axis, exclusive, reverse } = attrs;
|
||
|
const xRank = x.shape.length;
|
||
|
tfjsCore.util.assert(x.dtype === 'float32' || x.dtype === 'int32', () => `cumsum does not support ${x.dtype} tensors in the WASM backend`);
|
||
|
// permute required axis to inner most axis
|
||
|
const permutation = tfjsCore.backend_util.getAxesPermutation([axis], xRank);
|
||
|
let permutedX = x;
|
||
|
if (permutation !== null) {
|
||
|
permutedX = transpose({ inputs: { x }, attrs: { perm: permutation }, backend });
|
||
|
}
|
||
|
const permutedAxis = tfjsCore.backend_util.getInnerMostAxes(1, xRank)[0];
|
||
|
tfjsCore.backend_util.assertAxesAreInnerMostDims('cumsum', [permutedAxis], xRank);
|
||
|
const permutedOut = backend.makeOutput(permutedX.shape, permutedX.dtype);
|
||
|
const finalDim = permutedX.shape[permutedAxis];
|
||
|
const permutedXId = backend.dataIdMap.get(permutedX.dataId).id;
|
||
|
const permutedOutId = backend.dataIdMap.get(permutedOut.dataId).id;
|
||
|
wasmCumsum(permutedXId, exclusive ? 1 : 0, reverse ? 1 : 0, finalDim, permutedOutId, CppDType[x.dtype]);
|
||
|
// transpose data back if permuted
|
||
|
let out = permutedOut;
|
||
|
if (permutation !== null) {
|
||
|
const undoPermutation = tfjsCore.backend_util.getUndoAxesPermutation(permutation);
|
||
|
out = transpose({ inputs: { x: permutedOut }, attrs: { perm: undoPermutation }, backend });
|
||
|
backend.disposeData(permutedX.dataId);
|
||
|
backend.disposeData(permutedOut.dataId);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const cumsumConfig = {
|
||
|
kernelName: tfjsCore.Cumsum,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$9,
|
||
|
kernelFunc: cumsum
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmDepthToSpace;
|
||
|
function setup$a(backend) {
|
||
|
wasmDepthToSpace = backend.wasm.cwrap(tfjsCore.DepthToSpace, null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'array',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function depthToSpace(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const { blockSize, dataFormat } = attrs;
|
||
|
tfjsCore.util.assert(blockSize > 1, () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);
|
||
|
const batchSize = x.shape[0];
|
||
|
const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
|
||
|
const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
|
||
|
const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
|
||
|
const outputHeight = inputHeight * blockSize;
|
||
|
const outputWidth = inputWidth * blockSize;
|
||
|
const outputDepth = inputDepth / (blockSize * blockSize);
|
||
|
const outputShape = (dataFormat === 'NHWC') ?
|
||
|
[batchSize, outputHeight, outputWidth, outputDepth] :
|
||
|
[batchSize, outputDepth, outputHeight, outputWidth];
|
||
|
const out = backend.makeOutput(outputShape, 'float32');
|
||
|
const xData = backend.dataIdMap.get(x.dataId);
|
||
|
const xId = xData.id;
|
||
|
const xStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(x.shape)).buffer);
|
||
|
const outputShapeBytes = new Uint8Array(new Int32Array(outputShape).buffer);
|
||
|
const outStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(outputShape)).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const channelsLast = dataFormat === 'NHWC' ? 1 : 0;
|
||
|
wasmDepthToSpace(xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, outputShapeBytes, outStridesBytes, outputShape.length, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const depthToSpaceConfig = {
|
||
|
kernelName: tfjsCore.DepthToSpace,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$a,
|
||
|
kernelFunc: depthToSpace
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmDepthwiseConv2d;
|
||
|
function setup$b(backend) {
|
||
|
wasmDepthwiseConv2d =
|
||
|
backend.wasm.cwrap(tfjsCore.DepthwiseConv2dNative, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function depthwiseConv2d(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const { x, filter } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const filterId = backend.dataIdMap.get(filter.dataId).id;
|
||
|
const { strides, dilations, pad, dimRoundingMode } = attrs;
|
||
|
const $dilations = dilations == null ? [1, 1] : dilations;
|
||
|
const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const inputChannels = convInfo.inChannels;
|
||
|
const outputChannels = convInfo.outChannels;
|
||
|
const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0;
|
||
|
if (convInfo.dataFormat !== 'channelsLast') {
|
||
|
throw new Error(`wasm backend DepthwiseConv2dNative does not support dataFormat:'` +
|
||
|
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmDepthwiseConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const depthwiseConv2dNativeConfig = {
|
||
|
kernelName: tfjsCore.DepthwiseConv2dNative,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$b,
|
||
|
kernelFunc: depthwiseConv2d
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$1 = true;
|
||
|
const divConfig = createBinaryKernelConfig(tfjsCore.Div, supportsFullBroadcast$1);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$2 = false;
|
||
|
const equalConfig = createBinaryKernelConfig(tfjsCore.Equal, supportsFullBroadcast$2, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const expConfig = createUnaryKernelConfig(tfjsCore.Exp);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function fill(args) {
|
||
|
const { attrs: { shape, value, dtype }, backend } = args;
|
||
|
const out = backend.makeOutput(shape, dtype);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
outVals.fill(value);
|
||
|
return out;
|
||
|
}
|
||
|
const fillConfig = {
|
||
|
kernelName: tfjsCore.Fill,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: fill,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFlipLeftRight;
|
||
|
function setup$c(backend) {
|
||
|
wasmFlipLeftRight = backend.wasm.cwrap(tfjsCore.FlipLeftRight, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function flipLeftRight(args) {
|
||
|
const { inputs, backend } = args;
|
||
|
const { image } = inputs;
|
||
|
const out = backend.makeOutput(image.shape, image.dtype);
|
||
|
const imageId = backend.dataIdMap.get(image.dataId).id;
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
||
|
wasmFlipLeftRight(imageId, batch, imageHeight, imageWidth, numChannels, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const flipLeftRightConfig = {
|
||
|
kernelName: tfjsCore.FlipLeftRight,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: flipLeftRight,
|
||
|
setupFunc: setup$c
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$3 = false;
|
||
|
const floorDivConfig = createBinaryKernelConfig(tfjsCore.FloorDiv, supportsFullBroadcast$3);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmBatchNorm;
|
||
|
function setup$d(backend) {
|
||
|
wasmBatchNorm = backend.wasm.cwrap(tfjsCore.FusedBatchNorm, null /* void */, ['number', 'number', 'number', 'number', 'number', 'number', 'number']);
|
||
|
}
|
||
|
function fusedBatchNorm(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { varianceEpsilon } = attrs;
|
||
|
const { x, mean, variance, offset, scale } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const meanId = backend.dataIdMap.get(mean.dataId).id;
|
||
|
const varianceId = backend.dataIdMap.get(variance.dataId).id;
|
||
|
const offsetId = offset != null ? backend.dataIdMap.get(offset.dataId).id : 0;
|
||
|
const scaleId = scale != null ? backend.dataIdMap.get(scale.dataId).id : 0;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(x.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmBatchNorm(xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const fusedBatchNormConfig = {
|
||
|
kernelName: tfjsCore.FusedBatchNorm,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$d,
|
||
|
kernelFunc: fusedBatchNorm
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFusedConv2d;
|
||
|
function setup$e(backend) {
|
||
|
wasmFusedConv2d = backend.wasm.cwrap(tfjsCore.FusedConv2D, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function fusedConv2d(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
||
|
const { strides, pad, dilations, dataFormat, dimRoundingMode, activation } = attrs;
|
||
|
const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode);
|
||
|
const fusedActivation = FusableActivation[activation];
|
||
|
if (fusedActivation == null) {
|
||
|
throw new Error(`${activation} activation not yet supported for FusedConv2D ` +
|
||
|
`in the wasm backend.`);
|
||
|
}
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const filterId = backend.dataIdMap.get(filter.dataId).id;
|
||
|
const outputChannels = convInfo.outChannels;
|
||
|
let biasId = 0;
|
||
|
if (bias != null) {
|
||
|
const biasData = backend.dataIdMap.get(bias.dataId);
|
||
|
if (biasData.shape.length !== 1) {
|
||
|
throw new Error(`FusedConv2D only supports rank-1 bias but got ` +
|
||
|
`rank ${biasData.shape.length}.`);
|
||
|
}
|
||
|
if (biasData.shape[0] !== outputChannels) {
|
||
|
throw new Error(`FusedConv2D bias shape (${biasData.shape}) does not ` +
|
||
|
`match the number of output channels (${outputChannels})`);
|
||
|
}
|
||
|
biasId = biasData.id;
|
||
|
}
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const inputChannels = convInfo.inChannels;
|
||
|
const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0;
|
||
|
const batchSize = convInfo.batchSize;
|
||
|
const inHeight = convInfo.inHeight;
|
||
|
const inWidth = convInfo.inWidth;
|
||
|
if (dataFormat !== 'NHWC') {
|
||
|
throw new Error(`wasm backend FusedConv2D does not support dataFormat:'` +
|
||
|
`${dataFormat}'. Please use 'NHWC'.`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const preluActivationWeightsId = preluActivationWeights == null ?
|
||
|
0 :
|
||
|
backend.dataIdMap.get(preluActivationWeights.dataId).id;
|
||
|
wasmFusedConv2d(xId, batchSize, inHeight, inWidth, filterId, filterHeight, filterWidth, biasId, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, fusedActivation, preluActivationWeightsId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const fusedConv2DConfig = {
|
||
|
kernelName: tfjsCore.FusedConv2D,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$e,
|
||
|
kernelFunc: fusedConv2d
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFusedDepthwiseConv2d;
|
||
|
function setup$f(backend) {
|
||
|
wasmFusedDepthwiseConv2d =
|
||
|
backend.wasm.cwrap(tfjsCore.FusedDepthwiseConv2D, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function fusedDepthwiseConv2d(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
||
|
const { strides, pad, dilations, dataFormat, dimRoundingMode, activation } = attrs;
|
||
|
const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
||
|
const fusedActivation = FusableActivation[activation];
|
||
|
if (fusedActivation == null) {
|
||
|
throw new Error(`${activation} activation not yet supported for FusedDepthwiseConv2D ` +
|
||
|
`in the wasm backend.`);
|
||
|
}
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const filterId = backend.dataIdMap.get(filter.dataId).id;
|
||
|
const outputChannels = convInfo.outChannels;
|
||
|
let biasId = 0;
|
||
|
if (bias != null) {
|
||
|
const biasData = backend.dataIdMap.get(bias.dataId);
|
||
|
if (biasData.shape.length !== 1) {
|
||
|
throw new Error(`FusedDepthwiseConv2D only supports rank-1 bias but got ` +
|
||
|
`rank ${biasData.shape.length}.`);
|
||
|
}
|
||
|
if (biasData.shape[0] !== outputChannels) {
|
||
|
throw new Error(`FusedDepthwiseConv2D bias shape (${biasData.shape}) does not ` +
|
||
|
`match the number of output channels (${outputChannels})`);
|
||
|
}
|
||
|
biasId = biasData.id;
|
||
|
}
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const inputChannels = convInfo.inChannels;
|
||
|
const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0;
|
||
|
const batchSize = convInfo.batchSize;
|
||
|
const inHeight = convInfo.inHeight;
|
||
|
const inWidth = convInfo.inWidth;
|
||
|
if (dataFormat !== 'NHWC') {
|
||
|
throw new Error(`wasm backend FusedDepthwiseConv2D does not support dataFormat:'` +
|
||
|
`${dataFormat}'. Please use 'NHWC'.`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const preluActivationWeightsId = preluActivationWeights == null ?
|
||
|
0 :
|
||
|
backend.dataIdMap.get(preluActivationWeights.dataId).id;
|
||
|
wasmFusedDepthwiseConv2d(xId, batchSize, inHeight, inWidth, filterId, filterHeight, filterWidth, biasId, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, fusedActivation, preluActivationWeightsId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const fusedDepthwiseConv2DConfig = {
|
||
|
kernelName: tfjsCore.FusedDepthwiseConv2D,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$f,
|
||
|
kernelFunc: fusedDepthwiseConv2d
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmGatherNd;
|
||
|
function setup$g(backend) {
|
||
|
wasmGatherNd = backend.wasm.cwrap(tfjsCore.GatherNd, null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number' // outId
|
||
|
]);
|
||
|
}
|
||
|
function gatherNd(args) {
|
||
|
const { backend, inputs } = args;
|
||
|
const { params, indices } = inputs;
|
||
|
const [resultShape, numSlices, sliceSize, strides] = tfjsCore.gather_util.prepareAndValidate(params, indices);
|
||
|
const out = backend.makeOutput(resultShape, params.dtype);
|
||
|
if (numSlices === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const indicesShape = indices.shape;
|
||
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
||
|
const xData = backend.dataIdMap.get(params.dataId);
|
||
|
const xId = xData.id;
|
||
|
const indicesData = backend.dataIdMap.get(indices.dataId);
|
||
|
const indicesId = indicesData.id;
|
||
|
const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmGatherNd(xId, CppDType[params.dtype], indicesId, numSlices, sliceRank, sliceSize, stridesBytes, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const gatherNdConfig = {
|
||
|
kernelName: tfjsCore.GatherNd,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$g,
|
||
|
kernelFunc: gatherNd
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmGather;
|
||
|
function setup$h(backend) {
|
||
|
wasmGather = backend.wasm.cwrap('Gather', null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number' // outId
|
||
|
]);
|
||
|
}
|
||
|
function gatherV2(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { x, indices } = inputs;
|
||
|
const { axis } = attrs;
|
||
|
const newShape = x.shape.slice();
|
||
|
newShape[axis] = tfjsCore.util.sizeFromShape(indices.shape);
|
||
|
const stridesSize = x.shape.length - 1;
|
||
|
const out = backend.makeOutput(newShape, x.dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(x.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const xData = backend.dataIdMap.get(x.dataId);
|
||
|
const xId = xData.id;
|
||
|
const indicesData = backend.dataIdMap.get(indices.dataId);
|
||
|
const indicesId = indicesData.id;
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const xStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(x.shape)).buffer);
|
||
|
const outStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(newShape)).buffer);
|
||
|
wasmGather(xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, axis, outStridesBytes, outId);
|
||
|
// reshape
|
||
|
const parsedAxis = tfjsCore.util.parseAxisParam(axis, x.shape)[0];
|
||
|
const shapeInfo = tfjsCore.backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis);
|
||
|
out.shape = shapeInfo.outputShape;
|
||
|
return out;
|
||
|
}
|
||
|
const gatherV2Config = {
|
||
|
kernelName: tfjsCore.GatherV2,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$h,
|
||
|
kernelFunc: gatherV2
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$4 = false;
|
||
|
const greaterConfig = createBinaryKernelConfig(tfjsCore.Greater, supportsFullBroadcast$4, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$5 = false;
|
||
|
const greaterEqualConfig = createBinaryKernelConfig(tfjsCore.GreaterEqual, supportsFullBroadcast$5, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$6 = false;
|
||
|
const lessConfig = createBinaryKernelConfig(tfjsCore.Less, supportsFullBroadcast$6, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$7 = false;
|
||
|
const lessEqualConfig = createBinaryKernelConfig(tfjsCore.LessEqual, supportsFullBroadcast$7, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const logConfig = createUnaryKernelConfig(tfjsCore.Log);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$8 = false;
|
||
|
const logicalAndConfig = createBinaryKernelConfig(tfjsCore.LogicalAnd, supportsFullBroadcast$8, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmMax;
|
||
|
function setup$i(backend) {
|
||
|
wasmMax = backend.wasm.cwrap(tfjsCore.Max, null /*void*/, ['number, number, number']);
|
||
|
}
|
||
|
function max(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { reductionIndices: axis, keepDims } = attrs;
|
||
|
const { x } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
let inputId = xId;
|
||
|
let input = x;
|
||
|
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
|
||
|
if (inputWasTransposed) {
|
||
|
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
|
||
|
input = transposed;
|
||
|
inputId = transposedId;
|
||
|
}
|
||
|
const inputRank = input.shape.length;
|
||
|
tfjsCore.backend_util.assertAxesAreInnerMostDims('max', axes, inputRank);
|
||
|
const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes);
|
||
|
const reduceSize = tfjsCore.util.sizeFromShape(reduceShape);
|
||
|
const out = backend.makeOutput(outShape, x.dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(input.shape) !== 0) {
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmMax(inputId, reduceSize, outId);
|
||
|
}
|
||
|
if (inputWasTransposed) {
|
||
|
// dispose of the transposed tensor.
|
||
|
backend.disposeData(transposed.dataId);
|
||
|
}
|
||
|
if (keepDims) {
|
||
|
// reshape
|
||
|
const newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes);
|
||
|
out.shape = newShape;
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const maxConfig = {
|
||
|
kernelName: tfjsCore.Max,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$i,
|
||
|
kernelFunc: max
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$9 = false;
|
||
|
const maximumConfig = createBinaryKernelConfig(tfjsCore.Maximum, supportsFullBroadcast$9);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmMaxPool;
|
||
|
function setup$j(backend) {
|
||
|
wasmMaxPool = backend.wasm.cwrap(tfjsCore.MaxPool, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function maxPool(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const x = inputs.x;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
||
|
const convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
||
|
const filterHeight = convInfo.filterHeight;
|
||
|
const filterWidth = convInfo.filterWidth;
|
||
|
const padTop = convInfo.padInfo.top;
|
||
|
const padRight = convInfo.padInfo.right;
|
||
|
const padBottom = convInfo.padInfo.bottom;
|
||
|
const padLeft = convInfo.padInfo.left;
|
||
|
const dilationHeight = convInfo.dilationHeight;
|
||
|
const dilationWidth = convInfo.dilationWidth;
|
||
|
const strideHeight = convInfo.strideHeight;
|
||
|
const strideWidth = convInfo.strideWidth;
|
||
|
const inputChannels = convInfo.inChannels;
|
||
|
const outputChannels = convInfo.outChannels;
|
||
|
if (convInfo.dataFormat !== 'channelsLast') {
|
||
|
throw new Error(`wasm backend does not support dataFormat:'` +
|
||
|
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
|
||
|
}
|
||
|
const out = backend.makeOutput(convInfo.outShape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmMaxPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const maxPoolConfig = {
|
||
|
kernelName: tfjsCore.MaxPool,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$j,
|
||
|
kernelFunc: maxPool
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmMin;
|
||
|
function setup$k(backend) {
|
||
|
wasmMin = backend.wasm.cwrap(tfjsCore.Min, null /*void*/, ['number, number, number']);
|
||
|
}
|
||
|
function min(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { axis, keepDims } = attrs;
|
||
|
const { x } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
let inputId = xId;
|
||
|
let input = x;
|
||
|
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
|
||
|
if (inputWasTransposed) {
|
||
|
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
|
||
|
if (transposedId !== xId) {
|
||
|
// transpose was not a no-op. We will need to dispose of this
|
||
|
// once we are done.
|
||
|
input = transposed;
|
||
|
inputId = transposedId;
|
||
|
}
|
||
|
}
|
||
|
const inputRank = input.shape.length;
|
||
|
tfjsCore.backend_util.assertAxesAreInnerMostDims('min', axes, inputRank);
|
||
|
const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes);
|
||
|
const reduceSize = tfjsCore.util.sizeFromShape(reduceShape);
|
||
|
const out = backend.makeOutput(outShape, input.dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(input.shape) !== 0) {
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmMin(inputId, reduceSize, outId);
|
||
|
}
|
||
|
if (inputWasTransposed) {
|
||
|
// dispose of the transposed tensor.
|
||
|
backend.disposeData(transposed.dataId);
|
||
|
}
|
||
|
if (keepDims) {
|
||
|
// reshape
|
||
|
const newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes);
|
||
|
out.shape = newShape;
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const minConfig = {
|
||
|
kernelName: tfjsCore.Min,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$k,
|
||
|
kernelFunc: min
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$a = false;
|
||
|
const minimumConfig = createBinaryKernelConfig(tfjsCore.Minimum, supportsFullBroadcast$a);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$b = true;
|
||
|
const multiplyConfig = createBinaryKernelConfig(tfjsCore.Multiply, supportsFullBroadcast$b);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const negateConfig = createUnaryKernelConfig(tfjsCore.Negate);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
/**
|
||
|
* Parse the result of the c++ method, which has the shape equivalent to
|
||
|
* `Result`.
|
||
|
*/
|
||
|
function parseResultStruct(backend, resOffset) {
|
||
|
const result = new Int32Array(backend.wasm.HEAPU8.buffer, resOffset, 4);
|
||
|
const pSelectedIndices = result[0];
|
||
|
const selectedSize = result[1];
|
||
|
const pSelectedScores = result[2];
|
||
|
const pValidOutputs = result[3];
|
||
|
// Since the result was allocated on the heap, we have to delete it.
|
||
|
backend.wasm._free(resOffset);
|
||
|
return { pSelectedIndices, selectedSize, pSelectedScores, pValidOutputs };
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$2;
|
||
|
function setup$l(backend) {
|
||
|
wasmFunc$2 = backend.wasm.cwrap(tfjsCore.NonMaxSuppressionV3, 'number', // Result*
|
||
|
[
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function kernelFunc(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { iouThreshold, maxOutputSize, scoreThreshold } = attrs;
|
||
|
const { boxes, scores } = inputs;
|
||
|
const boxesId = backend.dataIdMap.get(boxes.dataId).id;
|
||
|
const scoresId = backend.dataIdMap.get(scores.dataId).id;
|
||
|
const resOffset = wasmFunc$2(boxesId, scoresId, maxOutputSize, iouThreshold, scoreThreshold);
|
||
|
const { pSelectedIndices, selectedSize, pSelectedScores, pValidOutputs } = parseResultStruct(backend, resOffset);
|
||
|
// Since we are not using scores for V3, we have to delete it from the heap.
|
||
|
backend.wasm._free(pSelectedScores);
|
||
|
backend.wasm._free(pValidOutputs);
|
||
|
const selectedIndicesTensor = backend.makeOutput([selectedSize], 'int32', pSelectedIndices);
|
||
|
return selectedIndicesTensor;
|
||
|
}
|
||
|
const nonMaxSuppressionV3Config = {
|
||
|
kernelName: tfjsCore.NonMaxSuppressionV3,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$l,
|
||
|
kernelFunc: kernelFunc,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$3;
|
||
|
function setup$m(backend) {
|
||
|
wasmFunc$3 = backend.wasm.cwrap(tfjsCore.NonMaxSuppressionV4, 'number', // Result*
|
||
|
[
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'bool',
|
||
|
]);
|
||
|
}
|
||
|
function nonMaxSuppressionV4(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { iouThreshold, maxOutputSize, scoreThreshold, padToMaxOutputSize } = attrs;
|
||
|
const { boxes, scores } = inputs;
|
||
|
const boxesId = backend.dataIdMap.get(boxes.dataId).id;
|
||
|
const scoresId = backend.dataIdMap.get(scores.dataId).id;
|
||
|
const resOffset = wasmFunc$3(boxesId, scoresId, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
|
||
|
const { pSelectedIndices, selectedSize, pSelectedScores, pValidOutputs } = parseResultStruct(backend, resOffset);
|
||
|
// Since we are not using scores for V4, we have to delete it from the heap.
|
||
|
backend.wasm._free(pSelectedScores);
|
||
|
const selectedIndicesTensor = backend.makeOutput([selectedSize], 'int32', pSelectedIndices);
|
||
|
const validOutputsTensor = backend.makeOutput([], 'int32', pValidOutputs);
|
||
|
return [selectedIndicesTensor, validOutputsTensor];
|
||
|
}
|
||
|
const nonMaxSuppressionV4Config = {
|
||
|
kernelName: tfjsCore.NonMaxSuppressionV4,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$m,
|
||
|
kernelFunc: nonMaxSuppressionV4,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$4;
|
||
|
function setup$n(backend) {
|
||
|
wasmFunc$4 = backend.wasm.cwrap(tfjsCore.NonMaxSuppressionV5, 'number', // Result*
|
||
|
[
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function kernelFunc$1(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { iouThreshold, maxOutputSize, scoreThreshold, softNmsSigma } = attrs;
|
||
|
const { boxes, scores } = inputs;
|
||
|
const boxesId = backend.dataIdMap.get(boxes.dataId).id;
|
||
|
const scoresId = backend.dataIdMap.get(scores.dataId).id;
|
||
|
const resOffset = wasmFunc$4(boxesId, scoresId, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
|
||
|
const { pSelectedIndices, selectedSize, pSelectedScores, pValidOutputs } = parseResultStruct(backend, resOffset);
|
||
|
// Since we are not using validOutputs for V5, we have to delete it from the
|
||
|
// heap.
|
||
|
backend.wasm._free(pValidOutputs);
|
||
|
const selectedIndicesTensor = backend.makeOutput([selectedSize], 'int32', pSelectedIndices);
|
||
|
const selectedScoresTensor = backend.makeOutput([selectedSize], 'float32', pSelectedScores);
|
||
|
return [selectedIndicesTensor, selectedScoresTensor];
|
||
|
}
|
||
|
const nonMaxSuppressionV5Config = {
|
||
|
kernelName: tfjsCore.NonMaxSuppressionV5,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$n,
|
||
|
kernelFunc: kernelFunc$1,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$c = false;
|
||
|
const notEqualConfig = createBinaryKernelConfig(tfjsCore.NotEqual, supportsFullBroadcast$c, 'bool');
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmOneHot;
|
||
|
function setup$o(backend) {
|
||
|
wasmOneHot = backend.wasm.cwrap(tfjsCore.OneHot, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function oneHot(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { indices } = inputs;
|
||
|
const { depth, onValue, offValue } = attrs;
|
||
|
const out = backend.makeOutput([...indices.shape, depth], 'int32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const indicesData = backend.dataIdMap.get(indices.dataId);
|
||
|
const indicesId = indicesData.id;
|
||
|
wasmOneHot(indicesId, depth, onValue, offValue, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const oneHotConfig = {
|
||
|
kernelName: tfjsCore.OneHot,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$o,
|
||
|
kernelFunc: oneHot,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function onesLike(args) {
|
||
|
const { inputs: { x }, backend } = args;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
outVals.fill(1);
|
||
|
return out;
|
||
|
}
|
||
|
const onesLikeConfig = {
|
||
|
kernelName: tfjsCore.OnesLike,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: onesLike,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmPadV2;
|
||
|
function setup$p(backend) {
|
||
|
wasmPadV2 = backend.wasm.cwrap(tfjsCore.PadV2, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function pad(args) {
|
||
|
const { inputs: { x }, backend, attrs: { paddings, constantValue } } = args;
|
||
|
const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const out = backend.makeOutput(outShape, x.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
|
||
|
const prePaddingsFlat = paddings.map(padTuple => padTuple[0]);
|
||
|
const postPaddingsFlat = paddings.map(padTuple => padTuple[1]);
|
||
|
const prePaddingsBytes = new Uint8Array(new Int32Array(prePaddingsFlat).buffer);
|
||
|
const postPaddingsBytes = new Uint8Array(new Int32Array(postPaddingsFlat).buffer);
|
||
|
wasmPadV2(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], prePaddingsBytes, postPaddingsBytes, constantValue, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const padV2Config = {
|
||
|
kernelName: tfjsCore.PadV2,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: pad,
|
||
|
setupFunc: setup$p
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$d = false;
|
||
|
const powConfig = createBinaryKernelConfig(tfjsCore.Pow, supportsFullBroadcast$d);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmPrelu;
|
||
|
function setup$q(backend) {
|
||
|
wasmPrelu = backend.wasm.cwrap(tfjsCore.Prelu, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function prelu(args) {
|
||
|
const { inputs, backend } = args;
|
||
|
const { x, alpha } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const weightsId = backend.dataIdMap.get(alpha.dataId).id;
|
||
|
const out = backend.makeOutput(x.shape, 'float32');
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmPrelu(xId, weightsId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const preluConfig = {
|
||
|
kernelName: tfjsCore.Prelu,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$q,
|
||
|
kernelFunc: prelu
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const reluConfig = createUnaryKernelConfig(tfjsCore.Relu);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const relu6Config = createUnaryKernelConfig(tfjsCore.Relu6);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmResizeBilinear;
|
||
|
function setup$r(backend) {
|
||
|
wasmResizeBilinear = backend.wasm.cwrap(tfjsCore.ResizeBilinear, null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // outId
|
||
|
]);
|
||
|
}
|
||
|
function resizeBilinear(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { images } = inputs;
|
||
|
const { alignCorners, size } = attrs;
|
||
|
const [newHeight, newWidth] = size;
|
||
|
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
|
||
|
const outShape = [batch, newHeight, newWidth, numChannels];
|
||
|
let xData = backend.dataIdMap.get(images.dataId);
|
||
|
let castedData;
|
||
|
if (xData.dtype !== 'float32') {
|
||
|
castedData =
|
||
|
cast({ backend, inputs: { x: images }, attrs: { dtype: 'float32' } });
|
||
|
xData = backend.dataIdMap.get(castedData.dataId);
|
||
|
}
|
||
|
const xId = xData.id;
|
||
|
const out = backend.makeOutput(outShape, 'float32');
|
||
|
if (tfjsCore.util.sizeFromShape(images.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmResizeBilinear(xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth, alignCorners ? 1 : 0, outId);
|
||
|
if (castedData != null) {
|
||
|
backend.disposeData(castedData.dataId);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const resizeBilinearConfig = {
|
||
|
kernelName: tfjsCore.ResizeBilinear,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$r,
|
||
|
kernelFunc: resizeBilinear
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmReverse;
|
||
|
function setup$s(backend) {
|
||
|
wasmReverse = backend.wasm.cwrap(tfjsCore.Reverse, null, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function reverse(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const { dims } = attrs;
|
||
|
const axes = tfjsCore.util.parseAxisParam(dims, x.shape);
|
||
|
if (x.shape.length === 0) {
|
||
|
return identity({ inputs: { x }, backend });
|
||
|
}
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const axesBytes = new Uint8Array(new Int32Array(axes).buffer);
|
||
|
const outShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
|
||
|
wasmReverse(xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId);
|
||
|
return reshape({ inputs: { x: out }, attrs: { shape: x.shape }, backend });
|
||
|
}
|
||
|
const reverseConfig = {
|
||
|
kernelName: tfjsCore.Reverse,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: reverse,
|
||
|
setupFunc: setup$s
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmRotate;
|
||
|
function setup$t(backend) {
|
||
|
wasmRotate = backend.wasm.cwrap(tfjsCore.RotateWithOffset, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function rotateWithOffset(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { image } = inputs;
|
||
|
const { radians, fillValue, center } = attrs;
|
||
|
const out = backend.makeOutput(image.shape, image.dtype);
|
||
|
const imageId = backend.dataIdMap.get(image.dataId).id;
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
||
|
const [centerX, centerY] = tfjsCore.backend_util.getImageCenter(center, imageHeight, imageWidth);
|
||
|
const fillIsBlack = fillValue === 0;
|
||
|
const fullOpacityValue = 255;
|
||
|
const fillValues = typeof fillValue === 'number' ?
|
||
|
[fillValue, fillValue, fillValue, fillIsBlack ? 0 : fullOpacityValue] :
|
||
|
[...fillValue, fullOpacityValue];
|
||
|
const fillBytes = new Uint8Array(new Int32Array(fillValues).buffer);
|
||
|
wasmRotate(imageId, batch, imageHeight, imageWidth, numChannels, radians, centerX, centerY, fillBytes, fillValues.length, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const rotateWithOffsetConfig = {
|
||
|
kernelName: tfjsCore.RotateWithOffset,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: rotateWithOffset,
|
||
|
setupFunc: setup$t
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const rsqrtConfig = createUnaryKernelConfig(tfjsCore.Rsqrt);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmScatterNd;
|
||
|
function setup$u(backend) {
|
||
|
wasmScatterNd = backend.wasm.cwrap(tfjsCore.ScatterNd, null /*void*/, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number' // outId
|
||
|
]);
|
||
|
}
|
||
|
function scatterNd(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { indices, updates } = inputs;
|
||
|
const { shape } = attrs;
|
||
|
const out = backend.makeOutput(shape, updates.dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = tfjsCore.scatter_util.calculateShapes(updates, indices, shape);
|
||
|
const indicesData = backend.dataIdMap.get(indices.dataId);
|
||
|
const indicesId = indicesData.id;
|
||
|
const updatesData = backend.dataIdMap.get(updates.dataId);
|
||
|
const updatesId = updatesData.id;
|
||
|
const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmScatterNd(indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates, sliceSize, stridesBytes, outputSize, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const scatterNdConfig = {
|
||
|
kernelName: tfjsCore.ScatterNd,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$u,
|
||
|
kernelFunc: scatterNd
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmSelect;
|
||
|
function setup$v(backend) {
|
||
|
wasmSelect = backend.wasm.cwrap(tfjsCore.SelectV2, null, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function select(args) {
|
||
|
const { inputs, backend } = args;
|
||
|
const { condition, t, e } = inputs;
|
||
|
const conditionId = backend.dataIdMap.get(condition.dataId).id;
|
||
|
const tId = backend.dataIdMap.get(t.dataId).id;
|
||
|
const eId = backend.dataIdMap.get(e.dataId).id;
|
||
|
const out = backend.makeOutput(t.shape, t.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const cRank = condition.shape.length;
|
||
|
const tRank = t.shape.length;
|
||
|
const offset = cRank === 0 || cRank > 1 || tRank === 1 ?
|
||
|
1 :
|
||
|
tfjsCore.util.sizeFromShape(t.shape.slice(1));
|
||
|
wasmSelect(conditionId, tId, eId, offset, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const selectV2Config = {
|
||
|
kernelName: tfjsCore.SelectV2,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: select,
|
||
|
setupFunc: setup$v
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$5;
|
||
|
function setup$w(backend) {
|
||
|
wasmFunc$5 = backend.wasm.cwrap(tfjsCore.Sigmoid, null /* void */, ['number', 'number']);
|
||
|
}
|
||
|
function sigmoid(args) {
|
||
|
const { backend, inputs: { x } } = args;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
wasmFunc$5(xId, outId);
|
||
|
return out;
|
||
|
}
|
||
|
const sigmoidConfig = {
|
||
|
kernelName: 'Sigmoid',
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$w,
|
||
|
kernelFunc: sigmoid
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const sinConfig = createUnaryKernelConfig(tfjsCore.Sin);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function slice(args) {
|
||
|
const { inputs: { x }, attrs: { begin, size }, backend } = args;
|
||
|
const [begin_, size_] = tfjsCore.slice_util.parseSliceParams(x, begin, size);
|
||
|
const isContinous = tfjsCore.slice_util.isSliceContinous(x.shape, begin_, size_);
|
||
|
const xVals = backend.typedArrayFromHeap(x);
|
||
|
const out = backend.makeOutput(size_, x.dtype);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
const xStrides = tfjsCore.util.computeStrides(x.shape);
|
||
|
if (isContinous) {
|
||
|
const flatOffset = tfjsCore.slice_util.computeFlatOffset(begin_, xStrides);
|
||
|
outVals.set(xVals.subarray(flatOffset, flatOffset + tfjsCore.util.sizeFromShape(size_)));
|
||
|
return out;
|
||
|
}
|
||
|
const rank = x.shape.length;
|
||
|
if (rank === 2) {
|
||
|
slice2d(xVals, xStrides[0], outVals, begin_, size_);
|
||
|
}
|
||
|
else if (rank === 3) {
|
||
|
slice3d(xVals, xStrides[0], xStrides[1], outVals, begin_, size_);
|
||
|
}
|
||
|
else if (rank === 4) {
|
||
|
slice4d(xVals, xStrides[0], xStrides[1], xStrides[2], outVals, begin_, size_);
|
||
|
}
|
||
|
else {
|
||
|
genericSliceSlow(xVals, x, outVals, begin_, size_);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
function slice2d(xVals, xStride, outVals, begin, size) {
|
||
|
let outOffset = 0;
|
||
|
const beginI = begin[0];
|
||
|
const beginJ = begin[1];
|
||
|
const endI = beginI + size[0];
|
||
|
for (let i = beginI; i < endI; i++) {
|
||
|
const xOffset = i * xStride + beginJ;
|
||
|
outVals.set(xVals.subarray(xOffset, xOffset + size[1]), outOffset);
|
||
|
outOffset += size[1];
|
||
|
}
|
||
|
}
|
||
|
function slice3d(xVals, xStride1, xStride2, outVals, begin, size) {
|
||
|
let outOffset = 0;
|
||
|
const beginI = begin[0];
|
||
|
const beginJ = begin[1];
|
||
|
const beginK = begin[2];
|
||
|
const endI = beginI + size[0];
|
||
|
const endJ = beginJ + size[1];
|
||
|
for (let i = beginI; i < endI; i++) {
|
||
|
for (let j = beginJ; j < endJ; j++) {
|
||
|
const xOffset = i * xStride1 + j * xStride2 + beginK;
|
||
|
outVals.set(xVals.subarray(xOffset, xOffset + size[2]), outOffset);
|
||
|
outOffset += size[2];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
function slice4d(xVals, xStride1, xStride2, xStride3, outVals, begin, size) {
|
||
|
let outOffset = 0;
|
||
|
const beginI = begin[0];
|
||
|
const beginJ = begin[1];
|
||
|
const beginK = begin[2];
|
||
|
const endI = beginI + size[0];
|
||
|
const endJ = beginJ + size[1];
|
||
|
const endK = beginK + size[2];
|
||
|
const beginL = begin[3];
|
||
|
for (let i = beginI; i < endI; i++) {
|
||
|
for (let j = beginJ; j < endJ; j++) {
|
||
|
for (let k = beginK; k < endK; k++) {
|
||
|
const xOffset = i * xStride1 + j * xStride2 + k * xStride3 + beginL;
|
||
|
outVals.set(xVals.subarray(xOffset, xOffset + size[3]), outOffset);
|
||
|
outOffset += size[3];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
function genericSliceSlow(xVals, xInfo, outVals, begin, size) {
|
||
|
const outBuf = tfjsCore.buffer(size, xInfo.dtype, outVals);
|
||
|
const xBuf = tfjsCore.buffer(xInfo.shape, xInfo.dtype, xVals);
|
||
|
for (let i = 0; i < outBuf.size; ++i) {
|
||
|
const loc = outBuf.indexToLoc(i);
|
||
|
const xLoc = loc.map((idx, j) => idx + begin[j]);
|
||
|
outVals[i] = xBuf.get(...xLoc);
|
||
|
}
|
||
|
}
|
||
|
const sliceConfig = {
|
||
|
kernelName: tfjsCore.Slice,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: slice,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmFunc$6;
|
||
|
function setup$x(backend) {
|
||
|
wasmFunc$6 = backend.wasm.cwrap(tfjsCore.Softmax, null /* void */, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
'number' // batch
|
||
|
]);
|
||
|
}
|
||
|
function softmax(args) {
|
||
|
const { backend, inputs: { logits }, attrs: { dim } } = args;
|
||
|
const xId = backend.dataIdMap.get(logits.dataId).id;
|
||
|
const out = backend.makeOutput(logits.shape, logits.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
const channels = logits.shape[dim];
|
||
|
const batch = tfjsCore.util.sizeFromShape(logits.shape) / channels;
|
||
|
// Short-circuit zero-sized tensors.
|
||
|
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
|
||
|
return out;
|
||
|
}
|
||
|
wasmFunc$6(xId, outId, channels, batch);
|
||
|
return out;
|
||
|
}
|
||
|
const softmaxConfig = {
|
||
|
kernelName: tfjsCore.Softmax,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$x,
|
||
|
kernelFunc: softmax
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function split(args) {
|
||
|
const { inputs, attrs, backend } = args;
|
||
|
const { x } = inputs;
|
||
|
const { numOrSizeSplits, axis } = attrs;
|
||
|
const $axis = tfjsCore.util.parseAxisParam(axis, x.shape)[0];
|
||
|
const splitSizes = tfjsCore.backend_util.prepareSplitSize(x, numOrSizeSplits, axis);
|
||
|
const begin = new Array(x.shape.length).fill(0);
|
||
|
const size = x.shape.slice();
|
||
|
return splitSizes.map(s => {
|
||
|
const xSliceSize = [...size];
|
||
|
xSliceSize[$axis] = s;
|
||
|
const xSlice = slice({ inputs: { x }, attrs: { begin, size: xSliceSize }, backend });
|
||
|
begin[$axis] += s;
|
||
|
return xSlice;
|
||
|
});
|
||
|
}
|
||
|
const splitVConfig = {
|
||
|
kernelName: tfjsCore.SplitV,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: split
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const sqrtConfig = createUnaryKernelConfig(tfjsCore.Sqrt);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const squareConfig = createUnaryKernelConfig(tfjsCore.Square);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$e = true;
|
||
|
const squaredDifferenceConfig = createBinaryKernelConfig(tfjsCore.SquaredDifference, supportsFullBroadcast$e);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmStridedSlice;
|
||
|
function setup$y(backend) {
|
||
|
wasmStridedSlice = backend.wasm.cwrap(tfjsCore.StridedSlice, null /*void*/, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'array',
|
||
|
'array',
|
||
|
'array',
|
||
|
'array',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number',
|
||
|
]);
|
||
|
}
|
||
|
function stridedSlice(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
let { begin, end, strides } = attrs;
|
||
|
if (strides == null) {
|
||
|
strides = new Array(begin.length);
|
||
|
}
|
||
|
const { beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
|
||
|
const ellipsisAxes = tfjsCore.backend_util.slice_util.maskToAxes(ellipsisMask);
|
||
|
if (ellipsisAxes.length > 1) {
|
||
|
throw new Error('Multiple ellipses in slice is not allowed.');
|
||
|
}
|
||
|
if (ellipsisMask !== 0 && newAxisMask !== 0) {
|
||
|
throw new Error('Using both ellipsisMask and newAxisMask is not yet supported.');
|
||
|
}
|
||
|
if (ellipsisMask !== 0 && shrinkAxisMask !== 0) {
|
||
|
throw new Error('Using both ellipsisMask and shrinkAxisMask is not yet supported.');
|
||
|
}
|
||
|
const numInterpolatedAxes = x.shape.length - begin.length;
|
||
|
// Expand the dims of x based on the newAxisMask.
|
||
|
const expandAxes = tfjsCore.backend_util.slice_util.maskToAxes(newAxisMask);
|
||
|
const newShape = x.shape.slice();
|
||
|
expandAxes.forEach(axis => {
|
||
|
begin[axis] = 0;
|
||
|
end[axis] = 1;
|
||
|
newShape.splice(axis, 0, 1);
|
||
|
});
|
||
|
const xReshaped = reshape({ inputs: { x }, attrs: { shape: newShape }, backend });
|
||
|
const { begin: normalizedBegin, end: normalizedEnd, strides: normalizedStrides } = tfjsCore.backend_util.slice_util.getNormalizedAxes(xReshaped.shape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask);
|
||
|
begin = normalizedBegin;
|
||
|
end = normalizedEnd;
|
||
|
strides = normalizedStrides;
|
||
|
const shrinkAxes = tfjsCore.backend_util.slice_util.maskToAxes(shrinkAxisMask);
|
||
|
// Adjust the ends based on the shrink mask.
|
||
|
shrinkAxes.forEach(axis => {
|
||
|
end[axis] = begin[axis] + 1;
|
||
|
strides[axis] = 1;
|
||
|
});
|
||
|
// Figure out the output shape.
|
||
|
const size = tfjsCore.backend_util.slice_util.computeOutShape(begin, end, strides);
|
||
|
// Remove the axes based on shrinkMask.
|
||
|
const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1);
|
||
|
const nonStrided = strides.every(v => v === 1);
|
||
|
if (nonStrided) {
|
||
|
const xSliced = slice({ inputs: { x }, attrs: { begin, size }, backend });
|
||
|
return reshape({ inputs: { x: xSliced }, attrs: { shape: outShape }, backend });
|
||
|
}
|
||
|
const out = backend.makeOutput(outShape, 'float32');
|
||
|
if (!outShape.some(axis => axis === 0)) {
|
||
|
const xId = backend.dataIdMap.get(xReshaped.dataId).id;
|
||
|
const xStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(xReshaped.shape)).buffer);
|
||
|
const beginBytes = new Uint8Array(new Int32Array(begin).buffer);
|
||
|
const endBytes = new Uint8Array(new Int32Array(end).buffer);
|
||
|
const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);
|
||
|
const outputShapeBytes = new Uint8Array(new Int32Array(outShape).buffer);
|
||
|
const outStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(outShape)).buffer);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmStridedSlice(xId, xStridesBytes, xReshaped.shape.length, beginBytes, endBytes, stridesBytes, outputShapeBytes, outStridesBytes, outShape.length, outId);
|
||
|
}
|
||
|
return reshape({ inputs: { x: out }, attrs: { shape: outShape }, backend });
|
||
|
}
|
||
|
const stridedSliceConfig = {
|
||
|
kernelName: tfjsCore.StridedSlice,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$y,
|
||
|
kernelFunc: stridedSlice
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const supportsFullBroadcast$f = true;
|
||
|
const subConfig = createBinaryKernelConfig(tfjsCore.Sub, supportsFullBroadcast$f);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmSum;
|
||
|
function setup$z(backend) {
|
||
|
wasmSum = backend.wasm.cwrap(tfjsCore.Sum, null /*void*/, ['number, number, number']);
|
||
|
}
|
||
|
function sum(args) {
|
||
|
const { backend, inputs, attrs } = args;
|
||
|
const { axis, keepDims } = attrs;
|
||
|
const { x } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
let inputId = xId;
|
||
|
let input = x;
|
||
|
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
|
||
|
let reductionAxes = axes;
|
||
|
if (inputWasTransposed) {
|
||
|
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
|
||
|
if (transposedId !== xId) {
|
||
|
// transpose was not a no-op. We will need to dispose of this
|
||
|
// once we are done.
|
||
|
input = transposed;
|
||
|
inputId = transposedId;
|
||
|
reductionAxes = tfjsCore.backend_util.getInnerMostAxes(reductionAxes.length, input.shape.length);
|
||
|
}
|
||
|
}
|
||
|
tfjsCore.backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, input.shape.length);
|
||
|
const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, reductionAxes);
|
||
|
const reduceSize = tfjsCore.util.sizeFromShape(reduceShape);
|
||
|
const out = backend.makeOutput(outShape, input.dtype);
|
||
|
if (tfjsCore.util.sizeFromShape(input.shape) !== 0) {
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmSum(inputId, reduceSize, outId);
|
||
|
}
|
||
|
if (inputWasTransposed) {
|
||
|
// dispose of the transposed tensor.
|
||
|
backend.disposeData(transposed.dataId);
|
||
|
}
|
||
|
if (keepDims) {
|
||
|
// reshape
|
||
|
const newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes);
|
||
|
out.shape = newShape;
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
const sumConfig = {
|
||
|
kernelName: tfjsCore.Sum,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$z,
|
||
|
kernelFunc: sum
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const tanhConfig = createUnaryKernelConfig(tfjsCore.Tanh);
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
let wasmTile;
|
||
|
function setup$A(backend) {
|
||
|
wasmTile = backend.wasm.cwrap(tfjsCore.Tile, null /* void */, [
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'array',
|
||
|
'number',
|
||
|
'number' // out_id
|
||
|
]);
|
||
|
}
|
||
|
function tile(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { x } = inputs;
|
||
|
const xId = backend.dataIdMap.get(x.dataId).id;
|
||
|
const { reps } = attrs;
|
||
|
const newShape = new Array(x.shape.length);
|
||
|
for (let i = 0; i < newShape.length; i++) {
|
||
|
newShape[i] = x.shape[i] * reps[i];
|
||
|
}
|
||
|
const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
|
||
|
const newShapeBytes = new Uint8Array(new Int32Array(newShape).buffer);
|
||
|
const out = backend.makeOutput(newShape, x.dtype);
|
||
|
const outId = backend.dataIdMap.get(out.dataId).id;
|
||
|
wasmTile(xId, xShapeBytes, x.shape.length, newShapeBytes, newShape.length, CppDType[out.dtype], outId);
|
||
|
return out;
|
||
|
}
|
||
|
const tileConfig = {
|
||
|
kernelName: tfjsCore.Tile,
|
||
|
backendName: 'wasm',
|
||
|
setupFunc: setup$A,
|
||
|
kernelFunc: tile
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function unpack(args) {
|
||
|
const { inputs, backend, attrs } = args;
|
||
|
const { value } = inputs;
|
||
|
const { axis } = attrs;
|
||
|
const numOutputs = value.shape[axis];
|
||
|
const rank = value.shape.length;
|
||
|
const outShape = new Array(rank - 1);
|
||
|
let outIndex = 0;
|
||
|
for (let i = 0; i < rank; i++) {
|
||
|
if (i !== axis) {
|
||
|
outShape[outIndex++] = value.shape[i];
|
||
|
}
|
||
|
}
|
||
|
const outs = new Array(numOutputs);
|
||
|
const begin = new Array(rank).fill(0);
|
||
|
const size = value.shape.slice();
|
||
|
size[axis] = 1;
|
||
|
for (let i = 0; i < outs.length; i++) {
|
||
|
begin[axis] = i;
|
||
|
outs[i] = slice({ inputs: { x: value }, attrs: { begin, size }, backend });
|
||
|
}
|
||
|
return outs.map(({ dataId, dtype }) => ({ dataId, dtype, shape: outShape }));
|
||
|
}
|
||
|
const unpackConfig = {
|
||
|
kernelName: tfjsCore.Unpack,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: unpack,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
function zerosLike(args) {
|
||
|
const { inputs: { x }, backend } = args;
|
||
|
const out = backend.makeOutput(x.shape, x.dtype);
|
||
|
const outVals = backend.typedArrayFromHeap(out);
|
||
|
outVals.fill(0);
|
||
|
return out;
|
||
|
}
|
||
|
const zerosLikeConfig = {
|
||
|
kernelName: tfjsCore.ZerosLike,
|
||
|
backendName: 'wasm',
|
||
|
kernelFunc: zerosLike,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
// List all kernel configs here
|
||
|
const kernelConfigs = [
|
||
|
absConfig,
|
||
|
addConfig,
|
||
|
addNConfig,
|
||
|
argMaxConfig,
|
||
|
avgPoolConfig,
|
||
|
batchMatMulConfig,
|
||
|
castConfig,
|
||
|
clipByValueConfig,
|
||
|
concatConfig,
|
||
|
conv2DConfig,
|
||
|
conv2DBackpropInputConfig,
|
||
|
cosConfig,
|
||
|
cropAndResizeConfig,
|
||
|
cumsumConfig,
|
||
|
depthToSpaceConfig,
|
||
|
depthwiseConv2dNativeConfig,
|
||
|
divConfig,
|
||
|
equalConfig,
|
||
|
expConfig,
|
||
|
fillConfig,
|
||
|
flipLeftRightConfig,
|
||
|
floorDivConfig,
|
||
|
fusedMatMulConfig,
|
||
|
fusedBatchNormConfig,
|
||
|
fusedConv2DConfig,
|
||
|
fusedDepthwiseConv2DConfig,
|
||
|
gatherNdConfig,
|
||
|
gatherV2Config,
|
||
|
greaterConfig,
|
||
|
greaterEqualConfig,
|
||
|
identityConfig,
|
||
|
lessConfig,
|
||
|
lessEqualConfig,
|
||
|
logConfig,
|
||
|
logicalAndConfig,
|
||
|
maxConfig,
|
||
|
maximumConfig,
|
||
|
maxPoolConfig,
|
||
|
minConfig,
|
||
|
minimumConfig,
|
||
|
multiplyConfig,
|
||
|
negateConfig,
|
||
|
nonMaxSuppressionV3Config,
|
||
|
nonMaxSuppressionV4Config,
|
||
|
nonMaxSuppressionV5Config,
|
||
|
notEqualConfig,
|
||
|
oneHotConfig,
|
||
|
onesLikeConfig,
|
||
|
padV2Config,
|
||
|
powConfig,
|
||
|
preluConfig,
|
||
|
reluConfig,
|
||
|
relu6Config,
|
||
|
reshapeConfig,
|
||
|
resizeBilinearConfig,
|
||
|
reverseConfig,
|
||
|
rotateWithOffsetConfig,
|
||
|
rsqrtConfig,
|
||
|
scatterNdConfig,
|
||
|
selectV2Config,
|
||
|
sigmoidConfig,
|
||
|
sinConfig,
|
||
|
sliceConfig,
|
||
|
softmaxConfig,
|
||
|
splitVConfig,
|
||
|
sqrtConfig,
|
||
|
squareConfig,
|
||
|
squaredDifferenceConfig,
|
||
|
stridedSliceConfig,
|
||
|
subConfig,
|
||
|
sumConfig,
|
||
|
tanhConfig,
|
||
|
tileConfig,
|
||
|
transposeConfig,
|
||
|
unpackConfig,
|
||
|
zerosLikeConfig
|
||
|
];
|
||
|
for (const kernelConfig of kernelConfigs) {
|
||
|
tfjsCore.registerKernel(kernelConfig);
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const ENV = tfjsCore.env();
|
||
|
/**
|
||
|
* True if SIMD is supported.
|
||
|
*/
|
||
|
// From: https://github.com/GoogleChromeLabs/wasm-feature-detect
|
||
|
ENV.registerFlag(
|
||
|
// This typed array passed in to WebAssembly.validate is WebAssembly binary
|
||
|
// code. In this case it is a small program that contains SIMD
|
||
|
// instructions.
|
||
|
'WASM_HAS_SIMD_SUPPORT', async () => WebAssembly.validate(new Uint8Array([
|
||
|
0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3,
|
||
|
2, 1, 0, 10, 9, 1, 7, 0, 65, 0, 253, 15, 26, 11
|
||
|
])));
|
||
|
/**
|
||
|
* True if threads are supported.
|
||
|
*/
|
||
|
// From: https://github.com/GoogleChromeLabs/wasm-feature-detect
|
||
|
ENV.registerFlag('WASM_HAS_MULTITHREAD_SUPPORT', async () => {
|
||
|
// TODO(annxingyuan): Enable node support once this is resolved:
|
||
|
// https://github.com/tensorflow/tfjs/issues/3830
|
||
|
if (ENV.get('IS_NODE')) {
|
||
|
return false;
|
||
|
}
|
||
|
try {
|
||
|
// Test for transferability of SABs (needed for Firefox)
|
||
|
// https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
|
||
|
new MessageChannel().port1.postMessage(new SharedArrayBuffer(1));
|
||
|
// This typed array is a WebAssembly program containing threaded
|
||
|
// instructions.
|
||
|
return WebAssembly.validate(new Uint8Array([
|
||
|
0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5,
|
||
|
4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16, 2, 0, 26, 11
|
||
|
]));
|
||
|
}
|
||
|
catch (e) {
|
||
|
return false;
|
||
|
}
|
||
|
});
|
||
|
|
||
|
function createCommonjsModule(fn, module) {
|
||
|
return module = { exports: {} }, fn(module, module.exports), module.exports;
|
||
|
}
|
||
|
|
||
|
var tfjsBackendWasmThreadedSimd = createCommonjsModule(function (module, exports) {
|
||
|
var WasmBackendModuleThreadedSimd = (function() {
|
||
|
var _scriptDir = typeof document !== 'undefined' && document.currentScript ? document.currentScript.src : undefined;
|
||
|
if (typeof __filename !== 'undefined') _scriptDir = _scriptDir || __filename;
|
||
|
return (
|
||
|
function(WasmBackendModuleThreadedSimd) {
|
||
|
WasmBackendModuleThreadedSimd = WasmBackendModuleThreadedSimd || {};
|
||
|
|
||
|
function GROWABLE_HEAP_I8(){if(wasmMemory.buffer!=buffer){updateGlobalBufferAndViews(wasmMemory.buffer);}return HEAP8}function GROWABLE_HEAP_U8(){if(wasmMemory.buffer!=buffer){updateGlobalBufferAndViews(wasmMemory.buffer);}return HEAPU8}function GROWABLE_HEAP_I32(){if(wasmMemory.buffer!=buffer){updateGlobalBufferAndViews(wasmMemory.buffer);}return HEAP32}function GROWABLE_HEAP_U32(){if(wasmMemory.buffer!=buffer){updateGlobalBufferAndViews(wasmMemory.buffer);}return HEAPU32}function GROWABLE_HEAP_F64(){if(wasmMemory.buffer!=buffer){updateGlobalBufferAndViews(wasmMemory.buffer);}return HEAPF64}var Module=typeof WasmBackendModuleThreadedSimd!=="undefined"?WasmBackendModuleThreadedSimd:{};var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}var arguments_=[];var thisProgram="./this.program";var quit_=function(status,toThrow){throw toThrow};var ENVIRONMENT_IS_WEB=false;var ENVIRONMENT_IS_WORKER=false;var ENVIRONMENT_IS_NODE=false;var ENVIRONMENT_IS_SHELL=false;ENVIRONMENT_IS_WEB=typeof window==="object";ENVIRONMENT_IS_WORKER=typeof importScripts==="function";ENVIRONMENT_IS_NODE=typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string";ENVIRONMENT_IS_SHELL=!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_NODE&&!ENVIRONMENT_IS_WORKER;var ENVIRONMENT_IS_PTHREAD=Module["ENVIRONMENT_IS_PTHREAD"]||false;if(ENVIRONMENT_IS_PTHREAD){buffer=Module["buffer"];DYNAMIC_BASE=Module["DYNAMIC_BASE"];DYNAMICTOP_PTR=Module["DYNAMICTOP_PTR"];}var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var read_,readBinary;var nodeFS;var nodePath;if(ENVIRONMENT_IS_NODE){if(ENVIRONMENT_IS_WORKER){scriptDirectory=path.dirname(scriptDirectory)+"/";}else{scriptDirectory=__dirname+"/";}read_=function shell_read(filename,binary){if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);return nodeFS["readFileSync"](filename,binary?null:"utf8")};readBinary=function readBinary(filename){var ret=read_(filename,true);if(!ret.buffer){ret=new Uint8Array(ret);}assert(ret.buffer);return ret};if(process["argv"].length>1){thisProgram=process["argv"][1].replace(/\\/g,"/");}arguments_=process["argv"].slice(2);process["on"]("uncaughtException",function(ex){if(!(ex instanceof ExitStatus)){throw ex}});process["on"]("unhandledRejection",abort);quit_=function(status){process["exit"](status);};Module["inspect"]=function(){return "[Emscripten Module object]"};var nodeWorkerThreads;try{nodeWorkerThreads=worker_threads;}catch(e){console.error('The "worker_threads" module is not supported in this node.js build - perhaps a newer version is needed?');throw e}Worker=nodeWorkerThreads.Worker;}else if(ENVIRONMENT_IS_SHELL){if(typeof read!="undefined"){read_=function shell_read(f){return read(f)};}readBinary=function readBinary(f){var data;if(typeof readbuffer==="function"){return new Uint8Array(readbuffer(f))}data=read(f,"binary");assert(typeof data==="object");return data};if(typeof scriptArgs!="undefined"){arguments_=scriptArgs;}else if(typeof arguments!="undefined"){arguments_=arguments;}if(typeof quit==="function"){quit_=function(status){quit(status);};}if(typeof print!=="undefined"){if(typeof console==="undefined")console={};console.log=print;console.warn=console.error=typeof printErr!=="undefined"?printErr:print;}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href;}else if(document.currentScript){scriptDirectory=document.currentScript.src;}if(_scriptDir){scriptDirectory=_scriptDir;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}if(ENVIRONMENT_IS_NODE){read_=function shell_read(filename,binary){if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);return nodeFS["readFileSync"](filename,binary?null:"utf8")};readBinary=function readBinary(filename){var ret=read_(filena
|
||
|
|
||
|
|
||
|
return WasmBackendModuleThreadedSimd
|
||
|
}
|
||
|
);
|
||
|
})();
|
||
|
module.exports = WasmBackendModuleThreadedSimd;
|
||
|
});
|
||
|
|
||
|
const wasmWorkerContents = 'var threadInfoStruct=0;var selfThreadId=0;var parentThreadId=0;var Module={};function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:selfThreadId})}var err=threadPrintErr;this.alert=threadAlert;Module["instantiateWasm"]=function(info,receiveInstance){var instance=new WebAssembly.Instance(Module["wasmModule"],info);Module["wasmModule"]=null;receiveInstance(instance);return instance.exports};this.onmessage=function(e){try{if(e.data.cmd==="load"){Module["DYNAMIC_BASE"]=e.data.DYNAMIC_BASE;Module["DYNAMICTOP_PTR"]=e.data.DYNAMICTOP_PTR;Module["wasmModule"]=e.data.wasmModule;Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob==="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}Module=WasmBackendModuleThreadedSimd(Module);postMessage({"cmd":"loaded"})}else if(e.data.cmd==="objectTransfer"){Module["PThread"].receiveObjectTransfer(e.data)}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;threadInfoStruct=e.data.threadInfoStruct;Module["__register_pthread_ptr"](threadInfoStruct,0,0);selfThreadId=e.data.selfThreadId;parentThreadId=e.data.parentThreadId;var max=e.data.stackBase;var top=e.data.stackBase+e.data.stackSize;Module["establishStackSpace"](top,max);Module["_emscripten_tls_init"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].setThreadStatus(Module["_pthread_self"](),1);try{var result=Module["dynCall_ii"](e.data.start_routine,e.data.arg);if(!Module["getNoExitRuntime"]())Module["PThread"].threadExit(result)}catch(ex){if(ex==="Canceled!"){Module["PThread"].threadCancel()}else if(ex!="unwind"){Atomics.store(Module["HEAPU32"],threadInfoStruct+4>>2,ex instanceof Module["ExitStatus"]?ex.status:-2);Atomics.store(Module["HEAPU32"],threadInfoStruct+0>>2,1);Module["_emscripten_futex_wake"](threadInfoStruct+0,2147483647);if(!(ex instanceof Module["ExitStatus"]))throw ex}}}else if(e.data.cmd==="cancel"){if(threadInfoStruct){Module["PThread"].threadCancel()}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processThreadQueue"){if(threadInfoStruct){Module["_emscripten_current_thread_process_queued_calls"]()}}else{err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){err("worker.js onmessage() captured an uncaught exception: "+ex);if(ex.stack)err(ex.stack);throw ex}};if(typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string"){self={location:{href:__filename}};var onmessage=this.onmessage;var nodeWorkerThreads=require("worker_threads");Worker=nodeWorkerThreads.Worker;var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",function(data){onmessage({data:data})});var nodeFS=require("fs");var nodeRead=function(filename){return nodeFS.readFileSync(filename,"utf8")};function globalEval(x){global.require=require;global.Module=Module;eval.call(null,x)}importScripts=function(f){globalEval(nodeRead(f))};postMessage=function(msg){parentPort.postMessage(msg)};if(typeof performance==="undefined"){performance={now:function(){return Date.now()}}}}';
|
||
|
|
||
|
var tfjsBackendWasm = createCommonjsModule(function (module, exports) {
|
||
|
var WasmBackendModule = (function() {
|
||
|
var _scriptDir = typeof document !== 'undefined' && document.currentScript ? document.currentScript.src : undefined;
|
||
|
if (typeof __filename !== 'undefined') _scriptDir = _scriptDir || __filename;
|
||
|
return (
|
||
|
function(WasmBackendModule) {
|
||
|
WasmBackendModule = WasmBackendModule || {};
|
||
|
|
||
|
var Module=typeof WasmBackendModule!=="undefined"?WasmBackendModule:{};var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}var arguments_=[];var thisProgram="./this.program";var quit_=function(status,toThrow){throw toThrow};var ENVIRONMENT_IS_WEB=false;var ENVIRONMENT_IS_WORKER=false;var ENVIRONMENT_IS_NODE=false;var ENVIRONMENT_IS_SHELL=false;ENVIRONMENT_IS_WEB=typeof window==="object";ENVIRONMENT_IS_WORKER=typeof importScripts==="function";ENVIRONMENT_IS_NODE=typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string";ENVIRONMENT_IS_SHELL=!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_NODE&&!ENVIRONMENT_IS_WORKER;var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var read_,readBinary;var nodeFS;var nodePath;if(ENVIRONMENT_IS_NODE){if(ENVIRONMENT_IS_WORKER){scriptDirectory=path.dirname(scriptDirectory)+"/";}else{scriptDirectory=__dirname+"/";}read_=function shell_read(filename,binary){if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);return nodeFS["readFileSync"](filename,binary?null:"utf8")};readBinary=function readBinary(filename){var ret=read_(filename,true);if(!ret.buffer){ret=new Uint8Array(ret);}assert(ret.buffer);return ret};if(process["argv"].length>1){thisProgram=process["argv"][1].replace(/\\/g,"/");}arguments_=process["argv"].slice(2);process["on"]("uncaughtException",function(ex){if(!(ex instanceof ExitStatus)){throw ex}});process["on"]("unhandledRejection",abort);quit_=function(status){process["exit"](status);};Module["inspect"]=function(){return "[Emscripten Module object]"};}else if(ENVIRONMENT_IS_SHELL){if(typeof read!="undefined"){read_=function shell_read(f){return read(f)};}readBinary=function readBinary(f){var data;if(typeof readbuffer==="function"){return new Uint8Array(readbuffer(f))}data=read(f,"binary");assert(typeof data==="object");return data};if(typeof scriptArgs!="undefined"){arguments_=scriptArgs;}else if(typeof arguments!="undefined"){arguments_=arguments;}if(typeof quit==="function"){quit_=function(status){quit(status);};}if(typeof print!=="undefined"){if(typeof console==="undefined")console={};console.log=print;console.warn=console.error=typeof printErr!=="undefined"?printErr:print;}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href;}else if(document.currentScript){scriptDirectory=document.currentScript.src;}if(_scriptDir){scriptDirectory=_scriptDir;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}{read_=function shell_read(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.send(null);return xhr.responseText};if(ENVIRONMENT_IS_WORKER){readBinary=function readBinary(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.responseType="arraybuffer";xhr.send(null);return new Uint8Array(xhr.response)};}}}var out=Module["print"]||console.log.bind(console);var err=Module["printErr"]||console.warn.bind(console);for(key in moduleOverrides){if(moduleOverrides.hasOwnProperty(key)){Module[key]=moduleOverrides[key];}}moduleOverrides=null;if(Module["arguments"])arguments_=Module["arguments"];if(Module["thisProgram"])thisProgram=Module["thisProgram"];if(Module["quit"])quit_=Module["quit"];var wasmBinary;if(Module["wasmBinary"])wasmBinary=Module["wasmBinary"];var noExitRuntime;if(Module["noExitRuntime"])noExitRuntime=Module["noExitRuntime"];if(typeof WebAssembly!=="object"){err("no native wasm support detected");}var wasmMemory;var wasmTable=new WebAssembly.Table({"initial":147,"maximum":147+0,"element":"anyfunc"});var ABORT=false;function assert(condition,text){if(!condition){abort("Assertion failed: "+text);}}function getCFunc(ident){var func=Module["_"+ident];assert(func,"Cannot call unknown function "+ident+", make sure it is exported");return func}function ccall(ident,returnType,argTypes,ar
|
||
|
|
||
|
|
||
|
return WasmBackendModule
|
||
|
}
|
||
|
);
|
||
|
})();
|
||
|
module.exports = WasmBackendModule;
|
||
|
});
|
||
|
|
||
|
/**
|
||
|
* @license
|
||
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
* =============================================================================
|
||
|
*/
|
||
|
const WASM_PRIORITY = 2;
|
||
|
class BackendWasm extends tfjsCore.KernelBackend {
|
||
|
constructor(wasm) {
|
||
|
super();
|
||
|
this.wasm = wasm;
|
||
|
// 0 is reserved for null data ids.
|
||
|
this.dataIdNextNumber = 1;
|
||
|
this.wasm.tfjs.init();
|
||
|
this.dataIdMap = new tfjsCore.DataStorage(this, tfjsCore.engine());
|
||
|
}
|
||
|
write(values, shape, dtype) {
|
||
|
const dataId = {};
|
||
|
this.move(dataId, values, shape, dtype);
|
||
|
return dataId;
|
||
|
}
|
||
|
numDataIds() {
|
||
|
return this.dataIdMap.numDataIds();
|
||
|
}
|
||
|
async time(f) {
|
||
|
const start = tfjsCore.util.now();
|
||
|
f();
|
||
|
const kernelMs = tfjsCore.util.now() - start;
|
||
|
return { kernelMs };
|
||
|
}
|
||
|
move(dataId, values, shape, dtype) {
|
||
|
const id = this.dataIdNextNumber++;
|
||
|
if (dtype === 'string') {
|
||
|
const stringBytes = values;
|
||
|
this.dataIdMap.set(dataId, { id, stringBytes, shape, dtype, memoryOffset: null });
|
||
|
return;
|
||
|
}
|
||
|
const size = tfjsCore.util.sizeFromShape(shape);
|
||
|
const numBytes = size * tfjsCore.util.bytesPerElement(dtype);
|
||
|
const memoryOffset = this.wasm._malloc(numBytes);
|
||
|
this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype });
|
||
|
this.wasm.tfjs.registerTensor(id, size, memoryOffset);
|
||
|
if (values != null) {
|
||
|
this.wasm.HEAPU8.set(new Uint8Array(values.buffer, values.byteOffset, numBytes), memoryOffset);
|
||
|
}
|
||
|
}
|
||
|
async read(dataId) {
|
||
|
return this.readSync(dataId);
|
||
|
}
|
||
|
readSync(dataId) {
|
||
|
const { memoryOffset, dtype, shape, stringBytes } = this.dataIdMap.get(dataId);
|
||
|
if (dtype === 'string') {
|
||
|
return stringBytes;
|
||
|
}
|
||
|
const bytes = this.wasm.HEAPU8.slice(memoryOffset, memoryOffset + tfjsCore.util.sizeFromShape(shape) * tfjsCore.util.bytesPerElement(dtype));
|
||
|
return typedArrayFromBuffer(bytes.buffer, dtype);
|
||
|
}
|
||
|
disposeData(dataId) {
|
||
|
const data = this.dataIdMap.get(dataId);
|
||
|
this.wasm._free(data.memoryOffset);
|
||
|
this.wasm.tfjs.disposeData(data.id);
|
||
|
this.dataIdMap.delete(dataId);
|
||
|
}
|
||
|
floatPrecision() {
|
||
|
return 32;
|
||
|
}
|
||
|
// Returns the memory offset of a tensor. Useful for debugging and unit
|
||
|
// testing.
|
||
|
getMemoryOffset(dataId) {
|
||
|
return this.dataIdMap.get(dataId).memoryOffset;
|
||
|
}
|
||
|
dispose() {
|
||
|
this.wasm.tfjs.dispose();
|
||
|
this.wasm = null;
|
||
|
}
|
||
|
memory() {
|
||
|
return { unreliable: false };
|
||
|
}
|
||
|
/**
|
||
|
* Make a tensor info for the output of an op. If `memoryOffset` is not
|
||
|
* present, this method allocates memory on the WASM heap. If `memoryOffset`
|
||
|
* is present, the memory was allocated elsewhere (in c++) and we just record
|
||
|
* the pointer where that memory lives.
|
||
|
*/
|
||
|
makeOutput(shape, dtype, memoryOffset) {
|
||
|
let dataId;
|
||
|
if (memoryOffset == null) {
|
||
|
dataId = this.write(null /* values */, shape, dtype);
|
||
|
}
|
||
|
else {
|
||
|
dataId = {};
|
||
|
const id = this.dataIdNextNumber++;
|
||
|
this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype });
|
||
|
const size = tfjsCore.util.sizeFromShape(shape);
|
||
|
this.wasm.tfjs.registerTensor(id, size, memoryOffset);
|
||
|
}
|
||
|
return { dataId, shape, dtype };
|
||
|
}
|
||
|
typedArrayFromHeap({ shape, dtype, dataId }) {
|
||
|
const buffer = this.wasm.HEAPU8.buffer;
|
||
|
const { memoryOffset } = this.dataIdMap.get(dataId);
|
||
|
const size = tfjsCore.util.sizeFromShape(shape);
|
||
|
switch (dtype) {
|
||
|
case 'float32':
|
||
|
return new Float32Array(buffer, memoryOffset, size);
|
||
|
case 'int32':
|
||
|
return new Int32Array(buffer, memoryOffset, size);
|
||
|
case 'bool':
|
||
|
return new Uint8Array(buffer, memoryOffset, size);
|
||
|
default:
|
||
|
throw new Error(`Unknown dtype ${dtype}`);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
tfjsCore.registerBackend('wasm', async () => {
|
||
|
const { wasm } = await init();
|
||
|
return new BackendWasm(wasm);
|
||
|
}, WASM_PRIORITY);
|
||
|
function createInstantiateWasmFunc(path) {
|
||
|
// tslint:disable-next-line:no-any
|
||
|
return (imports, callback) => {
|
||
|
tfjsCore.util.fetch(path, { credentials: 'same-origin' }).then((response) => {
|
||
|
if (!response['ok']) {
|
||
|
imports.env.a(`failed to load wasm binary file at '${path}'`);
|
||
|
}
|
||
|
response.arrayBuffer().then(binary => {
|
||
|
WebAssembly.instantiate(binary, imports).then(output => {
|
||
|
callback(output.instance);
|
||
|
});
|
||
|
});
|
||
|
});
|
||
|
return {};
|
||
|
};
|
||
|
}
|
||
|
/**
|
||
|
* Returns the path of the WASM binary.
|
||
|
* @param simdSupported whether SIMD is supported
|
||
|
* @param threadsSupported whether multithreading is supported
|
||
|
* @param wasmModuleFolder the directory containing the WASM binaries.
|
||
|
*/
|
||
|
function getPathToWasmBinary(simdSupported, threadsSupported, wasmModuleFolder) {
|
||
|
if (wasmPath != null) {
|
||
|
// If wasmPath is defined, the user has supplied a full path to
|
||
|
// the vanilla .wasm binary.
|
||
|
return wasmPath;
|
||
|
}
|
||
|
let path = 'tfjs-backend-wasm.wasm';
|
||
|
if (simdSupported && threadsSupported) {
|
||
|
path = 'tfjs-backend-wasm-threaded-simd.wasm';
|
||
|
}
|
||
|
else if (simdSupported) {
|
||
|
path = 'tfjs-backend-wasm-simd.wasm';
|
||
|
}
|
||
|
if (wasmFileMap != null) {
|
||
|
if (wasmFileMap[path] != null) {
|
||
|
return wasmFileMap[path];
|
||
|
}
|
||
|
}
|
||
|
return wasmModuleFolder + path;
|
||
|
}
|
||
|
/**
|
||
|
* Initializes the wasm module and creates the js <--> wasm bridge.
|
||
|
*
|
||
|
* NOTE: We wrap the wasm module in a object with property 'wasm' instead of
|
||
|
* returning Promise<BackendWasmModule> to avoid freezing Chrome (last tested
|
||
|
* in Chrome 76).
|
||
|
*/
|
||
|
async function init() {
|
||
|
const [simdSupported, threadsSupported] = await Promise.all([
|
||
|
tfjsCore.env().getAsync('WASM_HAS_SIMD_SUPPORT'),
|
||
|
tfjsCore.env().getAsync('WASM_HAS_MULTITHREAD_SUPPORT')
|
||
|
]);
|
||
|
return new Promise((resolve, reject) => {
|
||
|
const factoryConfig = {};
|
||
|
/**
|
||
|
* This function overrides the Emscripten module locateFile utility.
|
||
|
* @param path The relative path to the file that needs to be loaded.
|
||
|
* @param prefix The path to the main JavaScript file's directory.
|
||
|
*/
|
||
|
factoryConfig.locateFile = (path, prefix) => {
|
||
|
if (path.endsWith('.worker.js')) {
|
||
|
const response = wasmWorkerContents;
|
||
|
const blob = new Blob([response], { type: 'application/javascript' });
|
||
|
return URL.createObjectURL(blob);
|
||
|
}
|
||
|
if (path.endsWith('.wasm')) {
|
||
|
return getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : prefix);
|
||
|
}
|
||
|
return prefix + path;
|
||
|
};
|
||
|
// Use the instantiateWasm override when system fetch is not available.
|
||
|
// Reference:
|
||
|
// https://github.com/emscripten-core/emscripten/blob/2bca083cbbd5a4133db61fbd74d04f7feecfa907/tests/manual_wasm_instantiate.html#L170
|
||
|
if (customFetch) {
|
||
|
factoryConfig.instantiateWasm =
|
||
|
createInstantiateWasmFunc(getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : ''));
|
||
|
}
|
||
|
let wasm;
|
||
|
// If `wasmPath` has been defined we must initialize the vanilla module.
|
||
|
if (threadsSupported && simdSupported && wasmPath == null) {
|
||
|
wasm = tfjsBackendWasmThreadedSimd(factoryConfig);
|
||
|
wasm.mainScriptUrlOrBlob = new Blob([`var _scriptDir = undefined; var WasmBackendModuleThreadedSimd = ` +
|
||
|
tfjsBackendWasmThreadedSimd.toString()], { type: 'text/javascript' });
|
||
|
}
|
||
|
else {
|
||
|
// The wasmFactory works for both vanilla and SIMD binaries.
|
||
|
wasm = tfjsBackendWasm(factoryConfig);
|
||
|
}
|
||
|
const voidReturnType = null;
|
||
|
// Using the tfjs namespace to avoid conflict with emscripten's API.
|
||
|
wasm.tfjs = {
|
||
|
init: wasm.cwrap('init', null, []),
|
||
|
registerTensor: wasm.cwrap('register_tensor', null, [
|
||
|
'number',
|
||
|
'number',
|
||
|
'number',
|
||
|
]),
|
||
|
disposeData: wasm.cwrap('dispose_data', voidReturnType, ['number']),
|
||
|
dispose: wasm.cwrap('dispose', voidReturnType, []),
|
||
|
};
|
||
|
let initialized = false;
|
||
|
wasm.onRuntimeInitialized = () => {
|
||
|
initialized = true;
|
||
|
initAborted = false;
|
||
|
resolve({ wasm });
|
||
|
};
|
||
|
wasm.onAbort = () => {
|
||
|
if (initialized) {
|
||
|
// Emscripten already called console.warn so no need to double log.
|
||
|
return;
|
||
|
}
|
||
|
if (initAborted) {
|
||
|
// Emscripten calls `onAbort` twice, resulting in double error
|
||
|
// messages.
|
||
|
return;
|
||
|
}
|
||
|
initAborted = true;
|
||
|
const rejectMsg = 'Make sure the server can serve the `.wasm` file relative to the ' +
|
||
|
'bundled js file. For more details see https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers';
|
||
|
reject({ message: rejectMsg });
|
||
|
};
|
||
|
});
|
||
|
}
|
||
|
function typedArrayFromBuffer(buffer, dtype) {
|
||
|
switch (dtype) {
|
||
|
case 'float32':
|
||
|
return new Float32Array(buffer);
|
||
|
case 'int32':
|
||
|
return new Int32Array(buffer);
|
||
|
case 'bool':
|
||
|
return new Uint8Array(buffer);
|
||
|
default:
|
||
|
throw new Error(`Unknown dtype ${dtype}`);
|
||
|
}
|
||
|
}
|
||
|
const wasmBinaryNames = [
|
||
|
'tfjs-backend-wasm.wasm', 'tfjs-backend-wasm-simd.wasm',
|
||
|
'tfjs-backend-wasm-threaded-simd.wasm'
|
||
|
];
|
||
|
let wasmPath = null;
|
||
|
let wasmPathPrefix = null;
|
||
|
let wasmFileMap = {};
|
||
|
let initAborted = false;
|
||
|
let customFetch = false;
|
||
|
/**
|
||
|
* @deprecated Use `setWasmPaths` instead.
|
||
|
* Sets the path to the `.wasm` file which will be fetched when the wasm
|
||
|
* backend is initialized. See
|
||
|
* https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers
|
||
|
* for more details.
|
||
|
* @param path wasm file path or url
|
||
|
* @param usePlatformFetch optional boolean to use platform fetch to download
|
||
|
* the wasm file, default to false.
|
||
|
*
|
||
|
* @doc {heading: 'Environment', namespace: 'wasm'}
|
||
|
*/
|
||
|
function setWasmPath(path, usePlatformFetch = false) {
|
||
|
tfjsCore.deprecationWarn('setWasmPath has been deprecated in favor of setWasmPaths and' +
|
||
|
' will be removed in a future release.');
|
||
|
if (initAborted) {
|
||
|
throw new Error('The WASM backend was already initialized. Make sure you call ' +
|
||
|
'`setWasmPath()` before you call `tf.setBackend()` or `tf.ready()`');
|
||
|
}
|
||
|
wasmPath = path;
|
||
|
customFetch = usePlatformFetch;
|
||
|
}
|
||
|
/**
|
||
|
* Configures the locations of the WASM binaries.
|
||
|
*
|
||
|
* ```js
|
||
|
* setWasmPaths({
|
||
|
* 'tfjs-backend-wasm.wasm': 'renamed.wasm',
|
||
|
* 'tfjs-backend-wasm-simd.wasm': 'renamed-simd.wasm',
|
||
|
* 'tfjs-backend-wasm-threaded-simd.wasm': 'renamed-threaded-simd.wasm'
|
||
|
* });
|
||
|
* tf.setBackend('wasm');
|
||
|
* ```
|
||
|
*
|
||
|
* @param prefixOrFileMap This can be either a string or object:
|
||
|
* - (string) The path to the directory where the WASM binaries are located.
|
||
|
* Note that this prefix will be used to load each binary (vanilla,
|
||
|
* SIMD-enabled, threading-enabled, etc.).
|
||
|
* - (object) Mapping from names of WASM binaries to custom
|
||
|
* full paths specifying the locations of those binaries. This is useful if
|
||
|
* your WASM binaries are not all located in the same directory, or if your
|
||
|
* WASM binaries have been renamed.
|
||
|
* @param usePlatformFetch optional boolean to use platform fetch to download
|
||
|
* the wasm file, default to false.
|
||
|
*
|
||
|
* @doc {heading: 'Environment', namespace: 'wasm'}
|
||
|
*/
|
||
|
function setWasmPaths(prefixOrFileMap, usePlatformFetch = false) {
|
||
|
if (initAborted) {
|
||
|
throw new Error('The WASM backend was already initialized. Make sure you call ' +
|
||
|
'`setWasmPaths()` before you call `tf.setBackend()` or ' +
|
||
|
'`tf.ready()`');
|
||
|
}
|
||
|
if (typeof prefixOrFileMap === 'string') {
|
||
|
wasmPathPrefix = prefixOrFileMap;
|
||
|
}
|
||
|
else {
|
||
|
wasmFileMap = prefixOrFileMap;
|
||
|
const missingPaths = wasmBinaryNames.filter(name => wasmFileMap[name] == null);
|
||
|
if (missingPaths.length > 0) {
|
||
|
throw new Error(`There were no entries found for the following binaries: ` +
|
||
|
`${missingPaths.join(',')}. Please either call setWasmPaths with a ` +
|
||
|
`map providing a path for each binary, or with a string indicating ` +
|
||
|
`the directory where all the binaries can be found.`);
|
||
|
}
|
||
|
}
|
||
|
customFetch = usePlatformFetch;
|
||
|
}
|
||
|
|
||
|
/** @license See the LICENSE file. */
|
||
|
// This code is auto-generated, do not modify this file!
|
||
|
const version = '2.7.0';
|
||
|
|
||
|
exports.BackendWasm = BackendWasm;
|
||
|
exports.setWasmPath = setWasmPath;
|
||
|
exports.setWasmPaths = setWasmPaths;
|
||
|
exports.version_wasm = version;
|
||
|
|
||
|
Object.defineProperty(exports, '__esModule', { value: true });
|
||
|
|
||
|
})));
|
||
|
//# sourceMappingURL=tf-backend-wasm.es2017.js.map
|