From d66ef08dfdbe23f91fa93f0aacc40ab02b3323fc Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Thu, 1 Sep 2022 19:15:33 -0400 Subject: [PATCH] update --- src/tfjs-to-saved.py | 36 ---------------------------- src/tfjs2tflite.py | 57 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 36 deletions(-) delete mode 100644 src/tfjs-to-saved.py create mode 100644 src/tfjs2tflite.py diff --git a/src/tfjs-to-saved.py b/src/tfjs-to-saved.py deleted file mode 100644 index ee80931..0000000 --- a/src/tfjs-to-saved.py +++ /dev/null @@ -1,36 +0,0 @@ -import os -import sys -import glob -import tensorflow as tf -import tfjs_graph_converter.api as tfjs - -graphDir = 'models/' -savedDir = 'saved/' -tfliteDir = 'tflite/' - -def main() -> None: - for f in glob.glob(os.path.join(graphDir, '*.json')): - modelName = os.path.basename(f).split('.')[0] - print('graph model: ' + modelName + ' path: ' + f) - savedModel = os.path.join(savedDir, modelName) - try: - tfjs.graph_model_to_saved_model(f, savedModel) # type: ignore - except: - print('saved convert failed') - else: - converter = tf.lite.TFLiteConverter.from_saved_model(savedModel) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] - converter.target_spec.supported_types = [tf.float16] - tfliteModel = os.path.join(tfliteDir, modelName) - try: - tflite_model = converter.convert() - except: - print('tflite convert failed') - else: - with open(tfliteModel, 'wb') as f: - f.write(tflite_model) - print('saved:' + savedModel + ' tflite: ' + tfliteModel) - -if __name__ == '__main__': - main() diff --git a/src/tfjs2tflite.py b/src/tfjs2tflite.py new file mode 100644 index 0000000..e2d17e7 --- /dev/null +++ b/src/tfjs2tflite.py @@ -0,0 +1,57 @@ +import os +import glob +import tensorflow as tf +import tfjs_graph_converter.api as tfjs + +graphDir = 'models/' +savedDir = 'saved/' +tfliteDir = 'tflite/' + +def saved2tflite(savedModelDir, tfliteModelName): + if (os.path.isfile(os.path.join(savedModelDir, 'saved_model.pb'))): + converter = tf.lite.TFLiteConverter.from_saved_model(savedModelDir) + converter.optimizations = [tf.lite.Optimize.DEFAULT] # type: ignore + converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] + # converter.target_spec.experimental_supported_backends = ["CPU", "GPU"] + converter.target_spec.supported_types = [tf.float16] + converter.allow_custom_ops = True + converter.exclude_conversion_metadata = True + try: + tfliteModel = converter.convert() + # tf.lite.experimental.Analyzer.analyze(model_content = tfliteModel) + except: + print(' tflite convert failed') + else: + with open(tfliteModelName, 'wb') as f: + f.write(tfliteModel) + print(' tflite model', tfliteModelName) + else: + print(' tf saved model missing:', savedModelDir) + + +def tfjs2saved(graphJsonFile, savedModelDir): + if (not os.path.exists(savedModelDir)): + try: + tfjs.graph_model_to_saved_model(graphJsonFile, savedModelDir) # type: ignore + except: + print(' tf saved convert failed:', graphJsonFile) + else: + print(' tf saved model:', savedModelDir) + else: + print(' tf saved model exists:', savedModelDir) + + +def main(): + tf.compat.v1.enable_control_flow_v2() + for graphJsonFile in glob.glob(os.path.join(graphDir, '*.json')): + modelName = os.path.basename(graphJsonFile).split('.')[0] + print('model:', modelName) + print(' tfjs graph model:', graphJsonFile) + savedModelDir = os.path.join(savedDir, modelName) + tfliteModelFile = os.path.join(tfliteDir, modelName) + '.tflite' + tfjs2saved(graphJsonFile, savedModelDir) + saved2tflite(savedModelDir, tfliteModelFile) + + +if __name__ == '__main__': + main()