-1

I'm training a deep learning model where my original architecture involves an element-wise product between vectors. Specifically, the computation is straightforward:

# Original computation (element-wise product)
C = A * B  # shape: (batch, len, dim)

To enhance the expressive power of the model, I modified this computation by introducing dimension interactions through a dot product operation followed by reshaping. This way, each element in the output incorporates linear combinations across dimensions:

# Modified computation (dot product followed by reshape)
scale = torch.sum(A * B, dim=-1, keepdim=True)  # shape: (batch, len, 1)
C = torch.matmul(scale.unsqueeze(-1), N_q.unsqueeze(-2)).squeeze(-2)  # resulting shape: (batch, len, dim)
# Here, N_q.shape is (batch, len, dim)

Despite increasing the model's complexity and theoretical representational capacity, the modified architecture doesn't perform as expected. The simpler model converged to a loss of approximately 0.8, but after introducing the dot product and reshaping step, the loss worsened to about 1.1.

Why does introducing dimension interactions via dot product and reshaping negatively impact the model's convergence and final loss?

1
  • Are you sure the code is doing what you think it's doing? You compute C = torch.matmul(scale.unsqueeze(-1), N_q.unsqueeze(-2)).squeeze(-2), but given that you are performing the matmul over unit dimensions, the resulting C is equal to scale * N_q so I don't really see the point of doing the matmul and all the dimension operations.
    – Karl
    Commented Apr 17 at 0:30

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.