HLO Broadcast with decreasing broadcast_dimensions issues

72 views
Skip to first unread message

Alexander Pivovarov

unread,
Jun 17, 2024, 7:05:53 PMJun 17
to OpenXLA Discuss
Hello Everyone

I faced the following issues with hlo Broadcast Operation in case when broadcast_dimensions are in decreasing order (swapped)

Issue 1 - Jax does not support it -  broadcast_dimensions must be strictly increasing:

import jax
import jax.numpy as jnp
x=jnp.arange(9*12).reshape((9,12))
y=jax.lax.broadcast_in_dim(x, (1,2,12,9), (3,2))

TypeError: broadcast_in_dim broadcast_dimensions must be strictly increasing; got broadcast_dimensions (3, 2)


Issue 2 - run_hlo_module validation failed. CPU and the Interpreter outputs are different
Interpreter output is transposed
CPU output is just reshaped

ENTRY %main.3 {
  %Arg_0.1 = s32[9,12]{1,0} parameter(0)
  ROOT %broadcast.2 = s32[1,2,12,9]{3,2,1,0} broadcast(%Arg_0.1), dimensions={3,2}
}

./bazel-bin/xla/tools/run_hlo_module \

--input_format=hlo \

--platform=CPU a.hlotxt \

--print_literals=true


1/1 runs failed.


Questions:
1. Is it a valid hlo when broadcast_dimensions are in decreasing order?
2. Should we add additional validation to xla/service/hlo_verifier.cc to check that broadcast->dimensions() are in increasing order?
3. If decreasing order is actually allowed what should be in the output - transposed or just reshaped input?

Benjamin Chetioui

unread,
Jun 18, 2024, 4:12:21 AMJun 18
to Alexander Pivovarov, OpenXLA Discuss
Hi Alexander,

1. Is it a valid hlo when broadcast_dimensions are in decreasing order?
It shouldn't be! As you can see on the page describing the semantics of broadcasting, XLA expects that "The order of broadcast dimensions must be strictly increasing.", matching the semantics of JAX.

2. Should we add additional validation to xla/service/hlo_verifier.cc to check that broadcast->dimensions() are in increasing order?
Seems like the right course of action to me.

Thanks!
Benjamin

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAKKt98Qj_xR6TyN7OxGP6ASLX3D0LH4hBPNKLG279s%3DBhvAUkQ%40mail.gmail.com.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Alexander Pivovarov

unread,
Jun 18, 2024, 11:52:50 AMJun 18
to Benjamin Chetioui, OpenXLA Discuss
Thank you, Benjamin!

One additional question

xla/mlir_hlo/mhlo/IR/hlo_ops.cc has pattern eliminateBroadcastInDimTranspose to fold transpose(broadcast_in_dim(X)) => broadcast_in_dim(X)
This pattern is used in getCanonicalizationPatterns. (in mlir Canonicalizer pass)

Example:
Input mhlo:
%124 = "mhlo.reshape"(%123) : (tensor<1x1x96x128xf32>) -> tensor<96x128xf32>
%125 = "mhlo.broadcast_in_dim"(%124) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<96x128xf32>) -> tensor<1x2x96x128xf32>
%126 = "mhlo.transpose"(%125) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x2x96x128xf32>) -> tensor<1x2x128x96xf32>
%127 = "mhlo.reshape"(%126) : (tensor<1x2x128x96xf32>) -> tensor<2x128x96xf32>
 
Output mhlo:
%127 = "mhlo.reshape"(%126) : (tensor<1x1x96x128xf32>) -> tensor<96x128xf32>
%128 = "mhlo.broadcast_in_dim"(%127) {broadcast_dimensions = dense<[3, 2]> : tensor<2xi64>} : (tensor<96x128xf32>) -> tensor<1x2x128x96xf32>
%129 = "mhlo.reshape"(%128) : (tensor<1x2x128x96xf32>) -> tensor<2x128x96xf32>

After this pattern is applied it removes transpose Op and generates broadcast_in_dim with decreasing broadcast_dimensions.
What should we do with this pattern then? remove it?

Alexander Pivovarov

unread,
Jun 18, 2024, 11:53:36 PMJun 18
to Benjamin Chetioui, OpenXLA Discuss
hlo_verifier has HandleBroadcast check for broadcast dimensions order - VerifyBroadcastDimensionsOrder - it is false by default

Also XLA/HLO has a pass called broadcast_canonicalizer - BroadcastCanonicalizer
It ensures that dimensions in all broadcast operations are sorted in ascending order.
It sorts broadcast dims. Then insert a transpose on the broadcast to get the original shape back.
The pass is used in both cpu and gpu compilers

However,
mhlo eliminateBroadcastInDimTranspose pass is doing the opposite work - it removes transpose OP and makes broadcast dimensions not ordered
It seems that what is canonical for XLA/HLO is not canonical for MHLO, and vice versa…

Kevin Gleason

unread,
Jun 20, 2024, 12:51:33 PMJun 20
to OpenXLA Discuss, Alexander Pivovarov, OpenXLA Discuss, Benjamin Chetioui, liuyuanqi...@bytedance.com
cc liuyuanqi...@bytedance.com who authored the MHLO canonicalization in https://github.com/tensorflow/tensorflow/pull/53532

Interested if broadcast being able to represent transpose is useful in the compilation pipeline at bytedance, i.e. does this pattern map well to some lower level dialect?


Best,
Kevin
Reply all
Reply to author
Forward
0 new messages