Hi,
Usually bfloat16 training only affects the computation, not the saved weights. This means you should be able to run your bfloat16 model for training and then use it with TensorRT easily.
I'm not familiar with TensorRT, but let's assume you're using an Estimator like mechanism for training. Then you can do something like:
train.py --precision=bfloat16 --steps=1000
To train the model. The way you export to tensorRT is then unchanged: you just have to make sure you disable bfloat16 ops in your model. Here I'm assuming this is a flag:
def export_to_tensorrt(mode=tf.estimator.ModeKeys.TRAIN):
FLAGS.precision = 'float32'
features = tf.placeholder(...)
labels = tf.placeholder(...)
model_fn(features, labels, mode=mode, params={...})
# restore existing checkpoint
with tf.Session() as sess:
saver = tf.Saver()
saver.restore(sess, "/path/to/checkpoint/dir")
# copied from tensorrt tutorial
graphdef = tf.get_default_graph().as_graph_def()
frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
graphdef,
OUTPUT_NAMES)
return tf.graph_util.remove_training_nodes(frozen_graph)
HTH,
R