FlashAttention fusions from StableHLO primitives?

81 views
Skip to first unread message

Jan Pfeifer

unread,
Jun 15, 2026, 3:13:24 AM (9 days ago) Jun 15
to OpenXLA Discuss
hi all!

I maintain a Go ML framework (that uses XLA as its main backend), and I was wondering what is the current state and future roadmap (if any) for automated FlashAttention-style fusions within XLA (StableHLO -> PJRT) across CPU, GPU (v2/v3 flash attention variants), and TPU backends.

Is there anything on that front ? If not, is the expectation it will always require one making custom-calls to platform specific code (Pallas, CUDA, etc.) ?

many thanks!
Jan



Abhinav Gunjal

unread,
Jun 19, 2026, 12:49:59 PM (5 days ago) Jun 19
to Jan Pfeifer, OpenXLA Discuss
Hi Jan, here is the current state:
 Some folks (on device ML team) have leveraged Composites to represent attention in a more portable way that has a fallback.
 Server workloads have in general taken the kernel approach as you've noted.
 There is some backend support for fusing the constructs in FlashAttention, but it relies on pattern matching so ymmv. I believe these fusions are generally tailored to what JAX emits for these constructs.

-Abhinav

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/5dd4d86f-49d4-455a-ad54-4d3cdb843eeen%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Jan Pfeifer

unread,
Jun 20, 2026, 2:53:07 AM (4 days ago) Jun 20
to Abhinav Gunjal, OpenXLA Discuss

Thanks Abhinav, this is very helpful! 

I probed further with Gemini and it tells me that the pattern-matching has been abandoned(?).

I'm trying to build the right mental model, is the following accurate about specialized fusions (FlashAttention, PagedAttention, etc.) today?

  • Edge/Mobile: Rely on stablehlo.composite with a standard name / signature (maintained by the on-device ML team) and a decomposed StableHLO fallback.

  • CPU: No specialized attention custom-calls or pattern matching. It relies entirely on standard, unrolled StableHLO math.

  • GPU (CUDA): Rely on stablehlo.custom_call (e.g., "cudnn_fha"). Because these handlers are compiled into the CUDA PJRT distributed by JAX, I can replicate these fusions by matching JAX's custom-call signatures.

  • TPU: Rely on custom_call payloads compiled by and tightly coupled with Pallas/Mosaic. Because these are JAX-internal (not publicly specified as OpenXLA), other implementations (like mine using Go) cannot access them without embedding Python/JAX.

Cheers,
Jan

Reply all
Reply to author
Forward
0 new messages