MultiHeadAttention layer with input tensor of rank 4

50 views
Skip to first unread message

giuseppe...@gmail.com

unread,
Feb 7, 2024, 7:34:41 PMFeb 7
to Keras-users
The MultiHeadAttention layer in Keras is expected to receive a rank-3 tensor as an input: what if the input tensor is of rank 4? Is the layer supposed to work on the last dimension of the tensor?

 

Kwokleung Chan

unread,
Feb 15, 2024, 2:15:41 PMFeb 15
to Keras-users
Yes, in general, it will perform a 2D attention of axis[1] and axis[2]

For example,  if x.shape is [B,M,N,C], B is the batch size, C is the num of channels
for each sample x[b] in the batch, it will compute all possible attentions of x[b,m1,n1,:]  to x[b,m2,n2,:]
where 0 <= m1,m2 < M, 0 <=n1,n2<N
e.g.
layer = MultiHeadAttention(
    num_heads=7, key_dim=2, attention_axes=(2, 3))
input_tensor = tf.keras.Input(shape=[5, 33, 44, 16])
output_tensor, weights = layer(input_tensor, input_tensor, return_attention_scores=True)

print(output_tensor.shape)
print(weights.shape)
(None, 5, 33, 44, 16) (None, 5, 7, 33, 44, 33, 44)

Here the attention weight is (33x44) x (33x44),
repeated for every x[b]

if setting attention_axes=3
for each sample x[b] in the batch, it will compute all possible attentions of x[b,m,n1,:]  to x[b,m,n2,:]
where 0 <= m < M, 0 <=n1,n2<N

layer = MultiHeadAttention(
    num_heads=7, key_dim=2, attention_axes=(3))
input_tensor = tf.keras.Input(shape=[5, 33, 44, 16])
output_tensor, weights = layer(input_tensor, input_tensor, return_attention_scores=True)

print(output_tensor.shape)
print(weights.shape)
(None, 5, 33, 44, 16) (None, 5, 33, 7, 44, 44)
Here the attention weight is (44) x (44)
and repeated for every x[b,m]
Reply all
Reply to author
Forward
0 new messages