face-api/build/faceRecognitionNet/extractParamsFromWeigthMap.js

75 lines
3.0 KiB
JavaScript
Raw Normal View History

"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.extractParamsFromWeigthMap = void 0;
const common_1 = require("../common");
const utils_1 = require("../utils");
2020-08-18 14:04:33 +02:00
function extractorsFactory(weightMap, paramMappings) {
const extractWeightEntry = common_1.extractWeightEntryFactory(weightMap, paramMappings);
2020-08-18 14:04:33 +02:00
function extractScaleLayerParams(prefix) {
const weights = extractWeightEntry(`${prefix}/scale/weights`, 1);
const biases = extractWeightEntry(`${prefix}/scale/biases`, 1);
return { weights, biases };
}
function extractConvLayerParams(prefix) {
const filters = extractWeightEntry(`${prefix}/conv/filters`, 4);
const bias = extractWeightEntry(`${prefix}/conv/bias`, 1);
const scale = extractScaleLayerParams(prefix);
return { conv: { filters, bias }, scale };
}
function extractResidualLayerParams(prefix) {
return {
conv1: extractConvLayerParams(`${prefix}/conv1`),
conv2: extractConvLayerParams(`${prefix}/conv2`)
};
}
return {
extractConvLayerParams,
extractResidualLayerParams
};
}
function extractParamsFromWeigthMap(weightMap) {
2020-08-18 14:04:33 +02:00
const paramMappings = [];
const { extractConvLayerParams, extractResidualLayerParams } = extractorsFactory(weightMap, paramMappings);
const conv32_down = extractConvLayerParams('conv32_down');
const conv32_1 = extractResidualLayerParams('conv32_1');
const conv32_2 = extractResidualLayerParams('conv32_2');
const conv32_3 = extractResidualLayerParams('conv32_3');
const conv64_down = extractResidualLayerParams('conv64_down');
const conv64_1 = extractResidualLayerParams('conv64_1');
const conv64_2 = extractResidualLayerParams('conv64_2');
const conv64_3 = extractResidualLayerParams('conv64_3');
const conv128_down = extractResidualLayerParams('conv128_down');
const conv128_1 = extractResidualLayerParams('conv128_1');
const conv128_2 = extractResidualLayerParams('conv128_2');
const conv256_down = extractResidualLayerParams('conv256_down');
const conv256_1 = extractResidualLayerParams('conv256_1');
const conv256_2 = extractResidualLayerParams('conv256_2');
const conv256_down_out = extractResidualLayerParams('conv256_down_out');
const fc = weightMap['fc'];
paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });
if (!utils_1.isTensor2D(fc)) {
2020-08-18 14:04:33 +02:00
throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`);
}
const params = {
conv32_down,
conv32_1,
conv32_2,
conv32_3,
conv64_down,
conv64_1,
conv64_2,
conv64_3,
conv128_down,
conv128_1,
conv128_2,
conv256_down,
conv256_1,
conv256_2,
conv256_down_out,
fc
};
common_1.disposeUnusedWeightTensors(weightMap, paramMappings);
2020-08-18 14:04:33 +02:00
return { params, paramMappings };
}
exports.extractParamsFromWeigthMap = extractParamsFromWeigthMap;
2020-08-18 14:04:33 +02:00
//# sourceMappingURL=extractParamsFromWeigthMap.js.map