main
Vladimir Mandic 2022-09-01 19:15:33 -04:00
parent 0de5c713bb
commit d66ef08dfd
2 changed files with 57 additions and 36 deletions

View File

@ -1,36 +0,0 @@
import os
import sys
import glob
import tensorflow as tf
import tfjs_graph_converter.api as tfjs
graphDir = 'models/'
savedDir = 'saved/'
tfliteDir = 'tflite/'
def main() -> None:
for f in glob.glob(os.path.join(graphDir, '*.json')):
modelName = os.path.basename(f).split('.')[0]
print('graph model: ' + modelName + ' path: ' + f)
savedModel = os.path.join(savedDir, modelName)
try:
tfjs.graph_model_to_saved_model(f, savedModel) # type: ignore
except:
print('saved convert failed')
else:
converter = tf.lite.TFLiteConverter.from_saved_model(savedModel)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ]
converter.target_spec.supported_types = [tf.float16]
tfliteModel = os.path.join(tfliteDir, modelName)
try:
tflite_model = converter.convert()
except:
print('tflite convert failed')
else:
with open(tfliteModel, 'wb') as f:
f.write(tflite_model)
print('saved:' + savedModel + ' tflite: ' + tfliteModel)
if __name__ == '__main__':
main()

57
src/tfjs2tflite.py Normal file
View File

@ -0,0 +1,57 @@
import os
import glob
import tensorflow as tf
import tfjs_graph_converter.api as tfjs
graphDir = 'models/'
savedDir = 'saved/'
tfliteDir = 'tflite/'
def saved2tflite(savedModelDir, tfliteModelName):
if (os.path.isfile(os.path.join(savedModelDir, 'saved_model.pb'))):
converter = tf.lite.TFLiteConverter.from_saved_model(savedModelDir)
converter.optimizations = [tf.lite.Optimize.DEFAULT] # type: ignore
converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ]
# converter.target_spec.experimental_supported_backends = ["CPU", "GPU"]
converter.target_spec.supported_types = [tf.float16]
converter.allow_custom_ops = True
converter.exclude_conversion_metadata = True
try:
tfliteModel = converter.convert()
# tf.lite.experimental.Analyzer.analyze(model_content = tfliteModel)
except:
print(' tflite convert failed')
else:
with open(tfliteModelName, 'wb') as f:
f.write(tfliteModel)
print(' tflite model', tfliteModelName)
else:
print(' tf saved model missing:', savedModelDir)
def tfjs2saved(graphJsonFile, savedModelDir):
if (not os.path.exists(savedModelDir)):
try:
tfjs.graph_model_to_saved_model(graphJsonFile, savedModelDir) # type: ignore
except:
print(' tf saved convert failed:', graphJsonFile)
else:
print(' tf saved model:', savedModelDir)
else:
print(' tf saved model exists:', savedModelDir)
def main():
tf.compat.v1.enable_control_flow_v2()
for graphJsonFile in glob.glob(os.path.join(graphDir, '*.json')):
modelName = os.path.basename(graphJsonFile).split('.')[0]
print('model:', modelName)
print(' tfjs graph model:', graphJsonFile)
savedModelDir = os.path.join(savedDir, modelName)
tfliteModelFile = os.path.join(tfliteDir, modelName) + '.tflite'
tfjs2saved(graphJsonFile, savedModelDir)
saved2tflite(savedModelDir, tfliteModelFile)
if __name__ == '__main__':
main()