2020-10-09 03:31:31 +02:00
|
|
|
import * as tf from '@tensorflow/tfjs';
|
2020-08-18 13:54:53 +02:00
|
|
|
|
|
|
|
import { getModelUris } from '../common/getModelUris';
|
|
|
|
import { fetchJson } from './fetchJson';
|
|
|
|
|
|
|
|
export async function loadWeightMap(
|
|
|
|
uri: string | undefined,
|
|
|
|
defaultModelName: string,
|
|
|
|
): Promise<tf.NamedTensorMap> {
|
|
|
|
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName)
|
2020-08-26 18:39:17 +02:00
|
|
|
let manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri)
|
|
|
|
// if (manifest['weightsManifest']) manifest = manifest['weightsManifest'];
|
2020-08-18 13:54:53 +02:00
|
|
|
return tf.io.loadWeights(manifest, modelBaseUri)
|
|
|
|
}
|