I am implementing Soft Actor-Critic (SAC) and I am confused about the policy update step.
What I want is:
When I update the policy (actor), I do not want the parameters of the Q-networks (critics) to be changed.
However, I still need the gradient signal from the Q-networks in order to update the policy correctly.
I originally used
with torch.no_grad():
q_pi_1 = self.q1(obs, new_action)
q_pi_2 = self.q2(obs, new_action)
because I thought this would prevent the Q-networks from influencing the policy update.
But now I realize this may be wrong, because torch.no_grad() disables all gradient tracking, including the gradient of Q(s,a) with respect to the action a.
And that gradient is exactly what the policy needs during the SAC actor update.
Am I right or wrong?
My solution for this problem (I don't know if this really is a problem) would be:
with this solution I think that the parameters will not change, but the Gradient will still be available.
Or am I misunderstanding how gradients should be handled here?
for p in self.q1.parameters():
p.requires_grad = False
for p in self.q2.parameters():
p.requires_grad = False
# actor loss (Policy) uses q1/q2 forward pass here (without torch.no_grad)
for p in self.q1.parameters():
p.requires_grad = True
for p in self.q2.parameters():
p.requires_grad = True