2020-12-23 18:58:47 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2021-03-19 23:46:36 +01:00
|
|
|
import { disposeUnusedWeightTensors, extractWeightEntryFactory, loadSeparableConvParamsFactory, ParamMapping } from '../common/index';
|
2020-08-18 13:54:53 +02:00
|
|
|
import { loadConvParamsFactory } from '../common/loadConvParamsFactory';
|
2020-12-19 17:46:41 +01:00
|
|
|
import { range } from '../utils/index';
|
2020-08-18 13:54:53 +02:00
|
|
|
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
|
|
|
|
|
|
|
|
function loadParamsFactory(weightMap: any, paramMappings: ParamMapping[]) {
|
2020-12-23 17:26:55 +01:00
|
|
|
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const extractConvParams = loadConvParamsFactory(extractWeightEntry);
|
|
|
|
const extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
function extractReductionBlockParams(mappedPrefix: string): ReductionBlockParams {
|
2020-12-23 17:26:55 +01:00
|
|
|
const separable_conv0 = extractSeparableConvParams(`${mappedPrefix}/separable_conv0`);
|
|
|
|
const separable_conv1 = extractSeparableConvParams(`${mappedPrefix}/separable_conv1`);
|
|
|
|
const expansion_conv = extractConvParams(`${mappedPrefix}/expansion_conv`);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { separable_conv0, separable_conv1, expansion_conv };
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
function extractMainBlockParams(mappedPrefix: string): MainBlockParams {
|
2020-12-23 17:26:55 +01:00
|
|
|
const separable_conv0 = extractSeparableConvParams(`${mappedPrefix}/separable_conv0`);
|
|
|
|
const separable_conv1 = extractSeparableConvParams(`${mappedPrefix}/separable_conv1`);
|
|
|
|
const separable_conv2 = extractSeparableConvParams(`${mappedPrefix}/separable_conv2`);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { separable_conv0, separable_conv1, separable_conv2 };
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return {
|
|
|
|
extractConvParams,
|
|
|
|
extractSeparableConvParams,
|
|
|
|
extractReductionBlockParams,
|
2020-12-23 17:26:55 +01:00
|
|
|
extractMainBlockParams,
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|
|
|
|
|
2021-01-12 16:14:33 +01:00
|
|
|
export function extractParamsFromWeightMap(
|
2020-08-18 13:54:53 +02:00
|
|
|
weightMap: tf.NamedTensorMap,
|
2020-12-23 17:26:55 +01:00
|
|
|
numMainBlocks: number,
|
2020-08-18 13:54:53 +02:00
|
|
|
): { params: TinyXceptionParams, paramMappings: ParamMapping[] } {
|
2020-12-23 17:26:55 +01:00
|
|
|
const paramMappings: ParamMapping[] = [];
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
const {
|
|
|
|
extractConvParams,
|
|
|
|
extractSeparableConvParams,
|
|
|
|
extractReductionBlockParams,
|
2020-12-23 17:26:55 +01:00
|
|
|
extractMainBlockParams,
|
|
|
|
} = loadParamsFactory(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const entry_flow_conv_in = extractConvParams('entry_flow/conv_in');
|
|
|
|
const entry_flow_reduction_block_0 = extractReductionBlockParams('entry_flow/reduction_block_0');
|
|
|
|
const entry_flow_reduction_block_1 = extractReductionBlockParams('entry_flow/reduction_block_1');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
const entry_flow = {
|
|
|
|
conv_in: entry_flow_conv_in,
|
|
|
|
reduction_block_0: entry_flow_reduction_block_0,
|
2020-12-23 17:26:55 +01:00
|
|
|
reduction_block_1: entry_flow_reduction_block_1,
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2024-01-16 18:09:52 +01:00
|
|
|
const middle_flow: Record<`main_block_${number}`, MainBlockParams> = {};
|
2020-08-18 13:54:53 +02:00
|
|
|
range(numMainBlocks, 0, 1).forEach((idx) => {
|
2020-12-23 17:26:55 +01:00
|
|
|
middle_flow[`main_block_${idx}`] = extractMainBlockParams(`middle_flow/main_block_${idx}`);
|
|
|
|
});
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
const exit_flow_reduction_block = extractReductionBlockParams('exit_flow/reduction_block');
|
|
|
|
const exit_flow_separable_conv = extractSeparableConvParams('exit_flow/separable_conv');
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
const exit_flow = {
|
|
|
|
reduction_block: exit_flow_reduction_block,
|
2020-12-23 17:26:55 +01:00
|
|
|
separable_conv: exit_flow_separable_conv,
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
disposeUnusedWeightTensors(weightMap, paramMappings);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { params: { entry_flow, middle_flow, exit_flow }, paramMappings };
|
|
|
|
}
|