I’m currently using PyTorch’s torch.autograd.functional.jacobian to compute per-sample, elementwise gradients of a scalar-valued model output w.r.t. its inputs. I need to keep create_graph=True because I want the resulting Jacobian entries to themselves require gradients (for further calculations).
Here’s a minimal example of what I’m doing:
import torch
from torch.autograd.functional import jacobian
def method_jac_strict(inputs, forward_fn):
# inputs: (N, F)
# forward_fn: (N, F) -> (N, 1)
# output: (N, F).
# compute full Jacobian:
d = jacobian(forward_fn, inputs, create_graph=True, strict=True) # (N, 1, N, F)
d = d.squeeze() # (N, N, F)
# extract only the diagonal block (each output wrt its own input sample): (N, F)
d = torch.einsum('iif->if', d)
return d
A clarification - Batch-sample dependencies
My model may include layers like BatchNorm, so samples in the batch aren’t truly independent. However, I only care about the “elementwise” gradients—i.e. treating each scalar output as if it only depended on its own input sample, and ignoring cross-sample terms.
Question
Is there a more efficient/idiomatic way in PyTorch to compute this elementwise gradient preserving create_graph, without materializing the full (N, N, F) tensor and extracting its diagonal?
Any pointers to built-in functions or custom tricks (e.g. clever use of torch.einsum, custom autograd.Function, batching hacks, etc.) would be much appreciated!