[RFC] StableHLO Extensibility

234 views
Skip to first unread message

Kevin Gleason

unread,
Jun 6, 2023, 10:20:11 PM6/6/23
to OpenXLA Discuss, Eugene Burmako

Hi everyone,


As part of the ongoing review of the TopK RFC, we have come up with a sketch of a design that allows arbitrary operations to define StableHLO decompositions. PTAL at [RFC] StableHLO Extensibility. Feedback as comments on the doc are preferred since this has been shared in a few locations.



Cheers,

Kevin & Eugene


Jakub Kuderski

unread,
Jun 7, 2023, 3:17:13 PM6/7/23
to Kevin Gleason, OpenXLA Discuss, Eugene Burmako, silv...@google.com, Mehdi AMINI
I decided to move some long-form thoughts here so that we can avoid reviewing this RFC under the PR for the TopK RFC. The previous discussion is in this thread: https://github.com/openxla/stablehlo/pull/1593#discussion_r1220360994.

Overall, I'm very excited about this direction and think that TopK may provide a good case study here. I have similar concerns to those shared by Sean under that PR thread.

To make sure I understand this extensibility proposal, here's my attempt at re-stating it:
  • Extension ops are ops whose properties are defined by a known name + attributes + function arguments.
  • Every extension op comes with a decomposition in terms of a function with stablehlo ops. This decomposition can be considered a fallback in the absence of a better one that would involve other dialects. This allows portability of extension ops and universal serialization.
  • Decomposition merely implements an extension op, but does not define its semantics.
  • Other dialects can define their own preferred decomposition (the RFC calls this deserialization).
This leads to the question of who maintains the registry of extension ops. One choice would be to have a registry under the StableHLO umbrella, say `stablehlo_ext`, but there may be more, e.g., `vendor_ext`. One of the responsibilities of the registries would be to host the list of extension ops together with their semantics.

Let's consider two hypothetical extension ops, `stablehlo_ext.topk` and `vendor_ext.print`:
  • topk has a reasonable but inefficient decomposition to stablehlo ops (iota + sort + slice), and many possible decompositions to other dialects with better performance, e.g., `mhlo.topk`, `linalg_ext.topk`.
    • Given two possible non-default decompositions/deserializatins, who decides which one to pick?
  • print is not implementable in stablehlo, so it decomposes to an empty function. A decomposition to another dialect may end up writing to a buffer or stdout, and thus do something useful.
    • Even though the default decomposition is empty, front-end transforms should not DCE print statements.
    • print may be called with different types, e.g., tensor<i32> or tensor<f32>, so the decomposition would have to reference 2 functions with different names.
    • How do we encode the types of the arguments? Is it the vendor registry who has to come up with a name mangling scheme?
This makes me wonder if we have to require the IR emitter to immediately materialize decomposition functions, instead of exposing an interface method `FuncOp getDecomposition(OpBuilder* builder)`, as given an extension op we can trivially extract the callee. Similarly, isn't deserialization/non-default decomposition the same as a custom rewrite pattern set for a given dialect?

I think that if we squint a bit, CHLO already provides something similar to extension ops together with default decomposition patterns to MHLO (and in IREE also to StableHLO). In downstream compilers, we employ these decompositions by adding them to a rewrite pattern set, together with custom decompositions that can be given a higher proprietary. What is missing from this mechanism to implement decomposition/deserialization for proxy ops?

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/82c0806f-a59a-42c7-8617-026c6fbabe80n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Sean Silva

unread,
Jun 7, 2023, 4:31:14 PM6/7/23
to Kevin Gleason, OpenXLA Discuss, Eugene Burmako
How would the func for the decomposition work for operations with variadic rank? As an example, how would we implement the decomposition of elementwise gelu, if it can be called on tensors of arbitrary (but static) rank?

-- Sean Silva

Kevin Gleason

unread,
Jun 9, 2023, 1:16:48 PM6/9/23
to OpenXLA Discuss, Sean Silva, OpenXLA Discuss, Eugene Burmako, Kevin Gleason

Hi everyone,


We’ve really appreciated the discussion in this RFC thus far! I have incorporated feedback, and added some rationale / design goals.


In this revision of the RFC we present two designs, both of which should provide the core functionality needed by the community, with slightly different levels of structure and features available to extension dialects. Would appreciate all feedback!



Best,

Kevin

Kevin Gleason

unread,
Mar 15, 2024, 4:41:17 PMMar 15
to OpenXLA Discuss, Kevin Gleason, OpenXLA Discuss

Hi everyone!


I wanted to give a long-overdue update on the Extensibility RFC and the stablehlo.composite op. We've been iterating on the RFC Doc a bit - notably there's a frameworks update at the top, and plenty of discussion captured in the Open Feedback section. All open feedback should be addressed now, there are a few comments left open to ensure responses are seen.


The op implementation was recently merged at openxla/stablehlo#2024, with optional inliner pass in openxla/stablehlo#2073. Michael (PR author) will be presenting a short overview on extensibility and the new op at the next  OpenXLA Community Meeting on 3/26!


This is a great milestone, but there is still a bit of integration work to be done:

  1. Integrate with frameworks, starting with PyTorch/XLA and JAX, including JAX transformations.

  2. Design / implement HLO plan for composites - current suggestion is a new bit on kCustomCall.

  3. Figure out a glide path for reifying commonly-used composite operations as well-supported ops.


The StableHLO team will be immediately prioritizing a few of these items, and will follow up with relevant stakeholders to figure out timelines of the remaining items. Thanks again for all the discussion and feedback!



Best,

Kevin

Reply all
Reply to author
Forward
0 new messages