Adding 'mode' as an input-operand to scatter/gather Op

37 views
Skip to first unread message

Pushkar Ratnalikar

unread,
Aug 14, 2024, 3:35:00 AMAug 14
to OpenXLA Discuss
Dear StableHLO Community,

The current spec for scatter op in StableHLO - https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter does not have "mode" as an input operand.

Frameworks like JAX allow users to specify GatherScatterMode - https://github.com/google/jax/blob/main/jax/_src/lax/slicing.py#L247C7-L247C24 , which indicates what the runtime out-of-bounds (OOB) behavior would be.

Currently, when you lower the JAX model to StableHLO, mode gets wrapped into the location-metadata.

Example:
==========================================================================
#loc1 = loc("x")
#loc2 = loc("/home/ubuntu/tst/tst.py":43:0)
#loc3 = loc("/home/ubuntu/tst/tst.py":49:0)
#loc4 = loc("g"(#loc2))
#loc5 = loc("<module>"(#loc3))
#loc6 = loc(callsite(#loc4 at #loc5))
#loc8 = loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6))
module @jit_g attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<4xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{devices=[2]<=[2]}"} loc("x")) -> (tensor<4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<i32>) -> tensor<1xi32> loc(#loc7)
    %2 = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc)
    %3 = "stablehlo.scatter"(%arg0, %1, %2) ({
    ^bb0(%arg1: tensor<f32> loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6)), %arg2: tensor<f32> loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6))):
      stablehlo.return %arg2 : tensor<f32> loc(#loc8)
    }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = true} : (tensor<4xf32>, tensor<1xi32>, tensor<f32>) -> tensor<4xf32> loc(#loc8)
    return %3 : tensor<4xf32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc7 = loc("jit(g)/jit(main)/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc6))

==========================================================================

Can we add mode as an input of enum-type to the scatter op? 

Thanks,
Pushkar

Reply all
Reply to author
Forward
0 new messages