Groups keyboard shortcuts have been updated
Dismiss
See shortcuts

Addition of more high-level Ops to StableHLO

385 views
Skip to first unread message

Pushkar Ratnalikar

unread,
Mar 14, 2024, 4:35:17 AM3/14/24
to OpenXLA Discuss
Dear StableHLO Community,

I work with the AWS Neuron team.

I am reaching out to discuss a potential enhancement to StableHLO-- the addition of high-level ops for commonly used ML operations like Softmax. This would allow for custom-lowering for something like jax.nn.softmax, directly to a stablehlo.softmax. While we can lower it to a custom-call, for a specific target, we feel that a high-level op would allow for better optimization in our compilation flow.

I understand that StableHLO is inspired heavily by HLO/MHLO and this would be new op over what exits there. 
I would like to gather your thoughts, insights and feedback. 

Thank you for your time,
Pushkar

Michael Levesque-Dion

unread,
Mar 14, 2024, 12:14:35 PM3/14/24
to OpenXLA Discuss, Pushkar Ratnalikar
Hi Pushkar,

Thank you for reaching out!

We recently added a new "composite" op to StableHLO: https://openxla.org/stablehlo/spec#composite. This allows high-level ops to be modeled in stablehlo programs using a name and a decomposition. The name can be used to identify what kind of op is being modeled, and the decomposition can be used to inline the op when necessary (e.g. some backends may be able to handle softmax, but others may only handle lower-level ops). Composites can be legalized to function calls using the `stablehlo-legalize-composite-to-call` pass, and the composite op can also be tagged with a version and additional attributes.

Does the composite op work for your use case? Please let us know if you have any feedback! Figuring out what to do with successful composites is an area we'd like to discuss with the community in the near future.

