Hi folks - I would like to reach a consensus that we should take some form of dependency on torch-mlir and integrate its dialects and conversion pipelines directly into the IREE compiler.
For the past two months, we (Nod) have been executing on a proof of concept that folks may have seen:
SHARK-Turbine. This project has as a primary goal "taming" the interface between PyTorch and IREE, providing a tight integration with all PyTorch modalities (torch.compile, eager execution, and AOT approaches). Further, it has been attempting to see how far we are from the torch-mlir vision to rebase most of the integration onto upstream PyTorch concepts, mostly centered around Dynamo.
Current Status
I'm happy to say that as of last week, we passed a milestone where we consider this proof of concept to be a success, and it is now time to determine what to do with it and position it in a way that it has maximum effect. Namely, we have determined that:
- Interop at the FX/python level for graph extraction is viable without bridging through any of the legacy modalities (TorchScript, etc). Proving this is important in terms of integration because it allows for a much simplified approach to something like torch-mlir, requiring no native code dependency on PyTorch itself in order to bridge to a compiler backend. The entirety of the graph-level interop is now in one Python file, and it bridges directly via MLIR/IREE's Python API.
- The Dynamo export path to FX is ready to replace the more legacy (tricky integration) approaches for performing whole graph extraction. In addition, Dynamo's native dynamic shape support mates well to IREE's own model, and it appears ready to handle challenging models (we have been testing with inference optimized LLMs).
- Implementing a low-level export interface like iree-jax is now possible directly on PyTorch, producing similar levels of capability for assembling complicated programs/exports.
- I have been syncing a subset of torch-mlir (just the dialects and conversions) to keep in sync with IREE's LLVM version for O(months), and I have not hit *any* integration hurdles for this subset (with the exception of one case of needing to tweak a warning flag). The integration surface is quite small (see here and here). I expect it could be further simplified as we begin to shed the legacy, pre-Dynamo paths in torch-mlir proper.
- We considered interop via StableHLO and may consider it further in the future, but in a survey, it was not suitable for any of the workloads which our customers care about -- primarily due to immaturity with respect to custom ops, type support, extensibility, and convoluted conversion paths requiring third party dependencies that do not appear factored for our use. The StableHLO ecosystem continues to evolve, and we are free to make project-level decisions to embrace it more as the situation improves.
Benefits
We are happy to continue evolving the native compiler integration within the SHARK-Turbine project, but splitting the compiler in this way forks the ecosystem. Minimally, it creates a lot of redundant work around CI, but more importantly, it forces us into an untenable integration testing story: we either have to base integration tests on an import to certain internal dialects which are unstable or face the issue that we can't offer artifact stability across dependency versions. Ideally, I would like a robust integration test suite in upstream IREE that includes generated IR from PyTorch, testing a variety of models and modalities. Basing this on the torch+surrounding dialects, which we control and can handle upgrades for, gives us a reasonable path forward.
Approach
The SHARK-Turbine prototype already factors the PyTorch frontend as a compiler plugin. It should be a simple matter of adding it to a directory like compiler/plugins/fe/pytorch (parallel to the compiler/plugins/target tree) and setting some CMake flags. We would also need to add a submodule dependency on torch-mlir. We can consider other methods of managing this dependency over time, but this seems low enough overhead and inline with how we manage StableHLO, Jax's native frontend layer.
Doing this now also has the benefit that the code is mostly build scaffolding and has only had commits from a small set of individuals who have all signed the OpenXLA CLA -- meaning that it is in a state where it can easily be contributed directly. We are planning many enhancements to this layer of the system, and it will become increasingly hard to keep this in a clean state with respect to authorship.
We would continue to develop SHARK-Turbine, using a direct dependency on the compiler's APIs vs a private version with our plugin enabled.
Making this contribution would also eliminate the in-tree fork of the TMTensor dialect and supporting infrastructure, since (to the extent that those are still needed), they are drawn directly from their source and included as part of the plugin.
There are parallel discussions to be had on the torch-mlir side about simplifying the codebase in a post-Dynamo world. These are orthogonal concerns to this RFC, since the approach used already presumes that state and excludes the legacy pieces of the project.
Timeline
If there are no objections, I would like to make this contribution this week and simplify the state of the world so that we can contribute more tests to upstream.
Comments?
- Stella