37 lines
1.3 KiB
JavaScript
37 lines
1.3 KiB
JavaScript
"use strict";
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.residualDown = exports.residual = void 0;
|
|
const tf = require("@tensorflow/tfjs-core");
|
|
const convLayer_1 = require("./convLayer");
|
|
function residual(x, params) {
|
|
let out = convLayer_1.conv(x, params.conv1);
|
|
out = convLayer_1.convNoRelu(out, params.conv2);
|
|
out = tf.add(out, x);
|
|
out = tf.relu(out);
|
|
return out;
|
|
}
|
|
exports.residual = residual;
|
|
function residualDown(x, params) {
|
|
let out = convLayer_1.convDown(x, params.conv1);
|
|
out = convLayer_1.convNoRelu(out, params.conv2);
|
|
let pooled = tf.avgPool(x, 2, 2, 'valid');
|
|
const zeros = tf.zeros(pooled.shape);
|
|
const isPad = pooled.shape[3] !== out.shape[3];
|
|
const isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2];
|
|
if (isAdjustShape) {
|
|
const padShapeX = [...out.shape];
|
|
padShapeX[1] = 1;
|
|
const zerosW = tf.zeros(padShapeX);
|
|
out = tf.concat([out, zerosW], 1);
|
|
const padShapeY = [...out.shape];
|
|
padShapeY[2] = 1;
|
|
const zerosH = tf.zeros(padShapeY);
|
|
out = tf.concat([out, zerosH], 2);
|
|
}
|
|
pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled;
|
|
out = tf.add(pooled, out);
|
|
out = tf.relu(out);
|
|
return out;
|
|
}
|
|
exports.residualDown = residualDown;
|
|
//# sourceMappingURL=residualLayer.js.map
|