This section of code is taken from one of the Pytorch tutorials, I have just removed the non-essential parts so it doesn't error out and added some print statements. The question I have is why the two print statements I provided have slightly different results? Is this a tuple with nothing in the second half of it? I am confused by the comma without anything after it before the assignment operator.
import torch
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
print("ctx ", ctx.saved_tensors)
print("inputs ", input)
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
relu = MyReLU.apply
relu = MyReLU.apply
y_pred = relu(x.mm(w1)).mm(w2)
loss = (y_pred - y).pow(2).sum()
loss.backward()
Output
ctx (tensor([[-34.2381, 18.6334, 8.8368, ..., 13.7337, -31.5657, -11.8838],
[-25.5597, -6.2847, 9.9412, ..., -75.0621, 5.0451, -32.9348],
[-56.6591, -40.0830, 2.4311, ..., -2.8988, -18.9742, -74.0132],
...,
[ -6.4023, -30.3526, -73.9649, ..., 1.8587, -23.9617, -11.6951],
[ -3.6425, 34.5828, 27.7200, ..., -34.3878, -19.7250, 11.1960],
[ 16.0137, -24.0628, 14.4008, ..., -5.4443, 9.9499, -18.1259]],
grad_fn=<MmBackward>),)
inputs tensor([[-34.2381, 18.6334, 8.8368, ..., 13.7337, -31.5657, -11.8838],
[-25.5597, -6.2847, 9.9412, ..., -75.0621, 5.0451, -32.9348],
[-56.6591, -40.0830, 2.4311, ..., -2.8988, -18.9742, -74.0132],
...,
[ -6.4023, -30.3526, -73.9649, ..., 1.8587, -23.9617, -11.6951],
[ -3.6425, 34.5828, 27.7200, ..., -34.3878, -19.7250, 11.1960],
[ 16.0137, -24.0628, 14.4008, ..., -5.4443, 9.9499, -18.1259]],
grad_fn=<MmBackward>)
,is necessary because(x)is justx.(x,), which is different fromx.f(2,3,4)vsf((2,3), 4)) and as a special case the empty tuple which is just parentheses (since empty parentheses aren't otherwise a valid expression).