Hi everyone,
We are working on a pull request to add experimental ragged ops to HLO, using optionally decomposable chlo.ragged_xyz ops to map to the new opcodes. This will allow us to explore ragged concepts in hardware independent ways, starting most immediately with ragged_dot.
We'll be focusing on experiments with XLA:GPU and XLA:TPU. The CHLO fallback implementation will involve decomposing the ragged operation into a padded, regular tensor dot product.
What to Expect:
Performance: Current JAX users of lax.ragged_dot shouldn't see any performance surprises.
Optimization Potential: This allows us to explore compiler optimizations for ragged operations, and the API for the op may change as we refine this exploration. If the experiment proves successful, we will aim to standardize the op.
As we gain more confidence and gather results, we'll share updates and further details with the community.
Cheers,
|