face-api/src/common/extractFCParamsFactory.ts

28 lines
714 B
TypeScript
Raw Normal View History

import * as tf from '../../dist/tfjs.esm.js';
2020-08-18 13:54:53 +02:00
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
export function extractFCParamsFactory(
extractWeights: ExtractWeightsFunction,
2020-12-23 17:26:55 +01:00
paramMappings: ParamMapping[],
2020-08-18 13:54:53 +02:00
) {
2020-12-23 17:26:55 +01:00
return (
2020-08-18 13:54:53 +02:00
channelsIn: number,
channelsOut: number,
2020-12-23 17:26:55 +01:00
mappedPrefix: string,
): FCParams => {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
const fc_bias = tf.tensor1d(extractWeights(channelsOut));
2020-08-18 13:54:53 +02:00
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
2020-12-23 17:26:55 +01:00
{ paramPath: `${mappedPrefix}/bias` },
);
2020-08-18 13:54:53 +02:00
return {
weights: fc_weights,
2020-12-23 17:26:55 +01:00
bias: fc_bias,
};
};
2020-08-18 13:54:53 +02:00
}