2020-12-02 22:46:41 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
import { ConvParams } from './types';
|
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
// eslint-disable-next-line no-unused-vars
|
2020-08-18 13:54:53 +02:00
|
|
|
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
|
2020-12-23 17:26:55 +01:00
|
|
|
return (prefix: string): ConvParams => {
|
|
|
|
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4);
|
|
|
|
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
|
2020-08-18 13:54:53 +02:00
|
|
|
|
2020-12-23 17:26:55 +01:00
|
|
|
return { filters, bias };
|
|
|
|
};
|
|
|
|
}
|