The issue was in the EncoderLayer where the residual calculations were done wrong. The correct way of calculating:
def forward(self, x: torch.Tensor, src_pad_key = None):
residual = x
x = self.layer_norm1(x)
if src_pad_key is not None: x = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
else: x = self.self_attn(x)
# normalize and apply residual connections
x += residual
residual = x
x = self.layer_norm2(x)
x = self.mlp(x)
x += residual
return x
Another change was that we must always use self attention (instead of pooled attention) as otherwise the calculations won't work with the image encoder. [query = x]
The results look like this:
Cat similarity: tensor([[25.4132]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[21.8544]], grad_fn=<MulBackward0>)
cosine cat/dog: 0.8438754677772522