Groups keyboard shortcuts have been updated
Dismiss
See shortcuts

Include token as part of the input/output tuple in all-gather and reduce-scatter

115 views
Skip to first unread message

Jeff Huynh

unread,
Dec 11, 2023, 3:01:06 PM12/11/23
to OpenXLA Discuss
Currently reduce-scatter/all-gather in XLA has a different token propagating mechanism compared allreduce. Token is used to ensure ordering between CC ops. Whereas allreduce passes the token as XLA token-type through XLA allreduce op, reduce-scatter/all-gather use a 0.0 float to add to an input, and then multiplied with output to prevent token from being eliminated during DCE (correct me if I am wrong here). These add/multiply add unnecessary computations and slow things down.

This openxla PR https://github.com/openxla/xla/pull/7677 attempts to address the inconsistency. The proposal is to add the chained token that matches all-reduce's use of token datatype.

Rahul Joshi

unread,
Dec 11, 2023, 3:10:59 PM12/11/23
to Jeff Huynh, OpenXLA Discuss
Hi Jeff,

There was some discussion about this in an earlier version of the PR. OpenXLA currently does not support any ordering of collectives using tokens for any collectives as far as I am aware. Can you point out where we allow chaining of all-reduces using tokens for all-reduce? 

Also, what is the use case for ordering collectives? It seems it's trying to enforce a particular schedule for collectives, but the motivation is not clear. That can help us drive the design discussion here.

Thanks
Rahul


On Mon, Dec 11, 2023 at 12:01 PM Jeff Huynh <j2h...@gmail.com> wrote:
Currently reduce-scatter/all-gather in XLA has a different token propagating mechanism compared allreduce. Token is used to ensure ordering between CC ops. Whereas allreduce passes the token as XLA token-type through XLA allreduce op, reduce-scatter/all-gather use a 0.0 float to add to an input, and then multiplied with output to prevent token from being eliminated during DCE (correct me if I am wrong here). These add/multiply add unnecessary computations and slow things down.

This openxla PR https://github.com/openxla/xla/pull/7677 attempts to address the inconsistency. The proposal is to add the chained token that matches all-reduce's use of token datatype.

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/1fc02ea2-fc5a-463d-b0fe-5154c87f0998n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Jeff Huynh

unread,
Dec 13, 2023, 4:45:39 PM12/13/23
to OpenXLA Discuss, Rahul Joshi, OpenXLA Discuss, Jeff Huynh
Thanks Rahul,

While XLA doesn't explicitly have a token in allreduce yet, it is stated that having a token type data is a future goal according to the comments here.

      if (!inst->shape().IsArray()) {
        // We currently do not change tuple-shaped all-reduce.
        // Until XLA will support Token fed AllReduce(), the PyTorch client code
        // uses a fake data token (constant) which relies on this pass to not
        // optimize out (being fed within a tuple input).
        continue;
      }
Furthermore, the fact that this if statement exists accommodates the allreduce use-case in PyTorch-XLA that includes the token value as part of the input/output tuple/list.

  // TODO: We use pseudo-tokens ATM, which are real values. This need to be
  // switched to use the real XLA Token once support has been added to XLA
  // AllReduce().
  xla::XlaOp chained_token = token;
  ReduceContext redux = GetReduceContext(operands);
  std::vector<xla::XlaOp> result(operands.size());
  for (auto& type_ctx : redux.contexts) {
    xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
    type_ctx.second.ops.push_back(token_op);
    type_ctx.second.operand_shapes.push_back(
        ShapeHelper::ShapeOfXlaOp(token_op));

Jeff Huynh

unread,
Dec 15, 2023, 12:49:48 AM12/15/23
to OpenXLA Discuss, Jeff Huynh, Rahul Joshi, OpenXLA Discuss
Ordering is needed to ensure that if by chance the two data parallel workers each get a different graph, the compilation would still preserve the order of the CC operations so that deadlock due to CC operations don't occur during runtime.

Jeff Huynh

unread,
Feb 13, 2025, 11:57:58 AMFeb 13
to OpenXLA Discuss, Jeff Huynh, Rahul Joshi, OpenXLA Discuss
Hi Rahul,

I would like to revisit this issue as we see all-reduce has support for tuple but reduce-scatter/all-gather doesn't have, so token management is inconsistent between these CC-ops. For example, we see the following in torch-xla all-reduce where we were able to pass the token inside the all-reduce op:

all-reduce.82 = (bf16[1,2048,8192]{2,1,0}, bf16[]) all-reduce(bf16[1,2048,8192]{2,1,0} %multiply.74, bf16[] %p7.66), replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15},{16,17,18,19,20,21,22,23},{24,25,26,27,28,29,30,31}}, to_apply=%AddComputation.78)

However, for all-gather we are not able to do that because we are missing https://github.com/openxla/xla/pull/7677 :

%all-gather.135 = bf16[2048,1,8192]{2,1,0} all-gather(bf16[256,1,8192]{2,1,0} %add.130), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15},{16,17,18,19,20,21,22,23},{24,25,26,27,28,29,30,31}}, dimensions={0})

Will you be able to reopen this case?

Jeff

Reply all
Reply to author
Forward
0 new messages