--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/09f8666a-10b4-4ce3-bc2f-6fda25172533n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.
So if I’m understanding this, then scatter is a way for us to represent an in-place update, and this isn’t explicitly returned by the main function. If that’s the case, I suppose that makes it much more clear that we don’t have any copies.
But whatever code-gen we have should recognize and handle that? And the input to the public function would include the mutable buffer that we use as KV-cache.
I wasn’t able to see
stablehlo.scatter
under the
spec’s ops, though there is some scattered info in the
GitHub history.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/ab206517-e3b4-4686-85fb-18abb800f50f%40fractile.ai.
stablehlo.scatter ... it's right here :) https://openxla.org/stablehlo/spec#scatterTo view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CACN49R19E%3D8k0moHbm989iPTAwyM8AgRx7%2BVJ0KCaTVH7t6BNA%40mail.gmail.com.
dynamic_update_slice (https://openxla.org/stablehlo/spec#dynamic_update_slice) (with jax.lax, which translates +/- directly to StableHLO I assume) in https://github.com/google-deepmind/gemma/blob/2ea41628173cd88de9ab6963e628889faec86ff5/gemma/modules.py#L155 -- notice that the updated cache is returned here.stablehlo.scatter and dynamic_update_slice operations, for the purpose of sparsely updating large embedding tables (or KV-caches) ? Can one expect them to be equally efficient when paired with the "donated arguments" (in PJRT API) ? To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CABAdzO8t-ha-Sf62uzyGfUWfDQfZHiVnXwLb5mE-Ubz8C%3D3nqw%40mail.gmail.com.
stablehlo.scatterin IREE, but not for dynamic_update_slice. Did I understand that right ? Actually, does IREE support dynamic_update_slice (or is it just a wrapper around scatter?) To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAE%3D7LsXYAsj1rcHWeG%2B7CsqOakv4P3TNuOk3_YtXNMxyTsB7XQ%40mail.gmail.com.
Thanks for the follow up Stella! Let me see if I parsed your reply correctly:
- You are referring mostly to IREE ?
- The special treatment/management of persistent page tables (due to parallelization and update boundaries) was only done for
stablehlo.scatterin IREE, but not fordynamic_update_slice. Did I understand that right ? Actually, does IREE support dynamic_update_slice (or is it just a wrapper around scatter?)
- This only applies for the CPU implementation, but not necessarily for GPU (no persistent page tables there I believe?), is that correct ?
Thanks again Stella and Jacques for the replies and clarifications.I do appreciate the complexity of parallel updates on the generic case of sparse updates, and the need for representation of these updates (since if done wrongly they can be costly). The problem being compounded on distributed models I imagine.My question (sorry, now I realize I may be hijacking the thread, I probably should have started another one) was more specific for the XLA/PJRT implementation of Scatter vs DynamicUpdateSlice, if they will yield +/- the same compilation (same fusions/optimizations) for an equivalent computation ?I'm assuming in XLA/PJRT the way to make these sparse updates efficiently is using the "donate buffer" mechanism -- that will allow it to be done in-place --, is that correct ? Or is there another way/mechanism in XLA/PJRT to do it efficiently ?
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAE%3D7LsVcawi35GqeZCO5YnJ%2BbxtPx56o6t4ZZ2nPgRabVx7%2ByA%40mail.gmail.com.