Limits to XLA variable de-capture

21 views
Skip to first unread message

Joel Berkeley

unread,
Jun 14, 2026, 9:28:07 PM (10 days ago) Jun 14
to OpenXLA Discuss
Hi,

XLA can't do variable capture for higher-order functions, but it can lift values into the necessary scope if it's valid. For example, in the pseudo-code
```
main () {
   %0 tensor<1.0>
   %1tensor<2.0>
   %2 map(%0) (x => x + %0) %1
   return %2
}
```
it can move %0 into the function `x => x + %0`, removing variable capture. What are the limits on this? I've found this code is rejected, even though it's straightforward to do the above with it
```
module @root {
  func.func @main() -> tensor<f64> {
    %cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %cst_1 = stablehlo.constant dense<3.000000e-01> : tensor<1xf64>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %0 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f64>) -> tensor<f64>
    %1 = stablehlo.reduce(%cst_1 init: %0) applies stablehlo.add across dimensions = [0] : (tensor<1xf64>, tensor<f64>) -> tensor<f64>
    %2 = stablehlo.while(%iterArg = %cst) : tensor<f64>
    cond {
      %3 = stablehlo.compare GT, %iterArg, %cst_0 : (tensor<f64>, tensor<f64>) -> tensor<i1>
      stablehlo.return %3 : tensor<i1>
    } do {
      %3 = "stablehlo.map"(%iterArg) <{dimensions = array<i64>}> ({
      ^bb0(%arg0: tensor<f64>):
        %4 = stablehlo.subtract %arg0, %1 : tensor<f64>
        stablehlo.return %4 : tensor<f64>
      }) : (tensor<f64>) -> tensor<f64>
      stablehlo.return %3 : tensor<f64>
    }
    return %2 : tensor<f64>
  }
}
```

Joel Berkeley

unread,
Jun 14, 2026, 9:38:30 PM (10 days ago) Jun 14
to OpenXLA Discuss, Joel Berkeley
this is XLA commit 65f49e0e74ffdbfc9f475dec50607f35d368bd32
Reply all
Reply to author
Forward
0 new messages