1

I am trying to get a deeper understanding of how Pytorch's autograd works. I am unable to explain the following results:

import torch
def fn(a):
 b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
 return a*b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

The output is tensor(5.). But my question is that the variable b is created within the function and so should be removed from memory after the function returns a*b right? So when I call backward how is the value of b still present for allowing this computation? As far as I understand each operation in Pytorch has a context variable which tracks "which" tensor to use for backward computation and there are also versions present in each tensor, and if the version changes then backward should raise an error right?

Now when I try to run the following code,

import torch
def fn(a):
 b = a**2
 for i in range(5):
   b *= b
 return b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

I get the following error: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of MulBackward0, is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

But if I run the following code, there is no error:

import torch
def fn(a):
  b = a**2
  for i in range(2):
    b = b*b
  return b

def fn2(a):
  b = a**2
  c = a**2
  for i in range(2):
    c *= b
  return c

a  = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)

The output for this is :

tensor(625000.)

tensor(643750.)

So for a standard computation graphs with quite a few variables, in the same function, I am able to understand how the computation graph works. But when there is a variable changing before the call of backward function, I am having a lot of trouble understanding the results. Can someone explain?

1 Answer 1

3

Please note that b *=b is not same as b = b*b.

It is perhaps confusing, but the underlying operations vary.

In case of b *=b, an in-place operation takes place which messes up with the gradients and hence the RuntimeError.

In case of b = b*b, two tensor objects gets multiplied and the resulting object is assigned the name b. Thus no RuntimeError when you run this way.

Here is a SO question on the underlying python operation: The difference between x += y and x = x + y

Now what is the difference between fn in first case and fn2 in the second case? The operation c*=b does not destroy the graph links to b from c. The operation c*=c would make it impossible to have a graph connecting two tensors via an operation.

Well, I cannot work with tensors to showcase that because they raise RuntimeError. So I'll try with python list.

>>> x = [1,2]
>>> y = [3]
>>> id(x), id(y)
(140192646516680, 140192646927112)
>>>
>>> x += y
>>> x, y
([1, 2, 3], [3])
>>> id(x), id(y)
(140192646516680, 140192646927112)

Notice that there is no new object created. So it is not possible to trace from the output to initial variables. We cannot distinguish the object_140192646516680 to be an output or an input. So how does one create a graph with that..

Consider the following alternate case:

>>> a = [1,2]
>>> b = [3]
>>>
>>> id(a), id(b)
(140192666168008, 140192666168264)
>>>
>>> a = a + b
>>> a, b
([1, 2, 3], [3])
>>> id(a), id(b)
(140192666168328, 140192666168264)
>>>

Notice that the new list a is in fact a new object with id 140192666168328. Here we can trace that the object_140192666168328 came from the addition operation between two other objects object_140192666168008 and object_140192666168264. Thus a graph can be dynamically created and gradients can be propagated back from output to previous layers.

4
  • Hi. Thanks for the inplace clarification. Could you please elaborate more on the last point?
    – pranav
    Commented Jun 21, 2020 at 9:15
  • Expanded the answer to describe it more. Hope this helps. Commented Jun 21, 2020 at 9:41
  • How does c*=c differ from c*=b? It seems like in the example with lists, when you apply x+=y the id of the objects did not change?
    – pranav
    Commented Jun 21, 2020 at 13:22
  • The difference is that c*=c sort of cannibalizes its own history because the start and end of the graph edges are the same objects and thus the graph collapses. c*=b allows to have a graph link from one distinct object c to another object b. Commented Jun 21, 2020 at 20:03

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.