face-api/build/common/extractSeparableConvParamsF...

25 lines
1.5 KiB
JavaScript
Raw Normal View History

"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.loadSeparableConvParamsFactory = exports.extractSeparableConvParamsFactory = void 0;
const tf = require("@tensorflow/tfjs-core");
const types_1 = require("./types");
function extractSeparableConvParamsFactory(extractWeights, paramMappings) {
2020-08-18 14:04:33 +02:00
return function (channelsIn, channelsOut, mappedPrefix) {
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push({ paramPath: `${mappedPrefix}/depthwise_filter` }, { paramPath: `${mappedPrefix}/pointwise_filter` }, { paramPath: `${mappedPrefix}/bias` });
return new types_1.SeparableConvParams(depthwise_filter, pointwise_filter, bias);
2020-08-18 14:04:33 +02:00
};
}
exports.extractSeparableConvParamsFactory = extractSeparableConvParamsFactory;
function loadSeparableConvParamsFactory(extractWeightEntry) {
2020-08-18 14:04:33 +02:00
return function (prefix) {
const depthwise_filter = extractWeightEntry(`${prefix}/depthwise_filter`, 4);
const pointwise_filter = extractWeightEntry(`${prefix}/pointwise_filter`, 4);
const bias = extractWeightEntry(`${prefix}/bias`, 1);
return new types_1.SeparableConvParams(depthwise_filter, pointwise_filter, bias);
2020-08-18 14:04:33 +02:00
};
}
exports.loadSeparableConvParamsFactory = loadSeparableConvParamsFactory;
2020-08-18 14:04:33 +02:00
//# sourceMappingURL=extractSeparableConvParamsFactory.js.map