mirror of https://github.com/vladmandic/human
100 lines
3.6 KiB
JavaScript
100 lines
3.6 KiB
JavaScript
#!/usr/bin/env -S node --no-deprecation --trace-warnings
|
|
|
|
/**
|
|
* Helper app that analyzes any TensorFlow SavedModel or GraphModel for inputs and outputs
|
|
*/
|
|
|
|
const fs = require('fs');
|
|
const path = require('path');
|
|
const log = require('@vladmandic/pilogger');
|
|
const tf = require('@tensorflow/tfjs-node');
|
|
|
|
async function analyzeGraph(modelPath) {
|
|
const model = await tf.loadGraphModel(`file://${modelPath}`);
|
|
log.info('graph model:', path.resolve(modelPath));
|
|
log.info('size:', tf.engine().memory());
|
|
|
|
const inputs = [];
|
|
if (model.modelSignature && model.modelSignature['inputs']) {
|
|
log.info('model inputs based on signature');
|
|
for (const [key, val] of Object.entries(model.modelSignature['inputs'])) {
|
|
const shape = val.tensorShape.dim.map((a) => parseInt(a.size));
|
|
inputs.push({ name: key, dtype: val.dtype, shape });
|
|
}
|
|
// @ts-ignore accessing private property
|
|
} else if (model.executor.graph['inputs']) {
|
|
log.info('model inputs based on executor');
|
|
// @ts-ignore accessing private property
|
|
for (const t of model.executor.graph['inputs']) {
|
|
inputs.push({ name: t.name, dtype: t.attrParams.dtype.value, shape: t.attrParams.shape.value });
|
|
}
|
|
} else {
|
|
log.warn('model inputs: cannot determine');
|
|
}
|
|
|
|
const outputs = [];
|
|
let i = 0;
|
|
if (model.modelSignature && model.modelSignature['outputs'] && Object.values(model.modelSignature['outputs'])[0].dtype) {
|
|
log.info('model outputs based on signature');
|
|
for (const [key, val] of Object.entries(model.modelSignature['outputs'])) {
|
|
const shape = val.tensorShape?.dim.map((a) => parseInt(a.size));
|
|
outputs.push({ id: i++, name: key, dytpe: val.dtype, shape });
|
|
}
|
|
// @ts-ignore accessing private property
|
|
} else if (model.executor.graph['outputs']) {
|
|
log.info('model outputs based on executor');
|
|
// @ts-ignore accessing private property
|
|
for (const t of model.executor.graph['outputs']) {
|
|
outputs.push({ id: i++, name: t.name, dtype: t.attrParams.dtype?.value || t.rawAttrs.T.type, shape: t.attrParams.shape?.value });
|
|
}
|
|
} else {
|
|
log.warn('model outputs: cannot determine');
|
|
}
|
|
|
|
log.data('inputs:', inputs);
|
|
log.data('outputs:', outputs);
|
|
}
|
|
|
|
async function analyzeSaved(modelPath) {
|
|
const meta = await tf.node.getMetaGraphsFromSavedModel(modelPath);
|
|
log.info('saved model:', path.resolve(modelPath));
|
|
const sign = Object.values(meta[0].signatureDefs)[0];
|
|
log.data('tags:', meta[0].tags);
|
|
log.data('signature:', Object.keys(meta[0].signatureDefs));
|
|
const inputs = Object.values(sign.inputs)[0];
|
|
// @ts-ignore a is array
|
|
const inputShape = inputs.shape?.map((a) => a.array[0]);
|
|
log.data('inputs:', { name: inputs.name, dtype: inputs.dtype, shape: inputShape });
|
|
const outputs = [];
|
|
let i = 0;
|
|
for (const [key, val] of Object.entries(sign.outputs)) {
|
|
// @ts-ignore a is array
|
|
const shape = val.shape?.map((a) => a.array[0]);
|
|
outputs.push({ id: i++, name: key, dytpe: val.dtype, shape });
|
|
}
|
|
log.data('outputs:', outputs);
|
|
}
|
|
|
|
async function main() {
|
|
log.header();
|
|
const param = process.argv[2];
|
|
if (process.argv.length !== 3) {
|
|
log.error('path required');
|
|
process.exit(0);
|
|
} else if (!fs.existsSync(param)) {
|
|
log.error(`path does not exist: ${param}`);
|
|
process.exit(0);
|
|
}
|
|
const stat = fs.statSync(param);
|
|
log.data('created on:', stat.birthtime);
|
|
if (stat.isFile()) {
|
|
if (param.endsWith('.json')) analyzeGraph(param);
|
|
}
|
|
if (stat.isDirectory()) {
|
|
if (fs.existsSync(path.join(param, '/saved_model.pb'))) analyzeSaved(param);
|
|
if (fs.existsSync(path.join(param, '/model.json'))) analyzeGraph(path.join(param, '/model.json'));
|
|
}
|
|
}
|
|
|
|
main();
|