Skip to content

The calculation of w in SO3 to_quaternion() function may cause gradient explosion. #670

@HiOnes

Description

@HiOnes

🐛 Bug

I suffered gradient explosion in my training process. I used with autograd.detect_anomaly() and got the hint Function 'SqrtBackward0' returned nan values in its 0th output. However, I didn't use functions like torch.sqrt() in my code, so I thought the bug may lie in the internal calculations of Theseus. And I have noticed this #661 relevant fix, so I checked the implementation of to_quarternion() function in so3.py.

I found the eps in the calculation of sine_half_theta, the relevant code is:

sqrt_eps = _THESEUS_GLOBAL_PARAMS.get_eps("so3", "to_quaternion_sqrt", w.dtype)
sine_half_theta = (
    (0.5 * (1 - cosine_near_pi)).clamp(sqrt_eps, 1).sqrt().view(-1, 1)
)

However, another use of sqrt lies in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(0, 4).sqrt()

here it just limits the result between 0 and 4, when it is close to 0, the backward process may fail.

Steps to Reproduce

I prepare a simple test code to reproduce this bug:

import theseus as th
import torch
import torch.nn.functional as F

rot = torch.tensor([[1.0, 0.0, 0.0],
                    [0.0, -1.0, 0.0],
                    [0.0, 0.0, -1.0]], requires_grad=True).reshape(1, 3, 3)
rot_so3 = th.SO3(tensor=rot)
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0]).reshape(1, 4)
err = F.mse_loss(rot_so3.to_quaternion(), identity_quat)
rot.retain_grad()
err.backward()
print(rot.grad)

The output will be

tensor([[[-inf, 0., 0.],
         [0., -inf, 0.],
         [0., 0., -inf]]])

And if I add an eps(which is 1e-6 in my test) in the calculation of w:

w = 0.5 * (1 + self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2]).clamp(1e-6, 4).sqrt()

The grad will be:

tensor([[[-0.0625,  0.0000,  0.0000],
         [ 0.0000, -0.0625,  0.0000],
         [ 0.0000,  0.0000, -0.0625]]])

System Info

  • OS : Ubuntu 20.04
  • Python version: 3.8
  • CUDA version: 11.8

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions