face-api/src/common/extractFCParamsFactory.ts

32 lines
708 B
TypeScript
Raw Normal View History

2020-10-09 03:31:31 +02:00
import * as tf from '@tensorflow/tfjs';
2020-08-18 13:54:53 +02:00
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
export function extractFCParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
) {
return function(
channelsIn: number,
channelsOut: number,
mappedPrefix: string
): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
weights: fc_weights,
bias: fc_bias
}
}
}