StableHLO and KV-cache

312 views
Skip to first unread message

Perry Gibson

unread,
Sep 17, 2024, 10:13:22 AM9/17/24
to OpenXLA Discuss
I've been working with StableHLO and PyTorch, and I've trying to reason about how the KV-caches for LLMs can be represented.

Often KV-caches are managed by the runtime, so perhaps it doesn't necessarily make sense to think of a StableHLO main func taking KV-cache as input and returning a new KV-cache.

That being said, DNN framework level annotation of where KV-caches are, and their configuration could be useful, and I'm unsure where else that sort of information would be stored in not somewhere in the MLIR.

As a demo, I've got two PyTorch models for single head attention, and export both to MLIR.
One does not have a KV-cache, the other uses the torchtune KV-cache module.

I have both the PyTorch models and export to StableHLO in this GitHub gist (note that I am using fixed data sizes for simplicity).

Observe how the KV-cache version takes 2 additional arguments, as well as some scatter ops.

```mlir
    %21 = "stablehlo.scatter"(%arg12, %c, %20) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 1, 3], inserted_window_dims = [2], scatter_dims_to_operand_dims = [2], index_vector_dim = 1>, unique_indices = false}> ({
    ^bb0(%arg15: tensor<f32>, %arg16: tensor<f32>):
      stablehlo.return %arg16 : tensor<f32>
    }) : (tensor<5x1x10x8xf32>, tensor<10x1xi64>, tensor<5x1x10x8xf32>) -> tensor<5x1x10x8xf32>
```

However note that both models in StableHLO just return a single tensor, so the KV-cache is not used in the return value.

I'm curious if this KV-cache model is "well-formed", or it just so happens that the export worked.

I am actively looking to see other work in StableHLO, IREE, etc on statefulness, though haven't found anything yet.

Stella Laurenzo

unread,
Sep 17, 2024, 12:33:49 PM9/17/24
to Perry Gibson, OpenXLA Discuss
That doesn't look "well formed" for a pure StableHLO module. Since StableHLO is purely value-tensor based, mutations like scatter must be returned and managed externally somehow.

In IREE, when we do KV caches (whether traditional or page table based), it is often from Torch directly, where the generated stubs that wrap the value-SSA inner module handle the loads/stores/synchronization. The result is that the public function exported to the user takes a mutable buffer and mutations are properly tied so that everything is in-place from end to end. While in theory, this same treatment could be done by hand or systematically for StableHLO, we only have that ABI implemented for direct use from PyTorch: FX models mutation at the boundary in the way I describe so it is a fairly natural fit/transformation.

It's pretty important that KV caches be actually in-place and sparsely accessed in device memory. Even a single copy eliminates any benefit. I'm not sure how to do that outside of torch.export, which models those concerns (I think everything else basically leaves it as an exercise for whatever is wrapping it, which is hard to reason about).

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/09f8666a-10b4-4ce3-bc2f-6fda25172533n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Perry Gibson

unread,
Sep 17, 2024, 12:43:48 PM9/17/24
to Stella Laurenzo, OpenXLA Discuss

So if I’m understanding this, then scatter is a way for us to represent an in-place update, and this isn’t explicitly returned by the main function. If that’s the case, I suppose that makes it much more clear that we don’t have any copies.

But whatever code-gen we have should recognize and handle that? And the input to the public function would include the mutable buffer that we use as KV-cache.

I wasn’t able to see stablehlo.scatter under the spec’s ops, though there is some scattered info in the GitHub history.

Stella Laurenzo

unread,
Sep 17, 2024, 1:59:56 PM9/17/24
to Perry Gibson, OpenXLA Discuss
Yes, in an SSA-value oriented representation like StableHLO, such mutation ops are modeled as producing a value that can be thought of as a copy of the original with updates applied.

The middle end of the compiler has to do buffer alias tracking to ensure that such things are truly in-place and then, of course, codegen has to handle it. For common ops on such an in-place cache, this typically either involves a custom op to perform the non-contiguous computation or a compiler that can do gather-fusion. The boundary of the public function is where it gets very system specific.  What I'm describing is how IREE models it with respect to mutable buffers, based on metadata that PyTorch provides about which inputs are mutated and tied to which output chain from the raw FX function.

Han Qi

