face-api/src/common/extractWeightEntryFactory.ts

21 lines
587 B
TypeScript
Raw Normal View History

2020-12-19 17:46:41 +01:00
import { isTensor } from '../utils/index';
2020-08-18 13:54:53 +02:00
import { ParamMapping } from './types';
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
const tensor = weightMap[originalPath]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
paramMappings.push(
{ originalPath, paramPath: mappedPath || originalPath }
)
return tensor
}
}