(For more details on the origins of the composite op, please see the extensibility RFC: https://docs.google.com/document/d/1bSyyLA-p1F7KjZgjo563F1WFsPwcZc4eaH5WyQfbsi0/edit)

Best,

Michael

Stella Laurenzo

unread,
Mar 14, 2024, 12:22:46 PM3/14/24
to Michael Levesque-Dion, OpenXLA Discuss, Pushkar Ratnalikar
+1 stablehlo should have softmax.

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/26750f0c-12b8-4d1b-9832-389d94e8851cn%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Pushkar Ratnalikar

unread,
Mar 15, 2024, 2:32:42 AM3/15/24
to Stella Laurenzo, Michael Levesque-Dion, OpenXLA Discuss

Hi Michael,

 

Thanks for the information and the linked document. It is very helpful. It does seem to fulfill many of our goals. 

From the document, it feels like stablehlo.composite is a pathway for extension to the op-set. This allows the users to model new ops, with the framework still being responsible for generating the composites.

 

Other future work includes figuring out the "exact glide path from an experimental composite into a well-supported higher-level abstraction in the OpenXLA ecosystem".

 

Does the above mean that any proposed stablehlo op spend time as stablehlo.composite first and the `stablehlo-legalize-composite-to-call` pass do the target-specific legalization ( inlining v/s making it a custom-call ) on conversion to HLO/MHLO, before being considered as an op in StableHLO? 


OTOH, For the evolution part of StableHLO, the document says "Include popular compositions and their specs as a part of the StableHLO repository. This makes it easier for producers and consumers to support these operations." and "this could mean development of a community-driven dialect to hold these decompositions and their specs."

So, new-ops are less likely to be added to StableHLO itself?

 

@Stella, I agree.

 

Thanks,

Pushkar





Kevin Gleason

unread,
Mar 15, 2024, 12:29:26 PM3/15/24
to OpenXLA Discuss, Pushkar Ratnalikar, Michael Levesque-Dion, OpenXLA Discuss, Stella Laurenzo
Hello! Thanks for getting this conversation started. I've been thinking about this a bit lately and am interested in everyones thoughts.

Does the above mean that any proposed stablehlo op spend time as stablehlo.composite first
So, new-ops are less likely to be added to StableHLO itself?

New ops do not necessarily need to be composites first, but for higher level ops / exploration to build motivation to standardize, composites and custom_calls are a useful tool. We also have interest in composites from users who think that the abstractions needed to be accelerated are temporary and shouldn't be standardized, following this huggingface mentality. Of course, this does not apply to all abstractions, there are still core building blocks.

I agree that we need to start standardizing high-level ops. What I'm not sure of is if StableHLO is the opset where these should exist. Today StableHLO is (mostly) a basis opset - to draw a PyTorch analogy we fall somewhere between CoreATen and PrimsIR level of abstraction. There are StableHLO consumers that both do and don't want these higher level operations, meaning we would need decompositions regardless - I don't think an approach of "all StableHLO consumers must support all higher-level ops" is feasible. With that, I see two potential paths forward, and am interested in thoughts on both:

1. Add higher level ops to StableHLO, add expander passes from StableHLO-->StableHLO.
2. Add higher level ops to CHLO (or analogous opset), preserve through serialization / compilation.

(1) In my opinion, this approach muddies the spec to a certain extent, by mixing in things that can be decomposed and things that can't. If developing a backend, it is less clear what must be supported, vs what is optional to accelerate or decompose. HLO has taken this approach with some of its ops, but I think this approach makes more sense in a compiler IR that lives at the edge of the ecosystem, than an input IR that lives more centrally.

(2) This approach has a standardized higher-level opset which can be optionally decomposed (back to the previous analogy, something like ATen). If this is developed centrally, it will be bundled with all frameworks depending on StableHLO and frameworks can target these directly, frameworks like JAX/TF already target CHLO for example. Currently CHLO is not stable, but I think adding spec/serialization support gradually for useful abstractions could be a viable / flexible path forward for the ecosystem. This feels like a very reasonable place for something like Softmax to live.

Let me know your thoughts!


Best,
Kevin

Mehdi AMINI

unread,
Mar 15, 2024, 7:43:06 PM3/15/24
to Kevin Gleason, OpenXLA Discuss, Pushkar Ratnalikar, Michael Levesque-Dion, Stella Laurenzo
dddd

On Fri, Mar 15, 2024 at 9:29 AM 'Kevin Gleason' via OpenXLA Discuss <openxla...@openxla.org> wrote:
Hello! Thanks for getting this conversation started. I've been thinking about this a bit lately and am interested in everyones thoughts.

Does the above mean that any proposed stablehlo op spend time as stablehlo.composite first
So, new-ops are less likely to be added to StableHLO itself?

New ops do not necessarily need to be composites first, but for higher level ops / exploration to build motivation to standardize, composites and custom_calls are a useful tool. We also have interest in composites from users who think that the abstractions needed to be accelerated are temporary and shouldn't be standardized, following this huggingface mentality. Of course, this does not apply to all abstractions, there are still core building blocks.

I agree that we need to start standardizing high-level ops. What I'm not sure of is if StableHLO is the opset where these should exist. Today StableHLO is (mostly) a basis opset - to draw a PyTorch analogy we fall somewhere between CoreATen and PrimsIR level of abstraction. There are StableHLO consumers that both do and don't want these higher level operations, meaning we would need decompositions regardless - I don't think an approach of "all StableHLO consumers must support all higher-level ops" is feasible. With that, I see two potential paths forward, and am interested in thoughts on both:

1. Add higher level ops to StableHLO, add expander passes from StableHLO-->StableHLO.
2. Add higher level ops to CHLO (or analogous opset), preserve through serialization / compilation.

(1) In my opinion, this approach muddies the spec to a certain extent, by mixing in things that can be decomposed and things that can't. If developing a backend, it is less clear what must be supported, vs what is optional to accelerate or decompose. HLO has taken this approach with some of its ops, but I think this approach makes more sense in a compiler IR that lives at the edge of the ecosystem, than an input IR that lives more centrally.

(2) This approach has a standardized higher-level opset which can be optionally decomposed (back to the previous analogy, something like ATen). If this is developed centrally, it will be bundled with all frameworks depending on StableHLO and frameworks can target these directly, frameworks like JAX/TF already target CHLO for example. Currently CHLO is not stable, but I think adding spec/serialization support gradually for useful abstractions could be a viable / flexible path forward for the ecosystem. This feels like a very reasonable place for something like Softmax to live.

What is the future of CHLO in a world with proper composite op support? Can all CHLO op just be provided as composite ops? 
This would seem in line with the stated

> Include popular compositions and their specs as a part of the StableHLO repository. This makes it easier for producers and consumers to support these operations.

which is somehow what is achieved with CHLO historically.


Pushkar Ratnalikar

unread,
Mar 19, 2024, 3:48:18 AM3/19/24
to Kevin Gleason, OpenXLA Discuss, Michael Levesque-Dion, Stella Laurenzo

Hi Kevin,

Thanks for your detailed response.


On Fri, Mar 15, 2024 at 9:29 AM Kevin Gleason <glea...@google.com> wrote:
Hello! Thanks for getting this conversation started. I've been thinking about this a bit lately and am interested in everyones thoughts.

Does the above mean that any proposed stablehlo op spend time as stablehlo.composite first
So, new-ops are less likely to be added to StableHLO itself?

New ops do not necessarily need to be composites first, but for higher level ops / exploration to build motivation to standardize, composites and custom_calls are a useful tool. We also have interest in composites from users who think that the abstractions needed to be accelerated are temporary and shouldn't be standardized, following this huggingface mentality. Of course, this does not apply to all abstractions, there are still core building blocks.

I agree that we need to start standardizing high-level ops. What I'm not sure of is if StableHLO is the opset where these should exist. Today StableHLO is (mostly) a basis opset - to draw a PyTorch analogy we fall somewhere between CoreATen and PrimsIR level of abstraction. There are StableHLO consumers that both do and don't want these higher level operations, meaning we would need decompositions regardless - I don't think an approach of "all StableHLO consumers must support all higher-level ops" is feasible. With that, I see two potential paths forward, and am interested in thoughts on both:

1. Add higher level ops to StableHLO, add expander passes from StableHLO-->StableHLO.
2. Add higher level ops to CHLO (or analogous opset), preserve through serialization / compilation.

(1) In my opinion, this approach muddies the spec to a certain extent, by mixing in things that can be decomposed and things that can't. If developing a backend, it is less clear what must be supported, vs what is optional to accelerate or decompose. HLO has taken this approach with some of its ops, but I think this approach makes more sense in a compiler IR that lives at the edge of the ecosystem, than an input IR that lives more centrally.


I agree with the general philosophy. 
 
(2) This approach has a standardized higher-level opset which can be optionally decomposed (back to the previous analogy, something like ATen). If this is developed centrally, it will be bundled with all frameworks depending on StableHLO and frameworks can target these directly, frameworks like JAX/TF already target CHLO for example. Currently CHLO is not stable, but I think adding spec/serialization support gradually for useful abstractions could be a viable / flexible path forward for the ecosystem. This feels like a very reasonable place for something like Softmax to live.


Adding the new ops to CHLO is an option. But, for some compilation-flows which don't already use CHLO, it means use of an additional bridge IR. And like you mentioned, CHLO is not stable. 
 
Let me know your thoughts!


Thanks again. I would like to try lowering to composites for now and see how it goes. 
Thanks,
Pushkar
 

xq Dan

unread,
Sep 13, 2024, 10:53:58 PM9/13/24
to OpenXLA Discuss, Pushkar Ratnalikar, OpenXLA Discuss, Michael Levesque-Dion, Stella Laurenzo, Kevin Gleason
Hi all,

I'd like to participate in the discussion on the extensibility issues of StableHLO. I think the issues here include:
  1. Ensuring the Minimalism and Extensibility of StableHLO: If we add too many complex operators, it will undermine the very purpose of StableHLO.
  2. From the Perspective of Graph Compilation: Expanding high-level coarse-grained operators into fine-grained operators can make fusion algorithms more challenging and hinder the generalization of performance.
  3. Advantages of Fine-Grained Operators for Graph Compilation: Expanding into fine-grained operators can provide greater optimization space. For example, batch normalization can be simplified arithmetically during inference.
  4. From the Hardware Backend Perspective: For certain operators, there are hand-written implementations. If these are expanded into StableHLO operators within AI frameworks, it will be difficult to ensure performance.

Among these, the impact of Issue 2 is quite significant as it affects the architecture design of XLA and determines the upper limit of the model’s performance.

To address Issue 4, one solution is to support custom call operators. When converting high-level operators to StableHLO, users can choose to designate some operators as custom calls. This way, the graph compiler will treat these as black-box operators.

To address Issue 2, when lowering to StableHLO, we can retain the boundaries of these high-level operators. Drawing an analogy to traditional IR representations in compilers, these operator boundaries are akin to function abstractions, where the internals of the function are composed of StableHLO. This would transform the IR representation of the entire model into a combination of inter-procedural and intra-procedural optimizations.

As noted by the XLA team in their 2019 C4ML slides, they found that the most effective information for operator fusion often comes from the user side. For instance, identifying an operator as a softmax or batch normalization is more straightforward than relying on a complex heuristic algorithm to recover these operator boundary details.


Therefore, I believe that when expanding coarse-grained high-level operators into StableHLO, it is necessary to carry operator boundary information. I’m not sure if this is the design purpose of composite. When there is an efficient implementation for a certain operator on the backend, it can be chosen as a custom call, allowing the graph compiler to treat it as a black-box operator.

With the inclusion of operator boundary information, composite operators still exist but become more transparent, turning into many white-box operators. XLA can then redesign the operator fusion algorithm for this type of IR representation. White-box operators can be fused into larger white-box operators or simplified arithmetically, and white-box operators can also be fused with black-box operators. Of course, this will require a more efficient operator compiler to ensure that the generated fused operators can be automatically codegened.


Thanks

Xiaoqiang


Reply all
Reply to author
Forward
0 new messages