unread,
Sep 17, 2024, 3:44:20 PM9/17/24
to Perry Gibson, Stella Laurenzo, OpenXLA Discuss
Echoing what Stella has said: You want to pass KVCache explicitly, and you want to either have the model to return the updated KVCache, or, have the model to return the slice to update (and do the insertion by hand outside of StableHLO). 

In the former case, the StableHLO graph itself does not define how it should be executed. So in case of XLA runtime, it can use input KVCache buffer to hold the updated data if you explicitly mark that buffer as "donated buffer". (In jax this is done by passing `donate_argnums` argument to `jax.jit`). 

Gunhyun Park

unread,
Sep 18, 2024, 1:14:46 PM9/18/24
to Han Qi, Perry Gibson, Stella Laurenzo, OpenXLA Discuss
Re: I wasn't able to see stablehlo.scatter ... it's right here :) https://openxla.org/stablehlo/spec#scatter

Jan Pfeifer

unread,
Sep 20, 2024, 10:47:34 AM9/20/24
to Gunhyun Park, Han Qi, Perry Gibson, Stella Laurenzo, OpenXLA Discuss
Google's Gemma uses see dynamic_update_slice (https://openxla.org/stablehlo/spec#dynamic_update_slice) (with jax.lax, which translates +/- directly to StableHLO I assume) in https://github.com/google-deepmind/gemma/blob/2ea41628173cd88de9ab6963e628889faec86ff5/gemma/modules.py#L155 -- notice that the updated cache is returned here.

Is there anything done differently in PJRT plugins for stablehlo.scatter and  dynamic_update_slice  operations, for the purpose of sparsely updating large embedding tables (or KV-caches) ? Can one expect them to be equally efficient when paired with the "donated arguments" (in PJRT API)

ps.: I'm doing the exact same here, in a Go port of Gemma v2 model.

Stella Laurenzo

