Problem statement
Sunsetting the MLIR-HLO repository has announced that Google is planning to wind down the development effort and support of the MLIR-HLO repository. Over years, this repository has become an important community asset but it has also been hitting scalability issues - from technical, organizational and infrastructural perspectives.
In our experience, the primary usefulness of MLIR-HLO lies in providing direct access to: 1) the MHLO dialect that can serve as a portability layer between ML frameworks and ML compilers, 2) MHLO passes, e.g. hlo-legalize-to-linalg, which connect MHLO to the larger ecosystem. Both functions have been used successfully in multiple projects.
The first function has been successfully served by StableHLO over the last six months and an increasing number of projects have been migrating from MHLO to StableHLO as the portability layer of choice.
However, the second function - which boils down to MHLO => Linalg lowerings - is still an open question. As Mehdi pointed out in OpenXLA overall architecture & components, StableHLO was originally designed "decoupled for MHLO, which in turn is positioned as a compiler IR". This design decision allowed to focus StableHLO on just the matters of portability (which was already plenty challenging), but as a result its transformation story has been left undeveloped.
Proposed solution
Recently, Jakub from the IREE team has been exploring a potential solution to the second problem. openxla/iree#12678 goes into low-level details, but in a nutshell the idea was to take all MHLO-based passes that IREE uses, migrate them to use StableHLO and see what happens. This work is almost done, and it's been a success - now there is a proposal to sunset the MHLO-based input conversion pipeline in IREE.
Furthermore, folks from the Torch-MLIR project have recently been thinking about the second problem as well (llvm/torch-mlir#2177). Torch-MLIR is also using MHLO => Linalg passes from the MLIR-HLO repository, and while these passes are available in the XLA repository, depending on it is not very convenient logistically. Anush remarked "I would rather copy / fork if we have to than take an XLA dep because a few passes live there".
I believe that this recent community exploration suggests that there is a need in another repository to fill in the role of MLIR-HLO and provide HLO => Linalg lowerings. This repository needs well-maintained CMake and Bazel builds, a process for regularly bumping LLVM revisions and a community-friendly development environment.
In principle, we can create a new repository for this purpose, decide on the scope, bootstrap the infrastructure and find folks who would be interested in maintaining all this. But we don't have to do any of that - we can just use openxla/stablehlo which satisfies all the requirements above, and this is what I would like to propose.
Questions & answers
Q1: Doesn't this proposal contradict the mission of StableHLO?
A1: Indeed, the original mission for StableHLO was focused on just compiler input, so that MHLO can focus on being a compiler IR. However, this mission was formulated almost two years ago, and a lot has been learned during that time, e.g. that: 1) MHLO has a lot of cruft, to the extent that it's unclear how to evolve it into an awesome compiler IR, 2) the idea of splitting StableHLO and MHLO comes at a significant maintenance cost, but it's unclear whether it really carries its weight.
Q2: StableHLO has compatibility guarantees, so how can it be a good transformation dialect?
A2: This hasn't been discussed much, but within the StableHLO project there are actually two different dialects - StableHLO which is the interface that producers and consumers are using, and VHLO which is where compatibility guarantees are provided. As Stella put it, "the `stablehlo` *dialect* is actually defined in terms of an evolution process that is much closer to, say, LLVM IR than it is to a serialization format (i.e. the `vhlo` dialect and corresponding passes/utilities for serialization are what arbitrate the wire-compatibility guarantees)".
Q3: One of the benefits of splitting StableHLO and MHLO is that MHLO can contain additional operations which are only relevant to compiler pipelines. Does this proposal mean that we'll start polluting the portability layer with these operations?
A3: This kind of pollution would indeed be undesirable, but the current design of MHLO being a copy of StableHLO + a few ops is not the only way to avoid it. There is an alternative design where StableHLO is used together with satellite dialects which add functionality and not duplicate it. More specifically, ops which are currently in MHLO but not in StableHLO could go into a new dialect (could call it `stablehlo_ext` or `xla`, etc).
Q4: Perhaps we should first align on the overall OpenXLA architecture and only then decide on this proposal?
A4: Aligning on the overall architecture is very useful, but I don't think it has to be a blocker for making this particular decision. Providing StableHLO => Linalg lowerings in openxla/stablehlo will immediately resolve an acute issue that multiple projects are facing, so I would like to propose bias for action. If a better place for these lowerings materializes in the future, these lowerings can be moved there - since the API will stay the same, the migration would be easy even if the implementation changes.
Q5: What is the relationship between this proposal and TCP?
A5: To quote the TCP RFC, "TCP’s mission is to be a mid-level dialect in MLIR that enables transformations that complement those in existing dialects, by supporting ops at a high-level of abstraction". This is a promising project, which I think should be part of a long-term discussion about transformation dialects. However, its lowering to Linalg doesn't yet have feature parity with the MHLO => Linalg lowering, so it cannot yet provide a solution for the short-term problem.
Q6: What does this proposal mean for MHLO?
A6: MHLO dialect would remain as an implementation detail of the XLA compiler, which provides 1:1 parity with HLO and therefore a gateway to the wealth of functionality implemented with HLO. In the future, it may also dissolve into multiple dialects, but this is out of scope for this proposal. What happens to MHLO passes, e.g. the existing MHLO => Linalg lowering, would be up to their owners.
--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/1e6a2702-24eb-49ac-8789-fa0dc90ad293n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.
--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAF5Zj%3DDFUYrMUzncRw1Qh-gx7-MN1UNkYio8Q71MNf_r5zfvBQ%40mail.gmail.com.
FYI in Torch-MLIR we are using a little bit more than just the linalg lowerings from mlir-hlo and we would want to include those too:
https://github.com/llvm/torch-mlir/blob/959f4f48d51307dca38afdf63188583618e68378/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py#L43
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAPxmVVwDv2__d7ry3oQDka2Us1x6EW5FG-VbOwotOhUfhSibLA%40mail.gmail.com.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAG%2BB7Vkd8mDxCRUxqB39kpsEb_D050ZaKp4Cn-5aFu8wHo5qMA%40mail.gmail.com.
I don't think I agree with the framing that this is in contradiction to StableHLO being an input dialect and this makes it more of a transform dialect. If we want StableHLO to be a useful format, there *has* to be a way to lower it to something else. Linalg is one of the reasonable paths someone might choose.
If we can make the lowerings sufficiently general, sharing makes sense. Even if it would require being opinionated for some lowerings, we could share the rest. I actually think it makes *more* sense to have lowerings target StableHLO because then they can be shared more easily given StableHLO's position as a common input format. I can easily imagine pipelines that take StableHLO->My Transform Dialect [do transforms]->StableHLO.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAKeN2ecE0Mwe0CtJrdp64P6hT24qnbaBmKA5oCXy5snKW7tMBA%40mail.gmail.com.
On Tue, May 30, 2023 at 2:07 PM 'Geoffrey Martin-Noble' via OpenXLA Discuss <openxla...@openxla.org> wrote:I don't think I agree with the framing that this is in contradiction to StableHLO being an input dialect and this makes it more of a transform dialect. If we want StableHLO to be a useful format, there *has* to be a way to lower it to something else. Linalg is one of the reasonable paths someone might choose.From my understanding, this holds only if we assume that we would do **all** transformations at the Linalg layer. Otherwise by removing the extra MHLO layer, this is just a forcing function to perform transformation in StableHLO when you can't do things in Linalg.If we can make the lowerings sufficiently general, sharing makes sense. Even if it would require being opinionated for some lowerings, we could share the rest. I actually think it makes *more* sense to have lowerings target StableHLO because then they can be shared more easily given StableHLO's position as a common input format. I can easily imagine pipelines that take StableHLO->My Transform Dialect [do transforms]->StableHLO.That requires StableHLO to be able to model everything that you would do in "My Transform Dialect", and forbid any kind of "lowering" otherwise you lose the ability to go back to StableHLO. I don't quite foresee this being neither practical nor desirable. I actually haven't drawn or seen a proposed architecture diagram (in OpenXLA or any MLIR-based compiler actually) that would be structured with such a flow: the natural flow instead goes in a single direction where you introduce more lowerings and specialization.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CANF-O%3DYcUQdueZ28kitx_HfysWHtfwpwtx%3DQBtOQhCGjuk-e2g%40mail.gmail.com.
On Tue, May 30, 2023, 3:39 PM Mehdi AMINI <joke...@gmail.com> wrote:On Tue, May 30, 2023 at 2:07 PM 'Geoffrey Martin-Noble' via OpenXLA Discuss <openxla...@openxla.org> wrote:I don't think I agree with the framing that this is in contradiction to StableHLO being an input dialect and this makes it more of a transform dialect. If we want StableHLO to be a useful format, there *has* to be a way to lower it to something else. Linalg is one of the reasonable paths someone might choose.From my understanding, this holds only if we assume that we would do **all** transformations at the Linalg layer. Otherwise by removing the extra MHLO layer, this is just a forcing function to perform transformation in StableHLO when you can't do things in Linalg.If we can make the lowerings sufficiently general, sharing makes sense. Even if it would require being opinionated for some lowerings, we could share the rest. I actually think it makes *more* sense to have lowerings target StableHLO because then they can be shared more easily given StableHLO's position as a common input format. I can easily imagine pipelines that take StableHLO->My Transform Dialect [do transforms]->StableHLO.That requires StableHLO to be able to model everything that you would do in "My Transform Dialect", and forbid any kind of "lowering" otherwise you lose the ability to go back to StableHLO. I don't quite foresee this being neither practical nor desirable. I actually haven't drawn or seen a proposed architecture diagram (in OpenXLA or any MLIR-based compiler actually) that would be structured with such a flow: the natural flow instead goes in a single direction where you introduce more lowerings and specialization.+1 - the most natural flow is through dialects that represent a lowering of some element of the abstraction level.I think in the present ecosystem, there are exceptions to that, where we are not doing lateral transformations between dialects at a similar abstraction level for the purpose of interop, but this isn't how you would build an actual compiler (beyond trying to adapt inputs to the compiler in some fashion).
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAEkedjggodOpTW5CAf8ou4hs%2BKp3bP_MnEXzy6_55g59uN%3DnqQ%40mail.gmail.com.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAF5Zj%3DCesr1_E40uqvT7A8zSb3d5cndvh9TAnURg%3DcncPTXTxQ%40mail.gmail.com.