I read that a function f is equivariant if f(P(x)) = P(f(x)) where P is a permutation
So to check what means equivariant and permutation invariant I wrote the following code
import torch
import torch.nn as nn
multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
x0[i] *= i
x1[i] *= (i+1) % x0.size(0)
x = torch.cat(
(x0.unsqueeze(0), x1.unsqueeze(0))
)
y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)
``
Then to check torch.equal(x0[1],x1[0]) >>> True
but torch.equal(y0,y1) >>> False, so doesn't seem to be permutation invariant
and torch.equal(y0[1],y1[0]) >>> False, so doesn't seem to be equivarient
So what am I doing wrong ?