XLA and TF2.0

109 views
Skip to first unread message

Artem Artemev

unread,
Jul 31, 2019, 5:36:41 AM7/31/19
to TensorFlow Community Testing
Hello everyone,

I have a couple of questions about XLA. If this is not an appropriate place for them, could you advise where I should put them.

1. Does XLA work with TF2.0?
2. How to make sure that XLA is turned on?
3. Where to find a proper XLA documentation (switching options, how to write XLA ops and etc.)? Most of the links point to the XLA github source code.




Alexandre Passos

unread,
Jul 31, 2019, 11:36:53 AM7/31/19
to Artem Artemev, Sanjoy Das, Mehdi Amini, TensorFlow Community Testing
+Sanjoy Das +Mehdi Amini for tfxla help

1. Yes
2.  Use the experimental_compile argument to tf.function (added in a nightly last week) to compile a whole function; this is the most intuitive way of enabling XLA, I think. You can also use tf.xla.experimental.compile and tf.xla.experimental.jit_scope for somewhat more fine-grained control but beware that compile is not very user-friendly.
3. The best we have is in tensorflow.org/xla and in the documentation for the symbols listed above, I think.

--
To unsubscribe from this group and stop receiving emails from it, send an email to testing+u...@tensorflow.org.


--
 - Alex

Sanjoy Das

unread,
Jul 31, 2019, 11:59:00 AM7/31/19
to Alexandre Passos, Artem Artemev, Mehdi Amini, TensorFlow Community Testing
You can also enable XLA via "auto-clustering" by calling
config.set_optimizer_jit(True) like this:
https://github.com/tensorflow/tensorflow/blob/170a95de67f266c9fd7fea3ceedc5a7ecb0c80c3/tensorflow/python/framework/config_test.py#L231

We're currently working on enabling this mode by default for GPUs so
that people don't have to even make this one line change.

Note: auto-clustering only helps TF functions (i.e. when the TF
runtime creates and executes a graph) since XLA primarily boosts
performance by fusing and optimizing several TF nodes into a single
unit of execution. This cross-operation visibility is missing in TF
eager mode.

If you want to use auto-clustering on the CPU (as opposed to on the
GPU) you also need to set
XLA_FLAGS (env var) to "--tf_xla_cpu_global_jit".

You can confirm that XLA is kicking in by enabling vlogging and
grepping for log lines from xla_compilation_cache.cc that look like
"compiled cluster_N ...".

I'll also take it as an action item to revamp the tensorflow.org/xla
page a bit. The landing pages need to focus more on how TF users can
use XLA, what to expect instead of XLA internals.


-- Sanjoy

Artem Artemiev

unread,
Jul 31, 2019, 3:55:37 PM7/31/19
to Sanjoy Das, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
Sanjoy, +Alexandre Passos  Thanks a lot for your respose!

Sanjoy, could you tell how to build HLO graph and save it on the disk? I tried XLA_FLAGS="--xla_dump_to="./xla_output_file" --xla_dump_hlo_as_html", but it didn't work.
Sanjoy, +Alexandre Passos I cannot find an example of creating a custom XLA op. I would like to write an optimization for the matrix vector multiplication, so that tensorflow wouldn't store an intermediate full matrix in the memory - only a result of computation. Let's say we have functions f: X ˟ X → M and g: M → V, where M is a  matrix N⨯N and V is a  vector of size N, and N is very big, and N⨯N matrix doesn't fit into the GPU memory. We could optimize the composition of the functions, directly computing the result of (g ∘ f): X ˟ X → V by chunking the computation. In my view, this type of the optimization is one of the "must have"s for XLA.

Alexandre Passos

