0

I read that a function f is equivariant if f(P(x)) = P(f(x)) where P is a permutation

So to check what means equivariant and permutation invariant I wrote the following code

import torch
import torch.nn as nn

multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
    x0[i] *= i
    x1[i] *= (i+1) % x0.size(0)

x = torch.cat(
    (x0.unsqueeze(0), x1.unsqueeze(0))
    )

y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)
``

Then to check torch.equal(x0[1],x1[0]) >>> True

but torch.equal(y0,y1) >>> False, so doesn't seem to be permutation invariant
and torch.equal(y0[1],y1[0]) >>> False, so doesn't seem to be equivarient

So what am I doing wrong ?

1 Answer 1

0

I got the answer here https://discuss.pytorch.org/t/multi-head-self-attention-in-transformer-is-permutation-invariant-or-equivariant-how-to-see-it-in-practice/221249/2

the correct evaluation is

torch.allclose(y0[1], y1[0], atol=1e-6)

wich evaluate as True

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.