Hi Lu,
Thanks for your reply.
Here is my pb model conversion python code:
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
params=estimator_params,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
# define the input function
def serving_input_fn():
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids})()
return input_fn
estimator._export_to_tpu = False
estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn, checkpoint_path=FLAGS. init_checkpoint)
print("exported pb files can be found in /" + FLAGS.export_dir)
Here is my pb2tflite conversion python code:
def pb2tflite():
saved_model_dir = os.path.join(FLAGS.export_dir, str(cfg['bert_eval']['saved_model_ts']))
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
# converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
# converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(saved_model_dir)
# converter = tf.compat.v1.lite.TocoConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
print("start converting...")
tflite_model = converter.convert()
print("converting finished! start writing the tflite model")
if not tf.io.gfile.exists(FLAGS.tflite_dir):
tf.io.gfile.mkdir(FLAGS.tflite_dir)
# Save the tflite model.
# tflite_file = os.path.join(FLAGS.tflite_dir, 'model.tflite')
tflite_file = os.path.join(FLAGS.tflite_dir, FLAGS.model_dir + '.tflite')
with open(tflite_file, 'wb') as f:
f.write(tflite_model)
I've attached my tflite inference java file: (basically I am referring to this tensorflow example) My tflite model is around 500MB, I attached with google drive link
here. I guess you should have access already to it.
My confusion is that to convert the tensorflow model, I need to feed a dict there. But in the inference, dict is not the basic data type in java, how can I feed basic data in the java? Should I use more convenient way like tflite_maker released by tensorflow to achieve what I want. Thanks.
Best,
Liuyi