float 16 support

919 views
Skip to first unread message

dav...@graphcore.ai

unread,
Mar 16, 2017, 9:25:47 AM3/16/17
to XLA development
Hi,

I notice that there are constants for float16 in the XLA code, but that it isn't complete at the moment.  Are there plans to support float16 in the near future?

cheers


Peter Hawkins

unread,
Mar 16, 2017, 10:09:15 AM3/16/17
to dav...@graphcore.ai, XLA development
Float16 support sounds like a great idea (e.g., for newer NVidia GPUs), but I don't believe anyone is actively working on it at the moment. We will get to it eventually, but if you need it soon, contributions are welcome!

It is probably mostly a matter of adding F16 to the set of types to test in, say, the Tensorflow/XLA unit tests, and then making mechanical fixes to things that break.

I suspect the biggest obstacle will be that LLVM may not support float16 computation on most CPUs. Even if your backend doesn't use LLVM, it's necessary to have support in the CPU backend for constant folding, etc. So we would need to add support for cases where the storage format and the computation format are not the same.

Peter

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/52594689-89dd-4158-a362-b399918acd5a%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Chris Leary

unread,
Mar 16, 2017, 11:09:57 AM3/16/17
to Peter Hawkins, David Norman, XLA development
Even NVVM (LLVM backend targeting NVIDIA GPUs) doesn't have "operative" support for fp16s today. The NVIDIA headers have inline assembly that emit fp16 PTX directly. Currently, it seems mostly useful as a more compressed storage format or as a type that the HLO (optimizer) could target via transformation passes if the LLO (optimizer) supports it.

Longer term, we'll probably need to figure out a way to prevent precision-reducing transformations from happening at the TF level in a way that propagates to XLA. E.g. if the user sees a bunch of numerical error happening at a given layer because you used Winograd or whatever, ideally we'd have a way for users to request that we back off on that more numerically aggressive choice.

On Thu, Mar 16, 2017 at 7:07 AM, 'Peter Hawkins' via XLA development <xla...@googlegroups.com> wrote:
Float16 support sounds like a great idea (e.g., for newer NVidia GPUs), but I don't believe anyone is actively working on it at the moment. We will get to it eventually, but if you need it soon, contributions are welcome!

It is probably mostly a matter of adding F16 to the set of types to test in, say, the Tensorflow/XLA unit tests, and then making mechanical fixes to things that break.

I suspect the biggest obstacle will be that LLVM may not support float16 computation on most CPUs. Even if your backend doesn't use LLVM, it's necessary to have support in the CPU backend for constant folding, etc. So we would need to add support for cases where the storage format and the computation format are not the same.

Peter
On Thu, Mar 16, 2017 at 9:25 AM <dav...@graphcore.ai> wrote:
Hi,

I notice that there are constants for float16 in the XLA code, but that it isn't complete at the moment.  Are there plans to support float16 in the near future?

cheers


--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+unsubscribe@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/52594689-89dd-4158-a362-b399918acd5a%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+unsubscribe@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.

dav...@graphcore.ai

unread,
Mar 16, 2017, 11:14:20 AM3/16/17
to XLA development, dav...@graphcore.ai
indeed - thought as much.  our LLVM supports fp16, but doesn't target x86 so it isn't upstreamed.

perhaps constant folding could be done using fp32 and a couple of conversions?


sc...@openai.com

unread,
Mar 17, 2017, 1:25:23 PM3/17/17
to XLA development, phaw...@google.com, dav...@graphcore.ai
At OpenAI we're now starting to use mixed precision quite a bit.  Fp16 on the forward pass and fp32 on the backwards.  I had to replace many of the primitive TF ops but this wasn't too much work.  It would be nice if XLA also supported this.  The basic idea is that you convert to float just after loading.  And convert back to fp16 just before storing.  So I use wrappers like this:

#include <cuda.h>
#include <vector_types.h>
#include <cuda_fp16.h>

