I've been trying to get tensor2tensor to work with mixed precision on a Volta GPU (Tesla V100). (Who could say no to 110 TFlops?)
I had to fix a bug in tensorflow and add a number of bfloat16-related ops just to get it to run and not to fall back on CPU all the time (see my issue on github). And, with all that done, it seems to be actually slower with bfloat16 than with fp32. I see 19 s / 100 steps training vanilla fp32 transformer_base, and 27 s / 100 steps training with weight_dtype=bfloat16.
So, I did a GPU trace. Transformer is not a particularly simple model to begin with, but, with bfloat16 enabled, it's something else. Each training step involves approximately 10,000(!) separate GPU calls, nearly half of which have thread configuration (1,1,1). They all add up to a lot of overhead.
One of the most visible sources of bloat is the peculiar and convoluted way tensor2tensor does conversion from fp32 to bfloat16 in utils/quantization.py. It involves a large number of GPU kernels, and it is performed at least 18 times on different tensors in every step. Dropping this conversion in favor of a simple tf.to_bfloat16() alone takes me from 27 s to 23 s (though I can't say if it hurts accuracy any.) Is that conversion really needed if you're not on a TPU?
I can put more work into this. For example, I can probably code the t2t-flavor conversion directly into tensorflow. But I want to know, first, am I missing something obvious (I thought this was all supposed to be implemented already, we've had GPUs with fp16 tensor cores out for ages), and second, if there's any interest in merging my future work into the project.