Dynamism RFC

295 views
Skip to first unread message

Eugene Burmako

unread,
Jul 4, 2023, 9:50:13 PM7/4/23
to OpenXLA Discuss
Hi everyone,

The current design of dynamism in MHLO and StableHLO has been practically useful. There are success stories of it connecting JAX, PyTorch and TensorFlow to a number of compilers, in a mix of research and production environments. This RFC aims to leverage existing support for dynamism in the StableHLO dialect, discuss improvements to the existing design and then formalize the improved design in the StableHLO opset.

The main challenge with writing this RFC was that it affects the entire opset. The current design involves a lot of corner cases, so it took about a year of practical evaluation by the author - primarily within JAX native serialization, but also correlating with other ongoing and prior projects - to distill the design into just a few general design principles.

Finally, I'd like to acknowledge Smit Hinsu's work on the Bounded Dynamism RFC from Q4 2022, which was superseded by this work. The representation for bounded dynamic types in the StableHLO dialect was designed and implemented by Smit, and Smit's proposal to allow bounded dynamic types everywhere is compatible with the more general proposal from this RFC to enable dynamism for all size-related program elements. Furthermore, Smit contributed the formal spec for get_dimension_size as well as the informal spec for set_dimension_size.

PTAL at the RFC: https://github.com/openxla/stablehlo/pull/1668.

Cheers,
Eugene

Michael Levesque-Dion

unread,
Mar 21, 2024, 4:09:55 PMMar 21
to OpenXLA Discuss, Eugene Burmako

Hi everyone,


Thank you for all the discussion on the RFC! We ended up tying off remaining feedback in early February: https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md.


Several of the proposals have already been implemented. Unranked tensors have been removed from StableHLO, which allowed a good amount of cleanup, and more cleanup is coming.


We're working on adding specs, reference implementations, and revisiting verifiers for dynamic ops, as well as considering what new dynamic ops are missing (P4).


While integrating P5’s "represent shape computations as StableHLO operations on variadic 0D tensors", we noticed that the returns are lower than we thought and the costs are high. Following up with frameworks teams targeting dynamic ops, we believe the purported benefits aren't worth the breakages to community code (MLIR, C++, Python). Therefore, we are proposing to continue using 1D tensors to represent shapes in dynamic ops, which e.g. facilitates interop with shape dialect operations and makes DRR patterns easier to write. I have sent a PR to update the RFC: https://github.com/openxla/stablehlo/pull/2116. I will keep the PR open for one week to allow feedback from the community.


Best,


Michael

Michael Levesque-Dion

