2020-12-23 18:58:47 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
|
|
|
|
|
|
|
|
export function extractConvParamsFactory(
|
|
|
|
extractWeights: ExtractWeightsFunction,
|
2020-12-23 17:26:55 +01:00
|
|
|
paramMappings: ParamMapping[],
|
2020-08-18 13:54:53 +02:00
|
|
|
) {
|
2020-12-23 17:26:55 +01:00
|
|
|
return (
|
2020-08-18 13:54:53 +02:00
|
|
|
channelsIn: number,
|
|
|
|
channelsOut: number,
|
|
|
|
filterSize: number,
|
2020-12-23 17:26:55 +01:00
|
|
|
mappedPrefix: string,
|
|
|
|
): ConvParams => {
|
2020-08-18 13:54:53 +02:00
|
|
|
const filters = tf.tensor4d(
|
|
|
|
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
|
2020-12-23 17:26:55 +01:00
|
|
|
[filterSize, filterSize, channelsIn, channelsOut],
|
|
|
|
);
|
|
|
|
const bias = tf.tensor1d(extractWeights(channelsOut));
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
paramMappings.push(
|
|
|
|
{ paramPath: `${mappedPrefix}/filters` },
|
2020-12-23 17:26:55 +01:00
|
|
|
{ paramPath: `${mappedPrefix}/bias` },
|
|
|
|
);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { filters, bias };
|
|
|
|
};
|
2020-08-18 13:54:53 +02:00
|
|
|
}
|