TPU pre-trained model with bfloat 16 transformation to traditional tensorflow floating poing

Skip to first unread message

Aug 13, 2018, 1:54:37 PM8/13/18
to TPU Users

I am interested in converted something trained b-float 16 into normal tensorflow code that can be then parsed by tensorRT. Is there a way of doing that? 
I am lost in all the TPU documentation and would appreciate some pointers.

Thank you!

Russell Power

Aug 13, 2018, 3:37:39 PM8/13/18
to,, Zak Stone

Usually bfloat16 training only affects the computation, not the saved weights.  This means you should be able to run your bfloat16 model for training and then use it with TensorRT easily.

I'm not familiar with TensorRT, but let's assume you're using an Estimator like mechanism for training.  Then you can do something like: --precision=bfloat16 --steps=1000

To train the model.  The way you export to tensorRT is then unchanged: you just have to make sure you disable bfloat16 ops in your model.  Here I'm assuming this is a flag:

def export_to_tensorrt(mode=tf.estimator.ModeKeys.TRAIN):
  FLAGS.precision = 'float32'
  features = tf.placeholder(...)
  labels = tf.placeholder(...)
  model_fn(features, labels, mode=mode, params={...})
  # restore existing checkpoint
  with tf.Session() as sess:
    saver = tf.Saver()
    saver.restore(sess, "/path/to/checkpoint/dir")

  # copied from tensorrt tutorial 
  graphdef = tf.get_default_graph().as_graph_def()
  frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
  return tf.graph_util.remove_training_nodes(frozen_graph)



You received this message because you are subscribed to the Google Groups "TPU Users" group.
To unsubscribe from this group and stop receiving emails from it, send an email to
Visit this group at

Aug 13, 2018, 6:05:35 PM8/13/18
to TPU Users,,

Thank you for your reply. I am new to using pre-trained TPU models and tensorflow, so I might ask basic questions. Anyway, I was trying to create a frozen graph from the checkpoints with the following code:

TPU_MODEL = "/path/to/model_checkpoint.meta"

saver = tf.train.import_meta_graph(TPU_MODEL, clear_devices=True)

graph = tf.get_default_graph()

input_graph_def = graph.as_graph_def()
sess = tf.Session()
saver.restore(sess, "model.ckpt-258931")


output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session
            input_graph_def, # input_graph_def is useful for retrieving the nodes 


with tf.gfile.GFile(output_graph, "wb") as f:


Aug 13, 2018, 6:07:15 PM8/13/18
to TPU Users,,
Sorry I posted accidentally. When I try that which works with other models I get the error that it does not find the keys within the model.  Is that code wrong or is this not supported?

Jonathan Hseu

Aug 13, 2018, 6:16:48 PM8/13/18
to,, Zak Stone
What's the precise error you got?

If it's a float32 graphdef and you trained a model with bfloat16 on TensorFlow 1.8 or earlier, we prepended 'bfloat16' to all the variables which made the checkpoints slightly different. On TF 1.9 and later, we removed that, so the checkpoints are compatible.

You can use this function to list the variables in a checkpoint:

If that's your issue, then you can fix it without retraining by using any of these options:

Pass it to the warm_start_from keyword argument in the TPUEstimator constructor:
2. Manual: looping over all the variables from the list_variables function above and it to the new variable with tf.contrib.framework.load_variable

I think #1 is probably easier, but probably requires building the variable name map from list_variables.

Aug 13, 2018, 7:00:17 PM8/13/18
to TPU Users,,

The error I get is "KeyError: u'InfeedEnqueueTuple'"
 I ran "tf.contrib.framework.list_variables(checkpoint_dir)"
 And this was some of the output is :
('resnet_model/conv2d_31/kernel', [1, 1, 1024, 256]),
 ('resnet_model/conv2d_31/kernel/Momentum', [1, 1, 1024, 256]),
 ('resnet_model/conv2d_32/kernel', [3, 3, 256, 256]),
 ('resnet_model/conv2d_32/kernel/Momentum', [3, 3, 256, 256])

I will try using the first version!

Thank you!

Jonathan Hseu

Aug 13, 2018, 7:49:46 PM8/13/18
to,,, Zak Stone

Jonathan Hseu

Aug 13, 2018, 7:53:03 PM8/13/18
to,, Kevin Tsai,,, Zak Stone

Jonathan Hseu

Aug 13, 2018, 8:08:23 PM8/13/18
to,, Kevin Tsai,,, Zak Stone
Hey Kyle,

Additional things that I mentioned during the call, but it'll be easier to explain them with links:
1. After you export a SavedModel, you'll need to load it and then freeze it. You can load it like this:

2. export_savedmodel has a serving_input_receiver_fn which is basically the input_fn, but for inference:

It's possible you might want to modify it to suit your needs, based on how you want to pass images in:
Reply all
Reply to author
0 new messages