face-api/src/common/extractConvParamsFactory.ts

29 lines
766 B
TypeScript

import * as tf from '../../dist/tfjs.esm';
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[],
) {
return (
channelsIn: number,
channelsOut: number,
filterSize: number,
mappedPrefix: string,
): ConvParams => {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut],
);
const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` },
);
return { filters, bias };
};
}