Optimization to select(pred, 1, 0) -> convert(pred). Is it beneficial?

18 views
Skip to first unread message

Alexander Pivovarov

unread,
Feb 13, 2024, 9:44:18 PMFeb 13
to OpenXLA Discuss

Hi Everyone,

I'd like to check if adding the following optimization to HLO algebraic_simplifier.cc would be beneficial?

select(pred, 1, 0) -> convert(pred)                    (1)

select(pred, 1, 0) -> broadcast(convert(pred))         (2)

Optimization:

If on_true array isAll(1) and on_false array isAll(0) then we can:

  1. replace select(pred, 1, 0) with convert(pred) (if pred.shape == on_true.shape)
  2. replace select(pred, 1, 0) with broadcast(convert(pred)) otherwise

Assumption is that convert (or convert + broadcast) computation(s) is cheaper than (or equal to) select computation on most of CPU/GPU/TPU performance models.

George Karpenkov

unread,
Feb 19, 2024, 3:33:03 AMFeb 19
to Alexander Pivovarov, OpenXLA Discuss
I think it's better to start with microbenchmarks first. Seems extremely unlikely to make any difference.

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAKKt98QPZCoYZVuZjtFgyDMSPO%3DTGkxvuZwg%2Bdp2XDrHWtTfhQ%40mail.gmail.com.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Alexander Pivovarov

unread,
Feb 21, 2024, 2:40:02 AMFeb 21
to George Karpenkov, OpenXLA Discuss
Hi George, I have an update on this. We are shifting our focus to a more complex pattern to simplify - Causal Mask HLO for jnp.tril(jnp.ones()) can be simplified #9709. What do you think about this idea?

Thank you
Alex
Reply all
Reply to author
Forward
0 new messages