-
Notifications
You must be signed in to change notification settings - Fork 143
Description
🐛 Bug
I suffered gradient explosion in my training process. I used with autograd.detect_anomaly() and got the hint Function 'SqrtBackward0' returned nan values in its 0th output. However, I didn't use functions like torch.sqrt() in my code, so I thought the bug may lie in the internal calculations of Theseus. And I have noticed this #661 relevant fix, so I checked the implementation of to_quarternion() function in so3.py.
I found the eps in the calculation of sine_half_theta, the relevant code is:
sqrt_eps = _THESEUS_GLOBAL_PARAMS.get_eps("so3", "to_quaternion_sqrt", w.dtype)
sine_half_theta = (
(0.5 * (1 - cosine_near_pi)).clamp(sqrt_eps, 1).sqrt().view(-1, 1)
)
However, another use of sqrt lies in the calculation of w:
w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(0, 4).sqrt()
here it just limits the result between 0 and 4, when it is close to 0, the backward process may fail.
Steps to Reproduce
I prepare a simple test code to reproduce this bug:
import theseus as th
import torch
import torch.nn.functional as F
rot = torch.tensor([[1.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, -1.0]], requires_grad=True).reshape(1, 3, 3)
rot_so3 = th.SO3(tensor=rot)
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0]).reshape(1, 4)
err = F.mse_loss(rot_so3.to_quaternion(), identity_quat)
rot.retain_grad()
err.backward()
print(rot.grad)
The output will be
tensor([[[-inf, 0., 0.],
[0., -inf, 0.],
[0., 0., -inf]]])
And if I add an eps(which is 1e-6 in my test) in the calculation of w:
w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(1e-6, 4).sqrt()
The grad will be:
tensor([[[-0.0625, 0.0000, 0.0000],
[ 0.0000, -0.0625, 0.0000],
[ 0.0000, 0.0000, -0.0625]]])
System Info
- OS : Ubuntu 20.04
- Python version: 3.8
- CUDA version: 11.8