unread,
Apr 1, 2024, 3:02:29 PMApr 1
to OpenXLA Discuss, Michael Levesque-Dion, Eugene Burmako
We got the following piece of feedback from Stella....@amd.com (the feedback was meant to be sent here, but somehow didn't make it):

Thank you for the information. The change of direction on P5 is unfortunate: PyTorch has adopted a very compact representation for this based on SymInts, and it interoperates well with the symbolic constraint equations that they emit at the graph level. We’re still working to fully leverage it, but it is inherently more analyzable – which is the primary production purpose for such expressions. It nicely produces symbolic relationships where it is easy to infer not just bounds but multiples/strides/offset/etc between dimensions (which commonly occurs in nature). Given that the StableHLO representation is more adhoc and harder to analyze, it seems like code generation algorithms will end up benefiting PyTorch/FX uses more (unless if StableHLO itself invests in analysis/tooling to lift its representation to a similar level).

My response (with some ideas from glea...@google.com):

Thank you for bringing that up!

In the code that I encountered while working on this change, I saw that it is common to use shape.shape_of to obtain a shape for use with dynamic ops. Such code would need to change to use one stablehlo.get_dimension_size op per dimension. It’s not clear to me why one or the other of these options would be harder to analyze, as there is a direct 1:1 mapping between them. Anecdotally, there are examples of shape analyses and transforms in XLA, which suggests using 1D tensors for shapes is not a hindrance: shape_component_anaysis, symbolic_shape_optimization, broadcast_propagation. Also, representing shapes as 1D tensors is not an obstacle to shape refinement.

The main example I encountered where 0D tensors are used to represent shapes is in JAX codegen. “Materializing” the shape into a 1D tensor for use with dynamic ops requires one reshape op per dimension and a concatenate op. In some cases, there are computations before the shape is materialized, but not after (1, 2), and when we spoke to JAX maintainers they had no preference.

I couldn’t find canonical documentation for the PyTorch approach. If you know of any good documents, would you mind sharing them?

Since SymInts were not on the radar when the RFC was written, I think it is reasonable to consider this as out of scope for the current RFC and to put this off for now.

Best,

Michael

Stella Laurenzo

unread,
Apr 1, 2024, 4:31:25 PMApr 1
to Michael Levesque-Dion, OpenXLA Discuss, Eugene Burmako
Apologies - my email is routed in some obtuse ways and I think the listserv only allows responses from some of them. My original response had only gone to Michael, and when he asked me to repost, I added it to my list to try to also roll up some more useful feedback (but hadn't gotten to it yet).



On Mon, Apr 1, 2024 at 12:02 PM 'Michael Levesque-Dion' via OpenXLA Discuss <openxla...@openxla.org> wrote:
We got the following piece of feedback from Stella....@amd.com (the feedback was meant to be sent here, but somehow didn't make it):

Thank you for the information. The change of direction on P5 is unfortunate: PyTorch has adopted a very compact representation for this based on SymInts, and it interoperates well with the symbolic constraint equations that they emit at the graph level. We’re still working to fully leverage it, but it is inherently more analyzable – which is the primary production purpose for such expressions. It nicely produces symbolic relationships where it is easy to infer not just bounds but multiples/strides/offset/etc between dimensions (which commonly occurs in nature). Given that the StableHLO representation is more adhoc and harder to analyze, it seems like code generation algorithms will end up benefiting PyTorch/FX uses more (unless if StableHLO itself invests in analysis/tooling to lift its representation to a similar level).

My response (with some ideas from glea...@google.com):

Thank you for bringing that up!

In the code that I encountered while working on this change, I saw that it is common to use shape.shape_of to obtain a shape for use with dynamic ops. Such code would need to change to use one stablehlo.get_dimension_size op per dimension. It’s not clear to me why one or the other of these options would be harder to analyze, as there is a direct 1:1 mapping between them. Anecdotally, there are examples of shape analyses and transforms in XLA, which suggests using 1D tensors for shapes is not a hindrance: shape_component_anaysis, symbolic_shape_optimization, broadcast_propagation. Also, representing shapes as 1D tensors is not an obstacle to shape refinement.

The main example I encountered where 0D tensors are used to represent shapes is in JAX codegen. “Materializing” the shape into a 1D tensor for use with dynamic ops requires one reshape op per dimension and a concatenate op. In some cases, there are computations before the shape is materialized, but not after (1, 2), and when we spoke to JAX maintainers they had no preference.

I couldn’t find canonical documentation for the PyTorch approach. If you know of any good documents, would you mind sharing them?

There are many and also a lot of details. Suffice to say that FX uses an orthogonal, ranked, dependent-type based approach to encoding symbolic shape symbols for a graph. Then for the cases of feeding indices to ops that need to change their actual shape, there are dedicated "symint" ops and indexes are treated as indexes of a known rank. I'm not sure I could make an absolutist argument that it absolutely can't be done with tensors, but I will make a practical one: I'm not aware of any production compiler that does it that way. Indexes are so central to this kind of system that they almost always benefit from distinct handling and representations. In practice, we have found many programs where "getting back to indexes" from a conjoined representation is a high bar that, when an analysis gets it wrong, results in egregious performance. As with most things in compilers, I'm certainly not saying that it can't be done, but that it is hard to do well -- especially when compared to common practice. And that it is seldom a good idea to trade the most important values in your system for an overly generic representation.
 

Since SymInts were not on the radar when the RFC was written, I think it is reasonable to consider this as out of scope for the current RFC and to put this off for now.

Personally, I think that StableHLO is in a tough spot: it neither has a frontend that has been designed for dynamic shapes, nor does it have a backend compiler that has really put pressure on it (whereas the alternatives have both). Given that I've seen both of those ingredients be necessary to get to a good representation in two other situations now, I think I agree with you that StableHLO may not be the place to litigate that. However, with my other hat on where I literally see the information loss and disconnects from torch->stablehlo in the context of torch-mlir, it is hard to ignore completely.
 

Best,

Michael

On Thursday, March 21, 2024 at 1:09:55 PM UTC-7 Michael Levesque-Dion wrote:

Hi everyone,


Thank you for all the discussion on the RFC! We ended up tying off remaining feedback in early February: https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md.


Several of the proposals have already been implemented. Unranked tensors have been removed from StableHLO, which allowed a good amount of cleanup, and more cleanup is coming.


We're working on adding specs, reference implementations, and revisiting verifiers for dynamic ops, as well as considering what new dynamic ops are missing (P4).


While integrating P5’s "represent shape computations as StableHLO operations on variadic 0D tensors", we noticed that the returns are lower than we thought and the costs are high. Following up with frameworks teams targeting dynamic ops, we believe the purported benefits aren't worth the breakages to community code (MLIR, C++, Python). Therefore, we are proposing to continue using 1D tensors to represent shapes in dynamic ops, which e.g. facilitates interop with shape dialect operations and makes DRR patterns easier to write. I have sent a PR to update the RFC: https://github.com/openxla/stablehlo/pull/2116. I will keep the PR open for one week to allow feedback from the community.


Best,


Michael


On Tuesday, July 4, 2023 at 6:50:13 PM UTC-7 Eugene Burmako wrote:
Hi everyone,

The current design of dynamism in MHLO and StableHLO has been practically useful. There are success stories of it connecting JAX, PyTorch and TensorFlow to a number of compilers, in a mix of research and production environments. This RFC aims to leverage existing support for dynamism in the StableHLO dialect, discuss improvements to the existing design and then formalize the improved design in the StableHLO opset.

The main challenge with writing this RFC was that it affects the entire opset. The current design involves a lot of corner cases, so it took about a year of practical evaluation by the author - primarily within JAX native serialization, but also correlating with other ongoing and prior projects - to distill the design into just a few general design principles.

Finally, I'd like to acknowledge Smit Hinsu's work on the Bounded Dynamism RFC from Q4 2022, which was superseded by this work. The representation for bounded dynamic types in the StableHLO dialect was designed and implemented by Smit, and Smit's proposal to allow bounded dynamic types everywhere is compatible with the more general proposal from this RFC to enable dynamism for all size-related program elements. Furthermore, Smit contributed the formal spec for get_dimension_size as well as the informal spec for set_dimension_size.

PTAL at the RFC: https://github.com/openxla/stablehlo/pull/1668.

Cheers,
Eugene

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/f0ce22b7-310e-4e96-a523-2a532b4f01edn%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Michael Levesque-Dion

unread,
Apr 5, 2024, 7:37:22 PMApr 5
to OpenXLA Discuss, Stella Laurenzo, OpenXLA Discuss, Michael Levesque-Dion

Hi Stella,


Thank you for sharing more of your thoughts! I have discussed this with Kevin / Jacques and we are grateful that you brought this to our attention.


You make a good case that symbolic integers (or some other equivalent representation) are more convenient to analyze/manipulate than shape tensors. I looked at some of the lowerings in torch-mlir and they look very similar to what JAX does: torch-mlir > linear.mlir, JAX example.


You make a good point that some of the most prominent StableHLO frontends and backends don't currently have a need for a different shape representation at the StableHLO level, and that without the right pressures we are unlikely to "get it right". For most of our users, currently this change is "just" a big breaking change. Of course, that doesn't mean this change won't be possible in the future once we have a better motivated use-case to help guide design. In the meantime we’re testing out some ideas to make large-blast-radius changes like this less disruptive.


Would you mind sharing what the "two other situations" you mention are? I'm assuming one of them is PyTorch? Also, do you know any examples of this in the MLIR ecosystem? I’m mostly wondering if existing MLIR “symbolic integer” support is graph analyses of int math, or something else?


Best,


Michael

Reply all
Reply to author
Forward
0 new messages