Yes, in general, it will perform a 2D attention of axis[1] and axis[2]
For example, if x.shape is [B,M,N,C], B is the batch size, C is the num of channels
for each sample x[b] in the batch, it will compute all possible attentions of x[b,m1,n1,:] to x[b,m2,n2,:]
where 0 <= m1,m2 < M, 0 <=n1,n2<N
e.g.
layer = MultiHeadAttention(
num_heads=7, key_dim=2, attention_axes=(2, 3))
input_tensor = tf.keras.Input(shape=[5, 33, 44, 16])
output_tensor, weights = layer(input_tensor, input_tensor, return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)
(None, 5, 33, 44, 16)
(None, 5, 7, 33, 44, 33, 44)
Here the attention weight is (33x44) x (33x44),
repeated for every x[b]
if setting attention_axes=3
for each sample x[b] in the batch, it will compute all possible attentions of x[b,m,n1,:] to x[b,m,n2,:]
where 0 <= m < M, 0 <=n1,n2<N
layer = MultiHeadAttention(
num_heads=7, key_dim=2, attention_axes=(3))
input_tensor = tf.keras.Input(shape=[5, 33, 44, 16])
output_tensor, weights = layer(input_tensor, input_tensor, return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)
(None, 5, 33, 44, 16)
(None, 5, 33, 7, 44, 44)
Here the attention weight is (44) x (44)
and repeated for every x[b,m]