Questions about the maturity and intended use of stablehlo_to_tf_saved_model

94 views
Skip to first unread message

Hamza ELBARAGHI

unread,
Feb 24, 2026, 4:01:48 AMFeb 24
to OpenXLA Discuss

Hi all,

I’ve recently started digging deeper into the StableHLO ecosystem, and I’m trying to better understand the role and maturity of the stablehlo_to_tf_saved_model utility.

I have a few questions regarding its intended use and real-world applicability:

  1. What was the original motivation for this tool?
    Since StableHLO is designed to be framework-independent and serve as a portable compiler IR, converting it back to a framework-specific representation (TensorFlow SavedModel) seems somewhat counterintuitive to me. I would have expected downstream usage to favor other MLIR dialects or compiler backends instead. Could someone clarify the primary use case this tool was designed for?

  2. How mature and widely used is this path?
    Has stablehlo_to_tf_saved_model been tested extensively beyond small or synthetic examples (I only found one test with a very simple model, 2 constants)? Is it used in production or large-scale internal pipelines anywhere in the ecosystem?

  3. Scalability to large models (e.g., LLaMA 2 7B)
    Is it realistic to use this utility with very large models (multi-billion parameter scale)?
    In particular:

    • How are large constants/weights typically handled?

    • How are input_locations and state_dict expected to be configured for such models?

    • Is there tooling to automate this process, or is it currently a mostly manual setup?

If anyone has experience using this tool with large transformer models (or can point to relevant discussions or code), I would greatly appreciate any references or insights.

Thanks in advance!
Hamza

Hamza ELBARAGHI

unread,
Feb 24, 2026, 4:32:57 AMFeb 24
to OpenXLA Discuss, Hamza ELBARAGHI
Also another limitations I think is that we could have constants as dense<...> attributes not parameters. As given in the example in the documentation : 

func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {
  %0 = stablehlo.add %arg0, %bias: tensor<1xf32>
  return %0 : tensor<1xf32>
}

How can we handle this in stablehlo_to_tf_saved_model API ? Or probably we must implement a pass to somehow parameterize the constants first ?

Thanks!

Kevin Gleason

unread,
Feb 26, 2026, 2:04:06 PMFeb 26
to Hamza ELBARAGHI, OpenXLA Discuss
Hello! Excellent questions :)

> What was the original motivation for this tool?
Background is: StableHLO needs a vehicle/archive for serving where the requirements are a bit different. Weights need to be stored in the archive, often times we want the weights to be easily swappable / queryable, so its beneficial to have an archive with compute & resources separated. The OpenXLA ecosystem hasn't full settled on a canonical solution for this, SavedModel & TF Serving is the well-engrained historical solution, and Orbax is a more modern solution that aims to gradually sever and TF dependencies.

The contents of these SavedModels actually just directly embeds the StableHLO module as bytecode, only relying on TF ops for the resource loading / orchestration bits, so its not as much of a dialect conversion as it is a "serving wrapper":
image.png

> How mature and widely used is this path?
The tool itself is not incredibly well tested / used, but, the concept is highly tested and used. This API was made to be a JAX-free drop-in for jax2tf, which is much more battle tested (we serve a lot of JAX models, so users tend to favor the JAX API over the StableHLO API):

> Scalability to large models (e.g., LLaMA 2 7B)
We use StableHLO embedded in SavedModels for very large models, so should scale well! I'm not sure on the answers to your other sub-questions, but would recommend looking to that JAX ecosystem and trying some examples to see how the metadata looks.

Best,
Kevin

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/48baace6-9cc1-4b66-92cc-e355094e2dc6n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Hamza ELBARAGHI

unread,
Mar 3, 2026, 4:23:56 AMMar 3
to OpenXLA Discuss, Kevin Gleason, OpenXLA Discuss, Hamza ELBARAGHI

Hi Kevin,

I am sorry for the delay in my response.

Thanks a lot for the clarification, knowing that this path is more of a serving wrapper than a dialect conversion is very important for me. The historical context around SavedModel and the pointer to jax2tf were very helpful too.

I’m particularly curious about the scalability aspect you mentioned. Is there any public reference or example of large models that have been exported to StableHLO ?

More generally, I’m trying to better understand StableHLO’s role as a portability layer. Given its relatively small and stable op set, the higher-level constructs (attention, activation functions, etc.) are always fully decomposed into primitives ? Do you know of any emerging patterns for preserving some higher-level structure in MLIR ecosystem ? Something similar to what Intel's OpenVino presents here : https://docs.openvino.ai/2026/documentation/openvino-ir-format/operation-sets/available-opsets/opset16.html where they developed an MLIR dialect that preserves those high level op sets here : https://github.com/openvinotoolkit/npu_compiler/tree/develop/src/vpux_compiler/tblgen/vpux/compiler/dialect/IE/ops ?

I realize this diverges slightly from my initial question, but I'm trying to understand the broader architectural picture and uses of stablehlo.

Thanks again !!
Hamza

Jose Fonseca

unread,
Mar 3, 2026, 6:39:58 AMMar 3
to Hamza ELBARAGHI, OpenXLA Discuss, Kevin Gleason
> Given its relatively small and stable op set, the higher-level constructs (attention, activation functions, etc.) are always fully decomposed into primitives ?


Kevin Gleason

unread,
Mar 4, 2026, 4:50:21 PM (13 days ago) Mar 4
to Jose Fonseca, Hamza ELBARAGHI, OpenXLA Discuss
> Is there any public reference or example of large models that have been exported to StableHLO ?
Bottom line is that all JAX models are lowered to StableHLO - anything in a jax.jit wrapper will go via StableHLO.

I made a repo with some public large JAX models awhile back, but didn't maintain it much, also embedded constants in some of the models which probably wasn't the best design choice: https://github.com/GleasonK/stablehlo-exports

> But what I've seen in practice when running inference is that these complex computations tend to be punched through StableHLO/HLO with Triton/Pallas kernels

+1 to this. Many of these abstractions get modeled with kernels (https://github.com/openxla/tokamax ragged_dot / sdpa), this spans the ML ecosystem beyond StableHLO AFAIK, i,e, in PyTorch when you need more perf/something like a fusion you rewrite it as a triton kernel.

As for "higher level ops" -- composites are definitely useful, the other thing to add is that for some higher level ops that have generally agreeable decompositions we add those to CHLO: https://openxla.org/stablehlo/generated/chlo

These ops are considered "optionally decomposable" -- CHLO->StableHLO decompositions exist for all of these ops, but backends can preserve the CHLO op if they support it. In XLA we support the following higher-level ops:
https://github.com/openxla/xla/blob/a2afad273932c832519f0debb0ddb111d2036577/xla/mlir_hlo/stablehlo_ext/transforms/chlo_preserve_high_level_ops.cpp#L230-L240

> Given its relatively small and stable op set, the higher-level constructs (attention, activation functions, etc.) are always fully decomposed into primitives
That all said, for the most part this is true. XLA deals with primitives and finds the best ways to fuse them. For instances where we cannot achieve performance parity (fusion not performant enough or pattern matching too fragile), kernels are used or new ops are introduced (consider chlo.ragged_dot as a prime example of a kernel migrated to XLA supported op).


Best,
Kevin

Reply all
Reply to author
Forward
0 new messages