face-api/src/xception/extractParams.ts

81 lines
3.4 KiB
TypeScript
Raw Normal View History

2020-12-19 17:46:41 +01:00
import { extractConvParamsFactory, extractSeparableConvParamsFactory, extractWeightsFactory } from '../common/index';
2020-08-18 13:54:53 +02:00
import { ExtractWeightsFunction, ParamMapping } from '../common/types';
2020-12-19 17:46:41 +01:00
import { range } from '../utils/index';
2020-08-18 13:54:53 +02:00
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
2020-12-23 17:26:55 +01:00
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
2020-08-18 13:54:53 +02:00
function extractReductionBlockParams(channelsIn: number, channelsOut: number, mappedPrefix: string): ReductionBlockParams {
2020-12-23 17:26:55 +01:00
const separable_conv0 = extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/separable_conv0`);
const separable_conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/separable_conv1`);
const expansion_conv = extractConvParams(channelsIn, channelsOut, 1, `${mappedPrefix}/expansion_conv`);
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
return { separable_conv0, separable_conv1, expansion_conv };
2020-08-18 13:54:53 +02:00
}
function extractMainBlockParams(channels: number, mappedPrefix: string): MainBlockParams {
2020-12-23 17:26:55 +01:00
const separable_conv0 = extractSeparableConvParams(channels, channels, `${mappedPrefix}/separable_conv0`);
const separable_conv1 = extractSeparableConvParams(channels, channels, `${mappedPrefix}/separable_conv1`);
const separable_conv2 = extractSeparableConvParams(channels, channels, `${mappedPrefix}/separable_conv2`);
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
return { separable_conv0, separable_conv1, separable_conv2 };
2020-08-18 13:54:53 +02:00
}
return {
extractConvParams,
extractSeparableConvParams,
extractReductionBlockParams,
2020-12-23 17:26:55 +01:00
extractMainBlockParams,
};
2020-08-18 13:54:53 +02:00
}
export function extractParams(weights: Float32Array, numMainBlocks: number): { params: TinyXceptionParams, paramMappings: ParamMapping[] } {
2020-12-23 17:26:55 +01:00
const paramMappings: ParamMapping[] = [];
2020-08-18 13:54:53 +02:00
const {
extractWeights,
2020-12-23 17:26:55 +01:00
getRemainingWeights,
} = extractWeightsFactory(weights);
2020-08-18 13:54:53 +02:00
const {
extractConvParams,
extractSeparableConvParams,
extractReductionBlockParams,
2020-12-23 17:26:55 +01:00
extractMainBlockParams,
} = extractorsFactory(extractWeights, paramMappings);
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const entry_flow_conv_in = extractConvParams(3, 32, 3, 'entry_flow/conv_in');
const entry_flow_reduction_block_0 = extractReductionBlockParams(32, 64, 'entry_flow/reduction_block_0');
const entry_flow_reduction_block_1 = extractReductionBlockParams(64, 128, 'entry_flow/reduction_block_1');
2020-08-18 13:54:53 +02:00
const entry_flow = {
conv_in: entry_flow_conv_in,
reduction_block_0: entry_flow_reduction_block_0,
2020-12-23 17:26:55 +01:00
reduction_block_1: entry_flow_reduction_block_1,
};
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const middle_flow = {};
2020-08-18 13:54:53 +02:00
range(numMainBlocks, 0, 1).forEach((idx) => {
2020-12-23 17:26:55 +01:00
middle_flow[`main_block_${idx}`] = extractMainBlockParams(128, `middle_flow/main_block_${idx}`);
});
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const exit_flow_reduction_block = extractReductionBlockParams(128, 256, 'exit_flow/reduction_block');
const exit_flow_separable_conv = extractSeparableConvParams(256, 512, 'exit_flow/separable_conv');
2020-08-18 13:54:53 +02:00
const exit_flow = {
reduction_block: exit_flow_reduction_block,
2020-12-23 17:26:55 +01:00
separable_conv: exit_flow_separable_conv,
};
2020-08-18 13:54:53 +02:00
if (getRemainingWeights().length !== 0) {
2020-12-23 17:26:55 +01:00
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
2020-08-18 13:54:53 +02:00
}
return {
paramMappings,
2020-12-23 17:26:55 +01:00
params: { entry_flow, middle_flow, exit_flow },
};
}