I'm training a deep learning model where my original architecture involves an element-wise product between vectors. Specifically, the computation is straightforward:
# Original computation (element-wise product)
C = A * B # shape: (batch, len, dim)
To enhance the expressive power of the model, I modified this computation by introducing dimension interactions through a dot product operation followed by reshaping. This way, each element in the output incorporates linear combinations across dimensions:
# Modified computation (dot product followed by reshape)
scale = torch.sum(A * B, dim=-1, keepdim=True) # shape: (batch, len, 1)
C = torch.matmul(scale.unsqueeze(-1), N_q.unsqueeze(-2)).squeeze(-2) # resulting shape: (batch, len, dim)
# Here, N_q.shape is (batch, len, dim)
Despite increasing the model's complexity and theoretical representational capacity, the modified architecture doesn't perform as expected. The simpler model converged to a loss of approximately 0.8, but after introducing the dot product and reshaping step, the loss worsened to about 1.1.
Why does introducing dimension interactions via dot product and reshaping negatively impact the model's convergence and final loss?
C = torch.matmul(scale.unsqueeze(-1), N_q.unsqueeze(-2)).squeeze(-2)
, but given that you are performing the matmul over unit dimensions, the resultingC
is equal toscale * N_q
so I don't really see the point of doing the matmul and all the dimension operations.