2020-12-23 18:58:47 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-19 17:46:41 +01:00
|
|
|
import { disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
|
|
|
|
import { isTensor2D } from '../utils/index';
|
2021-01-24 17:08:04 +01:00
|
|
|
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
|
2020-12-23 17:26:55 +01:00
|
|
|
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
function extractScaleLayerParams(prefix: string): ScaleLayerParams {
|
2020-12-23 17:26:55 +01:00
|
|
|
const weights = extractWeightEntry(`${prefix}/scale/weights`, 1);
|
|
|
|
const biases = extractWeightEntry(`${prefix}/scale/biases`, 1);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { weights, biases };
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
function extractConvLayerParams(prefix: string): ConvLayerParams {
|
2020-12-23 17:26:55 +01:00
|
|
|
const filters = extractWeightEntry(`${prefix}/conv/filters`, 4);
|
|
|
|
const bias = extractWeightEntry(`${prefix}/conv/bias`, 1);
|
|
|
|
const scale = extractScaleLayerParams(prefix);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { conv: { filters, bias }, scale };
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
function extractResidualLayerParams(prefix: string): ResidualLayerParams {
|
|
|
|
return {
|
|
|
|
conv1: extractConvLayerParams(`${prefix}/conv1`),
|
2020-12-23 17:26:55 +01:00
|
|
|
conv2: extractConvLayerParams(`${prefix}/conv2`),
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return {
|
|
|
|
extractConvLayerParams,
|
2020-12-23 17:26:55 +01:00
|
|
|
extractResidualLayerParams,
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
2021-01-12 16:14:33 +01:00
|
|
|
export function extractParamsFromWeightMap(
|
2020-12-23 17:26:55 +01:00
|
|
|
weightMap: tf.NamedTensorMap,
|
2020-08-18 13:54:53 +02:00
|
|
|
): { params: NetParams, paramMappings: ParamMapping[] } {
|
2020-12-23 17:26:55 +01:00
|
|
|
const paramMappings: ParamMapping[] = [];
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
const {
|
|
|
|
extractConvLayerParams,
|
2020-12-23 17:26:55 +01:00
|
|
|
extractResidualLayerParams,
|
|
|
|
} = extractorsFactory(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const conv32_down = extractConvLayerParams('conv32_down');
|
|
|
|
const conv32_1 = extractResidualLayerParams('conv32_1');
|
|
|
|
const conv32_2 = extractResidualLayerParams('conv32_2');
|
|
|
|
const conv32_3 = extractResidualLayerParams('conv32_3');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const conv64_down = extractResidualLayerParams('conv64_down');
|
|
|
|
const conv64_1 = extractResidualLayerParams('conv64_1');
|
|
|
|
const conv64_2 = extractResidualLayerParams('conv64_2');
|
|
|
|
const conv64_3 = extractResidualLayerParams('conv64_3');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const conv128_down = extractResidualLayerParams('conv128_down');
|
|
|
|
const conv128_1 = extractResidualLayerParams('conv128_1');
|
|
|
|
const conv128_2 = extractResidualLayerParams('conv128_2');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const conv256_down = extractResidualLayerParams('conv256_down');
|
|
|
|
const conv256_1 = extractResidualLayerParams('conv256_1');
|
|
|
|
const conv256_2 = extractResidualLayerParams('conv256_2');
|
|
|
|
const conv256_down_out = extractResidualLayerParams('conv256_down_out');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const { fc } = weightMap;
|
|
|
|
paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
if (!isTensor2D(fc)) {
|
2020-12-23 17:26:55 +01:00
|
|
|
throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`);
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
const params = {
|
|
|
|
conv32_down,
|
|
|
|
conv32_1,
|
|
|
|
conv32_2,
|
|
|
|
conv32_3,
|
|
|
|
conv64_down,
|
|
|
|
conv64_1,
|
|
|
|
conv64_2,
|
|
|
|
conv64_3,
|
|
|
|
conv128_down,
|
|
|
|
conv128_1,
|
|
|
|
conv128_2,
|
|
|
|
conv256_down,
|
|
|
|
conv256_1,
|
|
|
|
conv256_2,
|
|
|
|
conv256_down_out,
|
2020-12-23 17:26:55 +01:00
|
|
|
fc,
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
disposeUnusedWeightTensors(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { params, paramMappings };
|
|
|
|
}
|