Dynamic shapes

971 views
Skip to first unread message

Bo Bob

unread,
Jan 6, 2021, 11:14:04 PM1/6/21
to XLA development
Greetings xla-team,

Are there plans for dynamic shapes in XLA without recompilation at each change in the dynamic shape?

Thank you

Mehdi AMINI

unread,
Jan 6, 2021, 11:31:10 PM1/6/21
to Bo Bob, XLA development
Hi,

I'm not aware of a direct plan inside the current XLA infrastructure, but folks at Alibaba have been recently sharing some progress on this topic using MLIR to compile HLO: https://llvm.discourse.group/t/updates-on-mlir-based-dynamic-shape-compiler/2384

Best,

-- 
Mehdi



--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/9e1258f3-007a-4ae2-ba49-e181a0c38b0bn%40googlegroups.com.

Sanjoy Das

unread,
Jan 6, 2021, 11:44:39 PM1/6/21
to Mehdi AMINI, George Karpenkov, Yunxing Dai, Bo Bob, XLA development
Hi Bo,

If your shapes have a natural upper bound then XLA's "dynamic padder" might be relevant (pads shapes up to a fixed upper bound).  However, it isn't piped all the way to TensorFlow CPU/GPU at this time.


-- Sanjoy

Bo Bob

unread,
Jan 9, 2021, 9:21:28 PM1/9/21
to XLA development
Are general dynamic shapes a fundamental limitation for XLA and TPUs or is it more a question of priorities?

What is the connection between XLA and MLIR?

Jacques Pienaar

unread,
Jan 11, 2021, 1:06:36 PM1/11/21
to Bo Bob, XLA development
Hey Bo,

It is a bit of both. XLA TPU is an important backend and for faster performance focussed on static/non-general dynamic (in some cases).

MLIR and XLA overlap to some degree, but have different scopes and approaches. We are starting to use MLIR codegen inside XLA (initial focus on buffered memory form, although there is work on tensor side too but the initial staging is on buffered form) as well as in HLO optimizations and the lowering to XLA from TF (rollout WIP).

-- Jacques

Chris Leary

unread,
Jan 11, 2021, 1:21:57 PM1/11/21
to Jacques Pienaar, Yunxing Dai, Matthew Johnson, Bo Bob, XLA development
Hello Bo Bob,

Just to add some personal perspective, I think what we find in practice is that compilation caching amortizes itself quite well for most workloads of interest (ML inference and training are often long running which are the primary things XLA was designed for). "Dynamic padder" support has been developed by +Yunxing Dai & company to clip /outer/ loops so that XLA can still reason about buffer allocation using worst-case buffer sizes (so it can still do nice things like rematerialization and yet specialize for important array bounds that cause big changes in performance). Folks in JAX land have been working on masking as a JAX transform as well, which shares commonalities with dynamic padder support, but is fully implemented in Python userspace using single kernels (which is very useful for potentially "bucketing" dimensions with custom policies in Python userspace). JAX also allows users to specify custom JIT compilation caching policies since it's all in Python userspace. So, high level view (taking the long arc of the years since 2015), I think it's also a question of prioritization given that code generation and compilation caching works quite well at specializing, and often we want to specialize for these big iron workloads to get the highest possible performance. Mostly mention the JAX aspects because we've found having the runtime policies be exposed such that they're easily tweaked by the users when needed (say if their workload is not fixed-size-and-runs-forever-supercomputing-shaped) is a very nice property: separates the concerns of extremely good/fast code generation when things are known (as handled by XLA) and policies of how the user functions map down to the code generation facilities (handled via JAX runtime and function transform facilities).

I know this list is primarily about XLA, but for things-using-XLA perspective +Matthew Johnson who may be able to tell us more about JAX masking transform / give some refs if folks are interested in that!

Cheers,

Chris Leary

Sanjoy Das

unread,
Jan 11, 2021, 2:02:09 PM1/11/21
to Chris Leary, Jacques Pienaar, Yunxing Dai, Matthew Johnson, Bo Bob, XLA development
The problem I see with dynamic padding is that the padded dimensions don't have numpy semantics.  I.e. if I add f32[<=10] with another f32[<=10], this should implicit broadcast, throw an error or do an elementwise add at runtime, depending on the actual value of the dimension.  I don't think the dynamic padder does that today.

Padding is a great lowering strategy though, once all of the implicit broadcasting, error semantics etc. are resolved, we could pad to simplify codegen.

