Dear StableHLO Community,
Currently, when you lower the JAX model to StableHLO, mode gets wrapped into the location-metadata.
Example:
==========================================================================
#loc1 = loc("x")
#loc2 = loc("/home/ubuntu/tst/tst.py":43:0)
#loc3 = loc("/home/ubuntu/tst/tst.py":49:0)
#loc4 = loc("g"(#loc2))
#loc5 = loc("<module>"(#loc3))
#loc6 = loc(callsite(#loc4 at #loc5))
#loc8 = loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6))
module @jit_g attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{devices=[2]<=[2]}"} loc("x")) -> (tensor<4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<i32>) -> tensor<1xi32> loc(#loc7)
%2 = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc)
%3 = "stablehlo.scatter"(%arg0, %1, %2) ({
^bb0(%arg1: tensor<f32> loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6)), %arg2: tensor<f32> loc("jit(g)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc6))):
stablehlo.return %arg2 : tensor<f32> loc(#loc8)
}) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = true} : (tensor<4xf32>, tensor<1xi32>, tensor<f32>) -> tensor<4xf32> loc(#loc8)
return %3 : tensor<4xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc7 = loc("jit(g)/jit(main)/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc6))
==========================================================================
Can we add mode as an input of enum-type to the scatter op?
Thanks,
Pushkar