unread,
Sep 20, 2024, 11:00:19 AM9/20/24
to Jan Pfeifer, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
For one-shot caches like this, I would expect this to work out (I'm not intimately familiar with the internals of XLA to say for certain, but I expect it is working by design). When you get to more complicated arrangements like persistent page tables, we've required more detailed work to sever the dependency chain at the runtime boundary since in those cases, you want to be able to have multiple invocations mutating the page-table concurrently (with accounting to make sure they don't overlap vs implicit producer-consumer chaining in the runtime). Not sure how that manifests in the "donated arguments" model, but you need both in-order and hog-wild access for generality.

Jan Pfeifer

unread,
Sep 20, 2024, 11:21:49 AM9/20/24
to Stella Laurenzo, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
Thanks for the follow up Stella! Let me see if I parsed your reply correctly:
  • You are referring mostly to IREE ?
  • The special treatment/management of persistent page tables (due to parallelization and update boundaries) was only done for stablehlo.scatterin IREE, but not for dynamic_update_slice. Did I understand that right ? Actually, does IREE support dynamic_update_slice (or is it just a wrapper around scatter?) 
  • This only applies for the CPU implementation, but not necessarily for GPU (no persistent page tables there I believe?), is that correct ?
thanks again!

Jacques Pienaar

unread,
Sep 20, 2024, 11:43:27 AM9/20/24
to Jan Pfeifer, Stella Laurenzo, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
I think it's a more general discussion: representation of in place mutation. One form is outside StableHLO, one form requires analysis, and the other is explicit. This is independent of IREE (or CPU/GPU).

If I'm following, the question being discussed at the moment is: if you have two functions that both finely modify a buffer but donation is at buffer granularity, is PJRT allowed to run both in parallel is the question? If only thing one does is purely sequential, then the 3 ways above should all handle it is the claim. But if one needs to donate buffer in call, it would seem to suggest ownership transfer and inability to have parallel writes even where it may be safe for model.

In the original question it would seem one has to thread the KV cache as tensor through all compute for XLA to enable in place, it would result in serializing due to dependency on whole tensor but then buffer donation should suffice. 

-- Jacques 

Stella Laurenzo

unread,
Sep 20, 2024, 12:04:52 PM9/20/24
to Jan Pfeifer, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
On Fri, Sep 20, 2024 at 8:21 AM Jan Pfeifer <pfe...@gmail.com> wrote:
Thanks for the follow up Stella! Let me see if I parsed your reply correctly:
  • You are referring mostly to IREE ?

I was trying to not respond in a system specific way but was noting that the use case you were referring to is for sequential access to a one-shot KV cache buffer, whereas there is another use case where a large device-resident page table is shared for concurrent access and KV cache "slices" are managed externally. In that case, you want everything processing the request to be able to mutate the buffer concurrently. I don't know how XLA handles this but I know we've had to do work on other systems to get around the implicit synchronization that is often induced by the typical dependency chain. I was primarily advising to watch out if approaching this case, as it is likely a different question that needs an answer. Sorry - wasn't trying to distract but just put a road sign down the highway to watch out for.
 
  • The special treatment/management of persistent page tables (due to parallelization and update boundaries) was only done for stablehlo.scatterin IREE, but not for dynamic_update_slice. Did I understand that right ? Actually, does IREE support dynamic_update_slice (or is it just a wrapper around scatter?) 
It's usually a chain of scatters in principle. In optimized implementations, typically, the attention mechanism operates in-place on the non-contiguous KV cache pages as well. This can be modeled in multiple ways (some implementations use custom ops for this, and in others, the compiler can fuse the gather/scatters with the compute for you). 
  • This only applies for the CPU implementation, but not necessarily for GPU (no persistent page tables there I believe?), is that correct ?
It applies anywhere you want to use a larger paged cache. Typically shows up in batch/throughput serving systems whether on GPU or CPU.

Jan Pfeifer

unread,
Sep 20, 2024, 12:31:26 PM9/20/24
to Stella Laurenzo, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
Thanks again Stella and Jacques for the replies and clarifications.

I do appreciate the complexity of parallel updates on the generic case of sparse updates, and the need for representation of these updates (since if done wrongly they can be costly). The problem being compounded on distributed models I imagine.

My question (sorry, now I realize I may be hijacking the thread, I probably should have started another one) was more specific for the XLA/PJRT implementation of Scatter vs DynamicUpdateSlice, if they will yield +/- the same compilation (same fusions/optimizations) for an equivalent computation ? 

I'm assuming in XLA/PJRT the way to make these sparse updates efficiently is using the "donate buffer" mechanism -- that will allow it to be done in-place --, is that correct ? Or is there another way/mechanism in XLA/PJRT to do it efficiently ?


Peter Hawkins

unread,
Sep 20, 2024, 12:45:13 PM9/20/24
to Jan Pfeifer, Stella Laurenzo, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
On Fri, Sep 20, 2024 at 12:31 PM Jan Pfeifer <pfe...@gmail.com> wrote:
Thanks again Stella and Jacques for the replies and clarifications.

I do appreciate the complexity of parallel updates on the generic case of sparse updates, and the need for representation of these updates (since if done wrongly they can be costly). The problem being compounded on distributed models I imagine.

My question (sorry, now I realize I may be hijacking the thread, I probably should have started another one) was more specific for the XLA/PJRT implementation of Scatter vs DynamicUpdateSlice, if they will yield +/- the same compilation (same fusions/optimizations) for an equivalent computation ? 

I'm assuming in XLA/PJRT the way to make these sparse updates efficiently is using the "donate buffer" mechanism -- that will allow it to be done in-place --, is that correct ? Or is there another way/mechanism in XLA/PJRT to do it efficiently ?


Yes. The way this is done at the moment, anyway, is that you tell XLA you'd like to alias the input and output, and you donate the buffer at the PJRT level.

In XLA's buffer assignment, it always prefers to alias scatter and dynamic-update-slice inputs and outputs, although that doesn't rule out other sources of copies.

Peter
 

Jan Pfeifer

unread,
Sep 21, 2024, 5:27:04 AM9/21/24
to Peter Hawkins, Stella Laurenzo, Gunhyun Park, Han Qi, Perry Gibson, OpenXLA Discuss
Thanks for the more specific reply Peter.

You mentioned "you tell XLA you'd like to alias the input and output" ? How do you do that ? 

Currently I'm simply setting the "donate buffer" mechanism (as using this configuration in PJRT api) , and I assume PJRT does the aliasing automatically. Do I need to explicitly set this "alias of input to output" some other way in the API ?

Reply all
Reply to author
Forward
0 new messages