Hello Everyone
Quite often we see the following JAX code to create and apply causal mask in the self-attention layer.
import jax import jax.numpy as jnp from jax import Array, random key = random.PRNGKey(42) qk = random.uniform(key, shape=(4,4)) def apply_causal_mask(qk: Array): seq_len = qk.shape[-1] mask = jnp.tril(jnp.ones((seq_len, seq_len))).astype('bool') return jnp.where(mask, qk, -jnp.inf)
If we jit + lower + compile + as_text this function
print(jax.jit(apply_causal_mask).lower(qk).compile().as_text())
then we will get the following hlo:
HloModule jit_apply_causal_mask, entry_computation_layout={(f32[4,4]{1,0})->f32[4,4]{1,0}}, allow_spmd_sharding_propagation_to_output={true} %fused_computation (param_0.1: f32[4,4]) -> f32[4,4] { %iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5} %iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5} %compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5} %constant.5 = f32[] constant(1) %broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={} %constant.3 = f32[] constant(0) %broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={} %select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5} %compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5} %param_0.1 = f32[4,4]{1,0} parameter(0) %constant.1 = f32[] constant(-inf) %broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8} ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.1, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8} } ENTRY %main.26 (Arg_0.1: f32[4,4]) -> f32[4,4] { %Arg_0.1 = f32[4,4]{1,0} parameter(0), sharding={replicated} ROOT %fusion = f32[4,4]{1,0} fusion(f32[4,4]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8} }
We noticed that fused_computation
can be simplified.
In particular, Mask related code can be simplified to just iota + iota + compare
%fused_computation (param_0.1: f32[4,4]) -> f32[4,4] { %iota.5 = s32[4,4]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]" source_file="<stdin>" source_line=5} %iota.4 = s32[4,4]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]" source_file="<stdin>" source_line=5} %compare.2 = pred[4,4]{1,0} compare(s32[4,4]{1,0} %iota.5, s32[4,4]{1,0} %iota.4), direction=GE, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/ge" source_file="<stdin>" source_line=5} %param_0.1 = f32[4,4]{1,0} parameter(0) %constant.1 = f32[] constant(-inf) %broadcast.6 = f32[4,4]{1,0} broadcast(f32[] %constant.1), dimensions={}, metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]" source_file="<stdin>" source_line=8} ROOT %select.2 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %param_0.1, f32[4,4]{1,0} %broadcast.6), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(_where)/select_n" source_file="<stdin>" source_line=8} }
As a result the following several ops can be removed
%constant.5 = f32[] constant(1) %broadcast.8 = f32[4,4]{1,0} broadcast(f32[] %constant.5), dimensions={} %constant.3 = f32[] constant(0) %broadcast.7 = f32[4,4]{1,0} broadcast(f32[] %constant.3), dimensions={} %select.3 = f32[4,4]{1,0} select(pred[4,4]{1,0} %compare.2, f32[4,4]{1,0} %broadcast.8, f32[4,4]{1,0} %broadcast.7), metadata={op_name="jit(apply_causal_mask)/jit(main)/jit(tril)/select_n" source_file="<stdin>" source_line=5} %compare.1 = pred[4,4]{1,0} compare(f32[4,4]{1,0} %select.3, f32[4,4]{1,0} %broadcast.7), direction=NE, metadata={op_name="jit(apply_causal_mask)/jit(main)/convert_element_type[new_dtype=bool weak_type=False]" source_file="<stdin>" source_line=5}
What do you think about recognizing such a pattern and applying the described simplification to it?
Seems like a very common use-case in LLM models.
Our team will be happy to work on it.
Alex
--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAKKt98TXdGMOO8BJb9OHm1dzm_qh2BqfZ7NpWzeStQx6vB3uLg%40mail.gmail.com.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CAP8FB4Kdh39RPzr0Mcct0bWAbx2w0Os6uAZrSrcAGnUSNn6HKg%40mail.gmail.com.
dtype=bool
in ones()
or mistakenly applied .astype(bool)
to the result of tril()
instead of to ones()