__device__ __forceinline__ float4 half2floatV(uint2 v)
{
    float4 r;
    asm("{\n\t"
        ".reg .f16 a, b, c, d;\n\t"
        "mov.b32 {a, b}, %4;\n\t"
        "mov.b32 {c, d}, %5;\n\t"
        "cvt.f32.f16 %0, a;\n\t"
        "cvt.f32.f16 %1, b;\n\t"
        "cvt.f32.f16 %2, c;\n\t"
        "cvt.f32.f16 %3, d;\n\t"
        "}" : "=f"(r.x),"=f"(r.y),"=f"(r.z),"=f"(r.w) : "r"(v.x),"r"(v.y));
    return r;
}
__device__ __forceinline__ uint2 float2halfV(float4 v)
{
    uint2 r;
    asm("{\n\t"
        ".reg .f16 a, b, c, d;\n\t"
        "cvt.rn.f16.f32 a, %2;\n\t"
        "cvt.rn.f16.f32 b, %3;\n\t"
        "cvt.rn.f16.f32 c, %4;\n\t"
        "cvt.rn.f16.f32 d, %5;\n\t"
        "mov.b32 %0, {a, b};\n\t"
        "mov.b32 %1, {c, d};\n\t"
        "}" : "=r"(r.x),"=r"(r.y) : "f"(v.x),"f"(v.y),"f"(v.z),"f"(v.w));
    return r;
}

template <typename TO, typename TI> __device__ __forceinline__ void load(TO &out, const TI* __restrict__ in, int i, bool b);

template <> __device__ __forceinline__ void load<float ,float >(float  &out, const float * __restrict__ in, int i, bool b) 
{ if (b) out = in[i]; }
template <> __device__ __forceinline__ void load<float4,float4>(float4 &out, const float4* __restrict__ in, int i, bool b) 
{ if (b) out = in[i]; }

template <> __device__ __forceinline__ void load<float ,Eigen::half>(float  &out, const Eigen::half* __restrict__ in, int i, bool b) 
{ Eigen::half v; v.x=0; if (b) v = in[i]; out = __half2float((__half)v); }
template <> __device__ __forceinline__ void load<float4,      uint2>(float4 &out, const       uint2* __restrict__ in, int i, bool b) 
{ uint2       v({0,0}); if (b) v = in[i]; out =   half2floatV(v); }

template <typename TO, typename TI> __device__ __forceinline__ void store(TO* out, TI val, int i, bool b);

template <> __device__ __forceinline__ void store<float ,float >(float * out, float  v, int i, bool b) 
{ if (b) out[i] = v; }
template <> __device__ __forceinline__ void store<float4,float4>(float4* out, float4 v, int i, bool b) 
{ if (b) out[i] = v; }

template <> __device__ __forceinline__ void store<Eigen::half,float >(Eigen::half* out, float  v, int i, bool b) 
{ Eigen::half r(__float2half(v)); if (b) out[i] = r; }
template <> __device__ __forceinline__ void store<uint2,      float4>(      uint2* out, float4 v, int i, bool b) 
{ uint2       r(float2halfV(v));  if (b) out[i] = r; }


Note that I don't include the conversion in the conditional.  This makes it easier for the cuda compiler to batch the loads during unrolling.  I'll be releasing this code soonish (just waiting on papers).  This will include support for mixed precision conv and gemm.

You might also include a tensor scale factor in the conversion to support integer types, and even fp16 can sometimes benefit from a re-scaling to avoid under/overflow.  Then you probably want to insert reductions in the kernel to collect stats cheaply.  This lets you predict subsequent scale factors during training.

-Scott

Bjarke Roune

unread,
Mar 17, 2017, 7:01:34 PM3/17/17
to XLA development, phaw...@google.com, dav...@graphcore.ai
Thanks Scott, that's very interesting to hear. Do you have any papers / notes / experiments / blog posts about your use of FP16?

I'd be especially interested in anything on the impact that FP16 versus FP32 has on how well a model trains. E.g. I assume that you are training FP16 forward and FP32 backward because you found that FP16 everywhere didn't work out and I'd be interested to see any documents explaining that conclusion - if you have public documents on that to share. Or definitely feel free to send a link here if you publish something later, I would be interested to see it.

Thank you
Bjarke Hammersholt Roune

Jingyue Wu

unread,
Mar 17, 2017, 7:25:59 PM3/17/17
to XLA development, phaw...@google.com, dav...@graphcore.ai, Justin Lebar
jlebar@, can you comment on LLVM/NVPTX support for fp16? 

I remember you told me it's in much better shape. So maybe we only need to tighten the XLA side? 

Justin Lebar

unread,
Mar 17, 2017, 7:30:06 PM3/17/17
to Jingyue Wu, Artem Belevich, XLA development, Peter Hawkins, dav...@graphcore.ai
+tra

> Even NVVM (LLVM backend targeting NVIDIA GPUs) doesn't have "operative" support for fp16s today.

Thanks to Art, this is no longer true, as of a few weeks ago.

