2

I’m currently using PyTorch’s torch.autograd.functional.jacobian to compute per-sample, elementwise gradients of a scalar-valued model output w.r.t. its inputs. I need to keep create_graph=True because I want the resulting Jacobian entries to themselves require gradients (for further calculations).

Here’s a minimal example of what I’m doing:

import torch
from torch.autograd.functional import jacobian

def method_jac_strict(inputs, forward_fn):
    # inputs: (N, F)
    # forward_fn: (N, F) -> (N, 1)
    # output: (N, F).

    # compute full Jacobian: 
    d = jacobian(forward_fn, inputs, create_graph=True, strict=True)  # (N, 1, N, F)
    d = d.squeeze()  # (N, N, F)

    # extract only the diagonal block (each output wrt its own input sample): (N, F)
    d = torch.einsum('iif->if', d)
    return d

A clarification - Batch-sample dependencies

My model may include layers like BatchNorm, so samples in the batch aren’t truly independent. However, I only care about the “elementwise” gradients—i.e. treating each scalar output as if it only depended on its own input sample, and ignoring cross-sample terms.

Question

Is there a more efficient/idiomatic way in PyTorch to compute this elementwise gradient preserving create_graph, without materializing the full (N, N, F) tensor and extracting its diagonal?

Any pointers to built-in functions or custom tricks (e.g. clever use of torch.einsum, custom autograd.Function, batching hacks, etc.) would be much appreciated!

1 Answer 1

0

You can just call backward on the raw output. Here's a simple example:

import torch
import torch.nn as nn
from torch.autograd.functional import jacobian

x = torch.randn(4, 8)
model = nn.Sequential(*[nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 1)])

# method 1 - compute jacobian and take diagonal
j = jacobian(model, x, create_graph=True, strict=True)
out1 = torch.einsum('iif->if', j.squeeze())

# method 2 - set grad flag on inputs, run forward pass and backward
x.requires_grad = True
y = model(x)
y.backward(torch.ones_like(y))
out2 = x.grad

torch.allclose(out1, out2)
> True
Sign up to request clarification or add additional context in comments.

2 Comments

Thank you for your answer. It does not help because out2.requires_grad is False.
Also, changing model to model = nn.Sequential(*[nn.Linear(8, 8), nn.BatchNorm1d(8), nn.ReLU(), nn.Linear(8, 1)]) makes non-diagonal elements nonzero making torch.allclose(out1, out2) False.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.