--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/5dd4d86f-49d4-455a-ad54-4d3cdb843eeen%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.
Thanks Abhinav, this is very helpful!
I probed further with Gemini and it tells me that the pattern-matching has been abandoned(?).
I'm trying to build the right mental model, is the following accurate about specialized fusions (FlashAttention, PagedAttention, etc.) today?
Edge/Mobile: Rely on stablehlo.composite with a standard name / signature (maintained by the on-device ML team) and a decomposed StableHLO fallback.
CPU: No specialized attention custom-calls or pattern matching. It relies entirely on standard, unrolled StableHLO math.
GPU (CUDA): Rely on stablehlo.custom_call (e.g., "cudnn_fha"). Because these handlers are compiled into the CUDA PJRT distributed by JAX, I can replicate these fusions by matching JAX's custom-call signatures.
TPU: Rely on custom_call payloads compiled by and tightly coupled with Pallas/Mosaic. Because these are JAX-internal (not publicly specified as OpenXLA), other implementations (like mine using Go) cannot access them without embedding Python/JAX.
Cheers,
Jan