Causal Mask HLO for jnp.tril(jnp.ones()) can be simplified

25 views
Skip to first unread message

Alexander Pivovarov

unread,
Feb 20, 2024, 11:25:33 PMFeb 20
to OpenXLA Discuss

Hello Everyone

Quite often we see the following JAX code to create and apply causal mask in the self-attention layer.

import jax
import jax.numpy as jnp
from jax import Array, random

key = random.PRNGKey(42)
qk = random.uniform(key, shape=(4,4))

def apply_causal_mask(qk: Array):
  seq_len = qk.shape[-1]
  mask = jnp.tril(jnp.ones((seq_len, seq_len))).astype('bool')
  return jnp.where(mask, qk, -jnp.inf)

If we jit + lower + compile + as_text this function

print(jax.jit(apply_causal_mask).lower(qk).compile().as_text())

then we will get the following hlo:

HloModule jit_apply_causal_mask, entry_computation_layout={(f32[4,4]{1,0})->f32[4,4]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.1: f32[4,4]) -> f32[4,4] {
  %iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5}
  %iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5}
  %compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5}

  %constant.5 = f32[] constant(1)
  %broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={}
  %constant.3 = f32[] constant(0)
  %broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={}
  %select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5}
  %compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5}

  %param_0.1 = f32[4,4]{1,0} parameter(0)
  %constant.1 = f32[] constant(-inf)
  %broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8}
  ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.1, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
}

ENTRY %main.26 (Arg_0.1: f32[4,4]) -> f32[4,4] {
  %Arg_0.1 = f32[4,4]{1,0} parameter(0), sharding={replicated}
  ROOT %fusion = f32[4,4]{1,0} fusion(f32[4,4]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
}

We noticed that fused_computation can be simplified.

In particular, Mask related code can be simplified to just iota + iota + compare

%fused_computation (param_0.1: f32[4,4]) -> f32[4,4] {
  %iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5}
  %iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5}
  %compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5}

  %param_0.1 = f32[4,4]{1,0} parameter(0)
  %constant.1 = f32[] constant(-inf)
  %broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8}
  ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8}
}

As a result the following several ops can be removed

  %constant.5 = f32[] constant(1)
  %broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={}
  %constant.3 = f32[] constant(0)
  %broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={}
  %select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5}
  %compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5}

What do you think about recognizing such a pattern and applying the described simplification to it?

Seems like a very common use-case in LLM models.

Our team will be happy to work on it.


Alex

George Karpenkov

unread,
Mar 11, 2024, 5:43:59 AMMar 11
to Alexander Pivovarov, OpenXLA Discuss
Sorry for the late reply. If you do this simplification manually, how much performance do you get?

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAKKt98TXdGMOO8BJb9OHm1dzm_qh2BqfZ7NpWzeStQx6vB3uLg%40mail.gmail.com.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Peter Hawkins

unread,
Mar 11, 2024, 11:57:47 AMMar 11
to George Karpenkov, Alexander Pivovarov, OpenXLA Discuss
I just saw this also! (I was out of office when the original message was sent.)

Note that another option is likely to use `jnp.tri`, which I believe avoids the need to form the jnp.ones() matrix at all in the IR.

Although I'd agree that XLA should do this simplification.

Peter

Alexander Pivovarov

unread,
Mar 11, 2024, 4:38:50 PMMar 11
to Peter Hawkins, George Karpenkov, OpenXLA Discuss
Hi George, hi Peter

The investigation showed that Jax produced optimal hlo out of the box for the following cases:
jnp.tri(seq_len, dtype=bool)
jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))

However model creators might use slightly different code in their models. e.g. they might forgot to use dtype=bool in ones() or mistakenly applied .astype(bool) to the result of tril() instead of to ones()

The following PR recognizes sub-optimal hlo pattern for Causal mask and simplifies it
Ne(select(Ge(a, b), ones, zeros), zeros) -> Ge(a, b)
PR-9867 Causal mask suboptimal HLO simplification (Merged)

Alex

Reply all
Reply to author
Forward
0 new messages