28 lines
714 B
TypeScript
28 lines
714 B
TypeScript
import * as tf from '../../dist/tfjs.esm.js';
|
|
|
|
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
|
|
|
|
export function extractFCParamsFactory(
|
|
extractWeights: ExtractWeightsFunction,
|
|
paramMappings: ParamMapping[],
|
|
) {
|
|
return (
|
|
channelsIn: number,
|
|
channelsOut: number,
|
|
mappedPrefix: string,
|
|
): FCParams => {
|
|
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
|
|
const fc_bias = tf.tensor1d(extractWeights(channelsOut));
|
|
|
|
paramMappings.push(
|
|
{ paramPath: `${mappedPrefix}/weights` },
|
|
{ paramPath: `${mappedPrefix}/bias` },
|
|
);
|
|
|
|
return {
|
|
weights: fc_weights,
|
|
bias: fc_bias,
|
|
};
|
|
};
|
|
}
|