Cost of xla::Reshape

54 views
Skip to first unread message

Joel

unread,
Jun 20, 2022, 8:10:24 AM6/20/22
to XLA development
Hi XLA devs,

How does xla::Reshape work? Does it copy and rearrange the data (sounds expensive), or just modify some "metadata" table of where each value is in memory (sounds cheap)?

I ask because I want to apply a function with shapes s -> s' over data with shape (leading + s) to get data with shape (leading + s'). I'm considering a number of approaches, but it would help a lot to know whether it's expensive to first flatten leading dimensions to ([product leading] + s), do some other stuff, then reshape the output again to ([product leading] + s').

Thanks,
Joel

George Karpenkov

unread,
Jun 20, 2022, 10:09:15 AM6/20/22
to Joel, XLA development
hi Joel,

It depends on whether it can fuse the reshape or not, but normally it should be free.

> whether it's expensive to first flatten leading dimensions to ([product leading] + s), do some other stuff, then reshape the output again to ([product leading] + s').

That should be fine* (*depends on what exactly XLA ends up fusing)



--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/4cfa6acb-875f-467d-9a67-e46198a66316n%40googlegroups.com.

Joel

unread,
Jun 20, 2022, 11:17:10 AM6/20/22
to XLA development
Thanks. That's really useful to know. Is there some way to find out if an operation is being fused? Like some compiled code or the like?

In my current case, I'm running an xla::While loop over the ([product leading] + s)

George Karpenkov

unread,
Jun 20, 2022, 12:04:57 PM6/20/22
to Joel, XLA development
You can look at after-optimizations HLO. While-loops on e.g. GPU tend to be very slow and unfused.

Joel

unread,
Jun 22, 2022, 9:01:46 AM6/22/22
to XLA development
> While-loops on e.g. GPU tend to be very slow

That's curious, because I thought while loops were intended for use in optimizers e.g. adam. Perhaps there's just no alternative in that scenario?

Leary, Chris

unread,
Jun 27, 2022, 9:48:05 PM6/27/22
to Joel, XLA development
Hi Joel,

I think maybe the key clarification here is that XLA's while is not in the "inner loop" of an optimizer function like 'adam' -- what George is referring to is the fact you can put a "while" loop around something like a neural net training step so that you could say things in XLA like "run this neural net step until the loss is < N" or something like "until eval accuracy is > N".

If you look at what IMO is a very simple/clean model builder like stax inside of JAX you can see this line here:


  opt_state = opt_init(init_params)
  for i in range(num_steps):  # this could be an XLA while loop instead
    opt_state = update(i, opt_state, next(batches))
  trained_params = get_params(opt_state)  

In this example first you ask the optimizer "hey can I get your initialization state", then you pump the neural net step function num_steps times as you shovel minibatches in. Each of these takes your optimizer state to optimizer state' which corresponds to an update of the parameters according to the optimizer function and any aux data for your optimizer. The optimizer is usually a pretty simple set of HLOs that are in their vector form on the parameters, e.g. you can see the multi-dimensional array ops in the case of adam here: https://github.com/google/jax/blob/main/jax/example_libraries/optimizers.py#L414

Here's where XLA "while" can come in -- the Python-implemented `for` in the snippet above could be replaced with the XLA while loop with the kind of coarse-grained condition I mentioned. On GPU the while loop implementation "just" pumps the set of kernels that make up the "body" of the while to execute on a stream. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/while_thunk.cc#L75 Though hypothetically exposing more computation to the XLA high level optimizer could lead to better decisions in optimization and buffer allocation, launching from the host side as we see in the above snippet is usually a good default.

On a device like TPU where the execution stays resident on the device (instead of needing the host to pump it via a data-parallel-kernel-launch like mechanism), there's the ability to consume/preduce I/O from the host within the while loop, i.e. the device can pop the host-given "infeed" of a minibatch, run the XLA while body, and only terminate when the condition was met, without needing the host to hand-hold the device via kernel launches.

There's a class of models where this doesn't matter that much; e.g. trapping to the host to ask "should I keep going" and it checking a word in device memory and saying "yes please do" and shoveling the next minibatch in could be a) much smaller latency than the update() execution time and b) more convenient. But there are at least some times when that's not the case, and XLA wanted to handle those -- it's effectively a way of turning the XLA device computation which could be a simple function instead into a co-routine that can run indefinitely, and being able to do (pipelined) I/O to it from the outside world. (That being said, launching computations and doing I/O transfers can also be pipelined in schemes independent of a coroutine-and-queues style model, but this packaged it all together in a way that the compiler could reason about the memory requirements.)

Here's an example in JAX of infeeding to a loop in a device function that feeds back out just as a proof of concept for what I'm talking about: https://github.com/google/jax/blob/de464fcf22c8bf7a2931182f8095bc01530df9fe/tests/infeed_test.py#L105 Like in the Python snippet above though, you could use host code to implement the while loop and launch XLA computations in a row using the host. I think de-facto that's the preferred way when you're not sensitive to this latency we're talking about, since then you get the full capability of the host to do whatever you want in-between your coarse grained XLA computations.

Random / deep "could be cool research" aside: I expect a persistent-kernel / infeed mechanism could be implemented on GPU devices as well and shave off some host round trips, but kernel launches are also an opportunity for the GPU to "reconfigure" itself for residency -- shared memory amounts, 2D register file access pattern (registers per thread), threads per block blocks in the grid -- so might need to implement something fancy like a "control" kernel (which could e.g. wait for infeed data to be present) that did sub-launches via a "dynamic parallelism" style feature. That being said, it'd probably be hard with the number of libraries we use on GPUs, host guidance of the kernel-launch process is a pervasive assumption in the ecosystem, and I don't think dynamic parallelism has a lower latency kernel launch than the host, it'd just be shaving off the round trips and any jank for the host to see kernel completion / dispatch. However, if people know any existing refs to work that builds this kind of system and provides their insights I'd be interested to read up!

Cheers,

- Leary

Joel

unread,
Jun 30, 2022, 10:09:00 AM6/30/22
to XLA development
Thanks for the thorough response. I'll need some time to digest it
Reply all
Reply to author
Forward
0 new messages