-- Sanjoy

Chris Leary

unread,
Jan 11, 2021, 2:06:17 PM1/11/21
to Sanjoy Das, Jacques Pienaar, Yunxing Dai, Matthew Johnson, Bo Bob, XLA development
Interesting point, Sanjoy. Isn't implicit broadcast what implicitly (har har) happens there? Errors don't exist unless we capture them explicitly in the HLO graph. You /could/ in that case stage GetDynamicDim(x) == GetDynamicDim(y) into the graph to get an error bit out as part of the output, basically staging numpy shape checking into the HLO graph in that case. Since XLA is fatal free you have to place error conditions you want to detect into the HLO computation.

Cheers,

Chris Leary

Yunxing Dai

unread,
Jan 11, 2021, 2:17:16 PM1/11/21
to Chris Leary, Sanjoy Das, Jacques Pienaar, Matthew Johnson, Bo Bob, XLA development
Hi.. 

On dynamic padder side I'll try to fix that issue soon (it's been tracked at b/165921482) by expliciting branching on the size of the dimensions and adding slice and explicit broadcast, we can then use a CSE to statically reduce the number of cases where broadcasting is not needed. I don't think that's the fundamental issue with this padding approach.

What worries me more nowadays are cases where the output bound is not inferrable or is pretty big in theory. One example is run tf.gather on a ragged tensor (https://www.tensorflow.org/api_docs/python/tf/raw_ops/RaggedGather?hl=id), where the output can be O(number of elements * input bound) and is not very practical. It may be a bit easier in JAX as we can just stage out the computation onto the host and trigger recompilation. 

Sanjoy Das

unread,
Jan 11, 2021, 2:53:57 PM1/11/21
to Yunxing Dai, Chris Leary, Jacques Pienaar, Matthew Johnson, Bo Bob, XLA development
On Mon, Jan 11, 2021 at 11:17 AM Yunxing Dai <yun...@google.com> wrote:
Hi.. 

On dynamic padder side I'll try to fix that issue soon (it's been tracked at b/165921482) by expliciting branching on the size of the dimensions and adding slice and explicit broadcast, we can then use a CSE to statically reduce the number of cases where broadcasting is not needed. I don't think that's the fundamental issue with this padding approach.

Yunxing, this is a public list so folks don't have access to b/XXX.  However, the fix is not general IIUC, for instance (quoting from the bug) we won't do the right thing in:

t1 = tf.where([True, False, False])
t2 = tf.where([True, False, False, False])
return t1 + t2


(The bug says we'll throw an error.)

It also isn't clear how we'll deal with side effecting instructions like infeed -- will all of those be wrapped in Conditional so that they don't execute when there is a shape error?

Finally, on GPUs we'll have to move to a combined host/device execution model before we can use this so that we don't have to do frequent host/device syncs.

-- Sanjoy

Chris Leary

unread,
Jan 11, 2021, 3:04:33 PM1/11/21
to Sanjoy Das, Yunxing Dai, Jacques Pienaar, Matthew Johnson, Bo Bob, XLA development
Note that (in/out)feed is are shapeless channels -- if we have a way to serialize the layed-out-shape descriptions as a prefix on the transfer (or, perhaps if we got super fancy about it, only prefixing the pieces of info that were unknown at compile time) the device/host can be made to know how much data to receive, so can be made to work with dynamic quantities of data. It could also support union types (algebraic datatypes) that way. I had that UXLA document (which didn't make it out to open source unfortunately) that covered some of this infeed/outfeed stuff in particular. But insofar as people are trying not to use infeed except in highest-possible-peak-of-performance-at-the-tippity-top scenarios it's probably ok to just error on dynamic shapes that reach infeed/outfeed nodes?

- Chris Leary

Yunxing Dai

unread,
Jan 11, 2021, 3:08:00 PM1/11/21
to Sanjoy Das, Chris Leary, Jacques Pienaar, Matthew Johnson, Bo Bob, XLA development
> (The bug says we'll throw an error.)

This currently throws a compilation error instead of runtime error. We can do more to support this case where the bounds are different for binary operands (e.g., pad shape with smaller bound to larger bound). I'd prefer to do that when a real user hits an error like this due to priorities. 

> Finally, on GPUs we'll have to move to a combined host/device execution model before we can use this so that we don't have to do frequent host/device syncs.

What does this include? Seems like a good idea in general even without dynamic shapes. 


Reply all
Reply to author
Forward
0 new messages