bfloat16 GPU performance

581 views
Skip to first unread message

Eugene Kuznetsov

unread,
Jul 13, 2019, 12:08:35 AM7/13/19
to tensor2tensor
I've been trying to get tensor2tensor to work with mixed precision on a Volta GPU (Tesla V100). (Who could say no to 110 TFlops?)

I had to fix a bug in tensorflow and add a number of bfloat16-related ops just to get it to run and not to fall back on CPU all the time (see my issue on github). And, with all that done, it seems to be actually slower with bfloat16 than with fp32. I see 19 s / 100 steps training vanilla fp32 transformer_base, and 27 s / 100 steps training with weight_dtype=bfloat16.

So, I did a GPU trace. Transformer is not a particularly simple model to begin with, but, with bfloat16 enabled, it's something else. Each training step involves approximately 10,000(!) separate GPU calls, nearly half of which have thread configuration (1,1,1). They all add up to a lot of overhead. 

One of the most visible sources of bloat is the peculiar and convoluted way tensor2tensor does conversion from fp32 to bfloat16 in utils/quantization.py. It involves a large number of GPU kernels, and it is performed at least 18 times on different tensors in every step. Dropping this conversion in favor of a simple tf.to_bfloat16() alone takes me from 27 s to 23 s (though I can't say if it hurts accuracy any.) Is that conversion really needed if you're not on a TPU?

I can put more work into this. For example, I can probably code the t2t-flavor conversion directly into tensorflow. But I want to know, first, am I missing something obvious (I thought this was all supposed to be implemented already, we've had GPUs with fp16 tensor cores out for ages), and second, if there's any interest in merging my future work into the project.


Eugene Kuznetsov

unread,
Jul 15, 2019, 3:20:59 PM7/15/19
to tensor2tensor
I've finally noticed that the mixed-precision pull request (#1362) used type float16, not bfloat16. An easy mistake to make, I suppose? And bfloat16 is a completely different type. And I'm still unsure if float16 is the same as CUDA half. It probably is (CUDA docs say that half is the same as IEEE 754-2008 binary16.)

Also, I did some profiling in FP32 mode on a 1080ti.

Even there, only about 50% of the GPU time is spent doing convolutions (the bits that Volta & Turing accelerate really well with FP16). The rest is all over the place. 10% is spent on layer_norm. (I've managed to write a specialized Tensorflow op for layer_norm in CUDA, and it cuts the computational cost of operation to less than half of what it was.) About as much is spent doing dropout. (It needs to fire up the RNG after every layer and go through several kernel calls to perform dropout. If I replace the line "return tf.nn.dropout..." in common_layers.py with "return x", I get a 10% speedup.)

So, the gain from going FP16 on Volta is going to be limited.

Lukasz Kaiser

unread,
Jul 16, 2019, 1:30:34 PM7/16/19
to Eugene Kuznetsov, tensor2tensor
Just to confirm: I never managed personally to get speedups from
float16 on GPUs. There are memory savings from bfloat16 on TPUs and
some speedups, but it's the memory part that mattered most to me.

Lukasz
> --
> You received this message because you are subscribed to the Google Groups "tensor2tensor" group.
> To unsubscribe from this group and stop receiving emails from it, send an email to tensor2tenso...@googlegroups.com.
> To post to this group, send email to tensor...@googlegroups.com.
> To view this discussion on the web visit https://groups.google.com/d/msgid/tensor2tensor/e7ce140b-dca3-4c09-b3b3-47b571f41bb7%40googlegroups.com.
> For more options, visit https://groups.google.com/d/optout.

Eugene Kuznetsov

unread,
Jul 17, 2019, 1:06:15 AM7/17/19
to tensor2tensor
Okay, with float16, I do see a speedup.

FP32: ~19 s
FP16 baseline: 14.9 s
CUDA layer_norm: 13.0 s
CUDA layer_norm & no dropout: 12.1 s

All figures are for translation using transformer_base on Google Cloud n1-standard-4 + 1x Nvidia Tesla V100, with FP16 activated via

--hparams activation_dtype=float16

I have to use Adam optimizer, because it has an optimized GPU resource apply op. Adafactor resource apply op compiles into lots of tiny kernels and there's a visible performance impact.

With weight_dtype=float16, I get an instant "model diverged with NaN", and we probably don't need to go there because setting activation_dtype=float16 is enough to force all convolutions into fp16.

The biggest bottleneck at that point is tf.scatter_nd() in utils/expert_utils.py, method PadRemover::restore(). It alone takes around 10% of the time. (Interestingly, its opposite number tf.gather_nd() in PadRemover::remove() is very fast.) Each call to tf.scatter_nd() takes almost 2 ms, and there are 6 of them in a step, and, when the entire step is 120-130 ms, that's a lot.  This seems to be a GPU memory issue, possibly worsened by an inefficient kernel implementation. It is actually faster (about 2x) to cast the input of scatter_nd() to float32, do scatter, then cast it back to float16, than to do the op in float16 direct.
> To unsubscribe from this group and stop receiving emails from it, send an email to tensor...@googlegroups.com.

Lukasz Kaiser

unread,
Jul 18, 2019, 12:25:06 PM7/18/19
to Eugene Kuznetsov, tensor2tensor
Cool speedups, thanks for letting us know!!

Lukasz
> To unsubscribe from this group and stop receiving emails from it, send an email to tensor2tenso...@googlegroups.com.
> To post to this group, send email to tensor...@googlegroups.com.
> To view this discussion on the web visit https://groups.google.com/d/msgid/tensor2tensor/ce32ace7-284b-4934-a95a-cc57739710c8%40googlegroups.com.

Eugene Kuznetsov

unread,
Jul 21, 2019, 5:50:04 PM7/21/19
to tensor2tensor
There's just one problem. Beam decoding does not work with activation_type=float16 out of the box :( I get all sorts of type mismatch errors, and debugging them is hell (half of the time, it does not actually tell you what line of code triggers the mismatch). Took me at least two hours to find a way to work around

"Input 1 of node transformer/while/Merge_14 was passed half from transformer/while/NextIteration_14:0 incompatible with expected float."

In the end, I came up with the set of casts necessary for the decoding to work, though some choices were really ugly (e.g. see the first change to beam_search.py in https://github.com/ekuznetsov139/tensor2tensor/commit/5aa0058fd99c99b38c89a0c45c2101192a91ecb5 ) and I might have broken some things in process (I'll try to test it on a TPU).

Eugene Kuznetsov

unread,
Jul 22, 2019, 2:08:29 AM7/22/19
to tensor2tensor
Found another fairly significant optimization.

It turns out that, if you add a dense layer to Tensorflow with the input that is >2D and whose dimensions are not fully defined at graph generation time, the framework inserts several extra GPU operations into the graph (two gathers, one reduce_prod, and four host<->device memcpy's). See function _tensordot_reshape() in tensorflow/math_ops.py. 

Fortunately, there is a way to circumvent it:
https://github.com/ekuznetsov139/tensor2tensor/commit/079f51f2c736d547cbd2b4bcfdcedb039827b93a
It's rather hacky, but it does work. With this change, the GPU trace is substantially simplified (the trace for a forward pass through one transformer_encoder layer shrinks from 72 calls to 44 calls), and FP16 training time per step is down at least 15%. It is going to help with FP32 too.

I am now just under 10 s / 100 steps on a V100. The trace is still very complex, but it is now down to "only" 3100 launches per step.

If anyone's curious, here's a typical trace: https://docs.google.com/spreadsheets/d/1md2CB-X3QrSB2Kz4svdExm3gRtY-DnGjxGxoK6srl2Q/edit?usp=sharing
Reply all
Reply to author
Forward
0 new messages