> The NVIDIA headers have inline assembly that emit fp16 PTX directly.

Correct, but we added fp16 support to LLVM specifically so that XLA
could use it. Unfortunately priorities shifted under us and we
haven't been able to hook it up to XLA, and probably won't be able to
for the next few months.

-Justin
>>>> an email to xla-dev+u...@googlegroups.com.
>>>> To post to this group, send email to xla...@googlegroups.com.
>>>> To view this discussion on the web visit
>>>> https://groups.google.com/d/msgid/xla-dev/52594689-89dd-4158-a362-b399918acd5a%40googlegroups.com.
>>>> For more options, visit https://groups.google.com/d/optout.
>>>
>>> --
>>> You received this message because you are subscribed to the Google Groups
>>> "XLA development" group.
>>> To unsubscribe from this group and stop receiving emails from it, send an
>>> email to xla-dev+u...@googlegroups.com.

sc...@openai.com

unread,
Mar 19, 2017, 5:36:44 PM3/19/17
to XLA development, phaw...@google.com, dav...@graphcore.ai

This is an area that still needs a lot of research.   Most of the research that has been done is likely locked behind the various companies building low precision hardware.  I've trained conv nets in both end to fp16 and int16.  I knew that RNNs were trickier to get to work in end to end 16 bits.  But with layernorm I figured the forward activations and quantized weights would be fine in low precision.  You want to keep an fp32 copy of the weights to apply the gradients to so as to get enough mantissa overlap in addition.

Anyway, we're not seeing any loss in accuracy with lower precision in the forward pass.  I'll likely try pushing this to int8.  Being able to use lower precision is nice since it can nearly double your effective memory from having to save less for the backward pass.  Also, RNNs and increasingly CNNs are somewhat bandwidth bound.  So you can get quite a bit of performance boost if you're in this regime.  That's even more helpful if you need to recompute segments of the forward pass to save even more memory for large models.  Then you can also use the low precision instructions in custom kernels for extra speed on the forward pass.

-Scott

Artem Belevich

unread,
Mar 20, 2017, 1:03:03 PM3/20/17
to XLA development
This is a re-post of the reply I sent last week which didn't make it to the list.
-------------------

NVPTX back-end in LLVM has support for f16 and f16x2 (<2 x half>) in LLVM for sm_53+.
For GPUs that don't support fp16, fp16 data gets promoted to fp32 on load and back to fp16 for stores, so IR that uses 'half' type should work on all GPUs (though I suspect there will be differences in precision).

There's currently no clang-side plumbing to replace cuda's inline assembler for fp16 ops with llvm intrinsics nor do we have fp16 type support enabled in clang (though it does, in principle, have support for fp16 as it's used in OpenCL).

I did tinker a bit with enabling fp16 in XLA and it worked for trivial examples. However, I quickly ran into the fact that there are quite a few places where we need to handle Eigen::half before we can do anything more interesting.

--Artem
>>>> To post to this group, send email to xla...@googlegroups.com.
>>>> To view this discussion on the web visit
>>>> https://groups.google.com/d/msgid/xla-dev/52594689-89dd-4158-a362-b399918acd5a%40googlegroups.com.
>>>> For more options, visit https://groups.google.com/d/optout.
>>>
>>> --
>>> You received this message because you are subscribed to the Google Groups
>>> "XLA development" group.
>>> To unsubscribe from this group and stop receiving emails from it, send an

dav...@graphcore.ai

unread,
May 15, 2017, 6:42:52 AM5/15/17
to XLA development
Hmmmm,

I had just added F16 to literal_util and i do a pull and discover that although it hasn't been added, the interface GetMutable..... has changed to use a MutableArraySlice - which is just what I was doing.  Lucky :)

dav...@graphcore.ai

unread,
May 15, 2017, 8:32:28 AM5/15/17
to XLA development
I've created 'https://github.com/tensorflow/tensorflow/pull/9913' in case it is any use.  It works well on our backend.

Peter Hawkins

unread,
May 15, 2017, 10:07:38 AM5/15/17
to dav...@graphcore.ai, XLA development
Thanks for the contribution! It looks good, but I sent you some minor review comments.

Peter

On Mon, May 15, 2017 at 8:32 AM <dav...@graphcore.ai> wrote:
I've created 'https://github.com/tensorflow/tensorflow/pull/9913' in case it is any use.  It works well on our backend.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.
Reply all
Reply to author
Forward
0 new messages