Hi,
As a follow up from the community meeting this week, I wanted to start the discussion of the two dialects we seem to be building in parallel:
- A tile dialect (tpp for us, vmvx in IREE)
- A micro-kernel dispatch dialect (xsmm for us, ukernel for IREE)
We have a more evolved design, because this is our core research, but it's also very much biased to the library we created (libxsmm) and so is probably not very representative of all other libraries that IREE and other MLIR users need.
Our dialect documents are a little outdated but they still give the general gist of what we're trying to do:
Note: TPP on tensors we'll use DPS to match other dialects.
Our vision is that we need to separate the layers into a list of canonical representations, from ingress to hardware dispatch, across multiple MLIR compilers, to focus on appropriate semantics at the right level.
We don't want to have our own list and it doesn't help if IREE has its own either, as we can't use it (or any other non-upstream project), otherwise we'd be stuck with a single framework, and upstreaming those dialects would fragment the overall MLIR design, not unify it.
We'd like to bring this to a wider audience (the rest of MLIR community) for discussion, but first, since we've been working with IREE and our needs are very much aligned (from feedback on the meeting this week), we'd like to use IREE's experience to bring together to LLVM an RFC that has a higher chance to resonate with all other groups.
A typical pipeline would be:
- Ingress: HLO, Torch, TOSA, ...
- High-Level: Linalg + NamedOps(*) [1]
- Pass: Block/Tile/Fuse @ tensor
- Tile level: { tpp, vmvx } -> Tile(*) @ tensor [2]
- Bufferization
- Tile level: { tpp, vmvx } -> Tile(*) @ memref
- Pass: Combine/Fuse/Reorder/Strength-Reduce @ memref
- Micro-kernel level: { xsmm, ukernel } -> UKernel(*)
- Pass: Hoisting/DCE/CSE/Canonicalization
- Lowering: SPIRV, LLVM, etc.
(*) Those are the places where we think there's scope for new dialects.
[1] This is probably the TCP dialect?
[2] We currently have tpp @ tensor level to make some passes simpler (not depend on address analysis for tile op fusion). This isn't mandatory in our design, but it is an important part of it.
Basically, once it gets to micro-kernel dispatch, it's really hard to do fusion and grouping, accumulation reordering etc. so we need an intermediate state between linalg and ukernel. This is our TPP dialect, that are simply operations where the "data type" is a tile.
The size of the tile and the order in which these ops are called is up to the compiler (and support in the library).
How to bundle these calls into a macro-kernel depends on the device. On CPUs, one can use OpenMP or a smart scheduler. On GPUs, one can fuse into a single kernel and dispatch to the device.
Here, high-level, tile-level and ukernel-level ops can co-exist at any given time (modulo bufferization issues), and the lower passes will simply ignore ops that are not its input ops.
This offers a lot of flexibility:
- A GPU lowering that doesn't have support for tile micro-kernels can replace the high-level named ops directly into calls (every other following pass ignores it).
- A CPU lowering that uses hand-crafted kernels (ex. OneDNN) can do the same.
- An accelerator device that has MLIR compiler passes can let the framework compile to tile ops, then run their passes and lower to micro-kernel calls on their own.
- A CPU lowering that receives profile/trace/super-optimizer information can generate the loops as instructed, lowering to micro-kernel calls for each tile.
- While at tile/tensor level, it's easier to see that inputs and outputs are the same, or unused, and make decisions about in-place vs out-of-place, kernel fusion, etc.
- Compilers that target multiple libraries (ex. OneDNN and XSMM) can bundle/split tile calls into kernel calls and vice versa to pick the optimal macro/micro kernel call to make before it gets to function calls.
In summary, we want to find a flexible path through MLIR, using dialects at the right level, where all compilers can rely on the semantics of the ops to do the transforms at the right level, and not have to rely on scalar evolution, alias analysis, liveness analysis, etc. to find optimal lowering patterns.
We also want to allow compilers to combine with other compilers/frameworks, and "talk" to each other through these strong semantics dialects, allowing them to transform the IR without needing to "understand" some third-party dialect.
Thanks!
Renato