face-api/src/common/loadConvParamsFactory.ts

12 lines
420 B
TypeScript
Raw Normal View History

2020-10-09 03:31:31 +02:00
import * as tf from '@tensorflow/tfjs';
2020-08-18 13:54:53 +02:00
import { ConvParams } from './types';
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
return function(prefix: string): ConvParams {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { filters, bias }
}
}