unread,
Jul 31, 2019, 3:58:32 PM7/31/19
to Artem Artemiev, Sanjoy Das, Mehdi Amini, TensorFlow Community Testing
XLA currently isn't very extensible in the way you'd like to extend it; all optimizations must be implemented in terms of HLOs or (though it's nontrivial to do this from TF) in terms of a customcall operation (but then you lose gradients and a lot of other useful properties).
--
 - Alex

Artem Artemiev

unread,
Jul 31, 2019, 6:01:58 PM7/31/19
to Alexandre Passos, Sanjoy Das, Mehdi Amini, TensorFlow Community Testing
Okay, this is very sad news. Check this project out: https://github.com/getkeops/keops, they do exactly this type of optimizations for `pytorch`.

but then you lose gradients and a lot of other useful properties
yes, gradients should be a part of optimization code.

Sanjoy Das

unread,
Jul 31, 2019, 6:12:39 PM7/31/19
to Artem Artemiev, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
On Wed, Jul 31, 2019 at 12:55 PM Artem Artemiev <i...@artemav.com> wrote:
>
> Sanjoy, +Alexandre Passos Thanks a lot for your respose!
>
> Sanjoy, could you tell how to build HLO graph and save it on the disk? I tried XLA_FLAGS="--xla_dump_to="./xla_output_file" --xla_dump_hlo_as_html", but it didn't work.

It is hard for me to debug this over email, unfortunately. But one
thing stands out -- maybe you need to escape the inner quotes? You
may also want to pass in an absolute path to be sure about the
destination location.

Btw, I usually don't use xla_dump_hlo_as_html; the text format seems
more ergonomic to me. If you just use --xla_dump_to (and no other
flags) then XLA GPU will dump the HLO graphs before and after
optimization, and some other artifacts.

> Sanjoy, +Alexandre Passos I cannot find an example of creating a custom XLA op. I would like to write an optimization for the matrix vector multiplication, so that tensorflow wouldn't store an intermediate full matrix in the memory - only a result of computation. Let's say we have functions f: X ˟ X → M and g: M → V, where M is a ℝ matrix N⨯N and V is a ℝ vector of size N, and N is very big, and N⨯N matrix doesn't fit into the GPU memory. We could optimize the composition of the functions, directly computing the result of (g ∘ f): X ˟ X → V by chunking the computation. In my view, this type of the optimization is one of the "must have"s for XLA.

What you describe is called "fusion" in XLA and XLA does support it
(and is the most important optimization XLA does, as you seem to
suggest), although it may not support the specific variant you
suggest. I'd be happy to review a PR for this though.

Artem Artemiev

unread,
Aug 1, 2019, 6:43:39 AM8/1/19
to Sanjoy Das, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
I'm a bit confused now, can I do this type of optimization by introducing new XLA op or it is just not trivial (impossible)? :)

XLA currently isn't very extensible in the way you'd like to extend it

+Alexandre Passos , what did you have in mind, why do you think it will be hard (impossible) to implement?

What you describe is called "fusion" in XLA and XLA does support it
(and is the most important optimization XLA does, as you seem to
suggest), although it may not support the specific variant you
suggest.  I'd be happy to review a PR for this though.

 +Sanjoy Das, yes, that's fusion operation! Can you guide me to the simple example for new XLA op (e.g. source code with simple/naive implementation of a fusion op). I do want to make a really rough sketch for computing matvec between a distance matrix and a vector.

Thanks!

Alexandre Passos

unread,
Aug 1, 2019, 11:16:18 AM8/1/19
to Artem Artemiev, Sanjoy Das, Mehdi Amini, TensorFlow Community Testing
On Thu, Aug 1, 2019 at 3:43 AM Artem Artemiev <i...@artemav.com> wrote:
I'm a bit confused now, can I do this type of optimization by introducing new XLA op or it is just not trivial (impossible)? :)

XLA currently isn't very extensible in the way you'd like to extend it

+Alexandre Passos , what did you have in mind, why do you think it will be hard (impossible) to implement?

Implementing new atomic operations in XLA is really difficult given its current design, and unlikely to be accepted as a PR, as XLA intends HLO to be a closed set of operations the compiler is deeply aware of (so it'd involve, among other things, changes to the google-private TPU backend to make it work). There is no easy extension point to add a new operation that just does a thing.

Similarly, while the thing you want can be implemented as a fusion, you cannot as a user currently teach XLA how to do new fusions. Either your fusion is expressible in terms of HLO (so you go from a set of HLOs to another set of HLOs) or it needs to be separately implemented for each backend.

Sanjoy / Mehdi: please correct me if I'm wrong, but that's my impression.
--
 - Alex

Sanjoy Das

unread,
Aug 1, 2019, 12:05:22 PM8/1/19
to Alexandre Passos, Artem Artemiev, Mehdi Amini, TensorFlow Community Testing
On Thu, Aug 1, 2019 at 8:16 AM Alexandre Passos <apa...@google.com> wrote:
> On Thu, Aug 1, 2019 at 3:43 AM Artem Artemiev <i...@artemav.com> wrote:
>>
>> I'm a bit confused now, can I do this type of optimization by introducing new XLA op or it is just not trivial (impossible)? :)
>>
>>> XLA currently isn't very extensible in the way you'd like to extend it
>>
>>
>> +Alexandre Passos , what did you have in mind, why do you think it will be hard (impossible) to implement?
>
> Implementing new atomic operations in XLA is really difficult given its current design, and unlikely to be accepted as a PR, as XLA intends HLO to be a closed set of operations the compiler is deeply aware of (so it'd involve, among other things, changes to the google-private TPU backend to make it work). There is no easy extension point to add a new operation that just does a thing.

We do allow "custom calls" in HLO.

We really don't encourage generating it from the "frontend" (i.e. the
TF/XLA bridge) and Alex is right that a PR doing this will get
pushback. However backends are free to generate custom calls when
lowering. For instance in XLA GPU we use custom calls to represent
backwards convolutions. Backwards convolutions are first lowered (by
the TF/XLA bridge) into a pad/reverse/convolution sequence (I don't
remember the exact details) that's mathematically equivalent to a
backwards conv. They're later pattern matched by
cudnn_conv_rewriter.cc to custom calls into targets like
"__cudnn$convBackwardInput" which do the whole sequence in "one step"*
faster and with less memory.

* I'm counting a single call to cudnn as "one step"

In the example you gave, are f: X * X -> M and g: M -> V specific
functions? If yes, I think our handling of cudnn backwards
convolutions would be a good fit for what you're trying to do. Most
of the machinery is in cudnn_conv_rewriter.cc.

If you want to do this more generally (i.e. f and g are general
functions that you know nothing about) then you'll probably have to
use our more general fusion machinery. This is split across several
optimization passes, see all files with "fusion" in their names in
tensorflow/compiler/xla/service/*.

Please feel free to ask additional questions if you have any.

> Similarly, while the thing you want can be implemented as a fusion, you cannot as a user currently teach XLA how to do new fusions. Either your fusion is expressible in terms of HLO (so you go from a set of HLOs to another set of HLOs) or it needs to be separately implemented for each backend.

Backends can implement custom fusions, like
"__cudnn$convBackwardInput" etc. I mentioned above.

We _also_ have a way to generically say: do this sequence of N HLO
ops, but in a "single step" and we use that representation extensively
as well.

-- Sanjoy

Artem Artemiev

unread,
Aug 2, 2019, 7:31:03 PM8/2/19
to Sanjoy Das, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
+Alexandre Passos , +Sanjoy Das  thanks a lot for your responses!


1. I've tried XLA_FLAGS="--xla_dump_to=/tmp/xla_output_file --xla_hlo_profile" ipython, and run this code:

```
In [1]: @tf.function(experimental_compile=True)
   ...: def comp(x, y, z):
   ...:     a = tf.matmul(x, y, transpose_b=True)
   ...:     return a @ z
   ...: z = tf.random.normal((1000, 1))
   ...: y = tf.random.normal((1000, 1))
   ...: x = tf.random.normal((1000, 1))
   ...: _ = comp(x, y, z)
   ...: _ = comp(x, y, z)
```

after running this code the `tmp` folder is empty. Is that a bug or I'm missing something? Here is an example in colab.

MacBook Pro, macOS Mojave 10.14.4
tf-nightly-2.0-preview==2.0.0.dev20190802
Python 3.6.8 :: Anaconda, Inc.

2. 

In the example you gave, are f: X * X -> M and g: M -> V specific
functions?

Let's consider a simplest example: f is an outer product of two vectors X and Y, M = X * Yᵀ and g is a matrix-vector multiplication O = M * V, V is a vector, in the code we always compute f first and then g. In such situations, the obvious optimization would be to make a composition of the operations (g ∘ f) and shuffle the order of the computation, such that instead of memory inefficient (X * Yᵀ) * V, the X * (Yᵀ * V) would be used. By dumping HLO graph, I want to make sure that such an optimization takes place. In case XLA doesn't do this optimization the plan is to implement that feature (outer-matec fusion) first, then move to euclidian distance operations.

Questions: how do I get HLO graph in TF2.0? For an example of XLA implementation, should I look at `cudnn_conv_rewriter.cc`?

Thanks!



Artem Artemiev

unread,
Aug 4, 2019, 7:16:00 AM8/4/19
to Sanjoy Das, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
+Sanjoy Das , here is a shared link to the colab

Artem Artemiev

unread,
Aug 5, 2019, 8:52:21 AM8/5/19
to Sanjoy Das, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing
+Sanjoy Das , +Alexandre Passos I made graphviz plots, finally... In fact, the problem was that I didn't specify the xla_device for computation. Is that assumed that I have to switch to xla_* devices manually? Guess, it is just not clear that after setting the global jit flag in the tensorflow config, I still have to switch to the xla friendly device.

I got two diagrams: after and before XLA optimization for the code which I shared with you in the colab (pdf attached to the email).

At first, the XLA, clearly, doesn't do an optimization which I described earlier: the operations reshuffling to get a lower memory layout. And I don't understand what kind of optimization it does by fusion: if the broadcasting part is happening in memory, according to the diagram the optimized version copies input vectors 1000 times, then this is a very poor optimization in my opinion. The trick though might have an advantage in cases when the result of broadcasting doesn't allocate a lot of memory, e.g. the input vectors are small, but not for big inputs. I'd really like to hear what you ( +Sanjoy Das+Alexandre Passos ) think about it.

after_optimization.pdf
before_optimization.pdf

Sanjoy Das

unread,
Aug 15, 2019, 10:51:56 AM8/15/19
to Artem Artemiev, Alexandre Passos, Mehdi Amini, TensorFlow Community Testing, George Karpenkov
On Mon, Aug 5, 2019 at 5:52 AM Artem Artemiev <i...@artemav.com> wrote:
> +Sanjoy Das , +Alexandre Passos I made graphviz plots, finally... In fact, the problem was that I didn't specify the xla_device for computation. Is that assumed that I have to switch to xla_* devices manually? Guess, it is just not clear that after setting the global jit flag in the tensorflow config, I still have to switch to the xla friendly device.

This is a bit of a gotcha, to enable auto-clustering on XLA CPU you
also have to pass in --tf_xla_cpu_global_jit to TF_XLA_FLAGS

George Karpenkov recently updated the documentation on the tensorflow
website to be clearer about this:
https://www.tensorflow.org/xla#auto-clustering

> I got two diagrams: after and before XLA optimization for the code which I shared with you in the colab (pdf attached to the email).
>
> At first, the XLA, clearly, doesn't do an optimization which I described earlier: the operations reshuffling to get a lower memory layout. And I don't understand what kind of optimization it does by fusion: if the broadcasting part is happening in memory, according to the diagram the optimized version copies input vectors 1000 times, then this is a very poor optimization in my opinion.

None of the intermediate values inside the fusion "box" writes to
memory (only the inputs and outputs for the fusion *as a whole* are in
memory). So we don't copy the input parameter 1000x . I agree with
you that that would be a very poor "optimization". :)

If you're okay reading LLVM IR then you can take a look at the
generated .ll files to check my assertion above. The XLA dump
directory should contain the pre and post optimization .ll files.

-- Sanjoy
Reply all
Reply to author
Forward
0 new messages