I'm computing higher order derivatives using nested jacobian from pytorch/functorch.
$$f(x, y, z, \dots) : R^{m_x} \times R^{m_y} \times R^{m_z} \times \dots \to R^{m} $$
Given a function f(x, y, z, ...) and derivative orders for first several arguments (nx, ny, nz, ..), partial derivatives can be computed and used as Taylor series representation of the original function.
I'm aware that nesting jacobian generates redundancy for high orders, e.g. both dxdyf and dydxf are computed. I'm looking for a style/code review for derivative function. In particular, I would like to avoid using nested kwarg.
Code
"""
Derivative module.
"""
from __future__ import annotations
import torch
import functorch
from typing import Iterator
from typing import Callable
from math import factorial
from multimethod import multimethod
def flatten(table:tuple) -> Iterator:
"""
Flatten a nested tuple.
Parameters
----------
table: tuple
input (nested) tuple
Returns
-------
flattened tuple generator (iterator)
"""
if isinstance(table, tuple):
for _ in table: yield from flatten(_)
else:
yield table
@multimethod
def derivative(order:int,
function:Callable,
*args,
jacobian:Callable=functorch.jacfwd) -> list:
"""
Compute input function derivatives with respet to the first argument upto a given order.
Input function is expected to return a tensor or a (nested) list of tensors
The first function argument is expected to be a tensor
Parameters
----------
order: int
maximum derivative order
function: Callable
input function
*args:
input function arguments
jacobian: Callable, optional, default=functorch.jacfwd
functorch.jacfwd or functorch.jacrev
Returns
-------
function derivatives (list)
"""
def local(x, *xs):
y = function(x, *xs)
return y, y
for _ in range(order):
def local(x, *xs, local=local):
y, ys = jacobian(local, has_aux=True)(x, *xs)
return y, (ys, y)
_, ys = local(*args)
return list(flatten(ys))
@multimethod
def derivative(order:tuple[int, ...],
function:Callable,
*args,
jacobian:Callable=functorch.jacfwd) -> list:
"""
Compute input function derivatives with respet to several first function arguments upto corresponding given orders.
Input function is expected to return a tensor or a (nested) list of tensors
The first several function arguments are expected to be tensors
Parameters
----------
order: tuple[int, ...]
maximum derivative orders
function: Callable
input function
*args:
input function arguments
jacobian: Callable, optional, default=functorch.jacfwd
functorch.jacfwd or functorch.jacrev
Returns
-------
function derivatives (list)
"""
pars = [*args][len(order):]
def fixed(*args):
return function(*args, *pars)
def build(order, value):
def local(*args):
return derivative(order, lambda x: fixed(*args, x), value, jacobian=jacobian)
return local
(order, value), *rest = zip(order, args)
for degree, tensor in reversed(rest):
def build(order, value, build=build(degree, tensor)):
def local(*args):
return derivative(order, lambda x: build(*args, x), value, jacobian=jacobian)
return local
return build(order, value)()
@multimethod
def evaluate(table:list,
point:list[torch.Tensor]) -> torch.Tensor:
"""
Evaluate input table of derivatives at a given (deviation) point.
Parameters
----------
table: list
input table of derivatives
point: list[torch.Tensor]
evaluation point
Returns
----------
value (torch.Tensor)
"""
return sum(evaluate([i], subtable, point) for i, subtable in enumerate(table))
@multimethod
def evaluate(index:list[int], table:list, point:list[torch.Tensor]) -> torch.Tensor:
return sum(evaluate(index + [i], subtable, point) for i, subtable in enumerate(table))
@multimethod
def evaluate(index:list[int], table:torch.Tensor, point:list[torch.Tensor]) -> torch.Tensor:
factor = 1.0
for position, order in enumerate(index):
factor *= 1.0/factorial(order)
value = point[position]
if value.ndim > 0:
for _ in range(order): table @= value
else:
table *= value**order
return factor*table
Examples
print(torch.__version__)
print(functorch.__version__)
# 1.13.0+cu117
# 1.13.0+cu117
# explicit nested derivatives
# set test function
# f: R x R x R -> R
def f(x, y, z):
return (1.0 + (2.0 + 3.0*z)*y)*x
# set derivatives order
nx, ny, nz = 1, 1, 1
# set evaluation point
px = torch.tensor(1.0, dtype=torch.float64)
py = torch.tensor(1.0, dtype=torch.float64)
pz = torch.tensor(1.0, dtype=torch.float64)
# compute derivatives (explicit nesting)
t1 = derivative(nx, lambda x: derivative(ny, lambda y: derivative(nz, lambda z: f(x, y, z), pz), py), px)
# compute derivatives (implicit nesting)
t2 = derivative((nx, ny, nz), f, px, py, pz)
# compare derivatives
print(all(k1 == k2 for i1, i2 in zip(t1, t2) for j1, j2 in zip(i1, i1) for k1, k2 in zip(j1, j1)))
# evaluate at a different point
vx = torch.tensor(5.0, dtype=torch.float64)
vy = torch.tensor(5.0, dtype=torch.float64)
vz = torch.tensor(5.0, dtype=torch.float64)
print(f(vx, vy, vz))
print(evaluate(t1, [vx - px, vy - py, vz - pz]))
print(evaluate(t2, [vx - px, vy - py, vz - pz]))
# True
# tensor(4.300000000000e+02, dtype=torch.float64)
# tensor(4.300000000000e+02, dtype=torch.float64)
# tensor(4.300000000000e+02, dtype=torch.float64)
# derivatives with respect to several vector arguments
# set test function
# f: R^2 x R^3 x R^4 -> R^2
def f(x, y, z):
x1, x2 = x
y1, y2, y3 = y
z1, z2, z3, z4 = z
Z1 = (1 + z1 + z2 + z3 + z4)**2
Z2 = (1 + z1 + z2 + z3 + z4)**2
Z3 = (1 + z1 + z2 + z3 + z4)**2
Z4 = (1 + z1 + z2 + z3 + z4)**2
Y1 = (Z1 + Z2)*(1 + y1 + y2 + y3)**2
Y2 = (Z2 + Z3)*(1 + y1 + y2 + y3)**2
Y3 = (Z3 + Z4)*(1 + y1 + y2 + y3)**2
X1 = (Y1 + Y2)*(1 + x1 + x2)**2
X2 = (Y2 + Y3)*(1 + x1 + x2)**2
return torch.stack([X1, X2])
# set derivatives order
nx, ny, nz = 2, 2, 2
# set evaluation point
px = torch.tensor([0.0, 0.0], dtype = torch.float64)
py = torch.tensor([0.0, 0.0, 0.0], dtype = torch.float64)
pz = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype = torch.float64)
# compute derivatives
table = derivative((nx, ny, nz), f, px, py, pz)
# evaluate at a different point
vx = torch.tensor([1.0, 2.0], dtype = torch.float64)
vy = torch.tensor([3.0, 4.0, 5.0], dtype = torch.float64)
vz = torch.tensor([6.0, 7.0, 8.0, 9.0], dtype = torch.float64)
print(f(vx, vy, vz))
print(evaluate(table, [vx - px, vy - py, vz - pz]))
# tensor([1.039417600000e+07, 1.039417600000e+07], dtype=torch.float64)
# tensor([1.039417600000e+07, 1.039417600000e+07], dtype=torch.float64)
# structure of derivatives
# [
# [
# [ f, dz f, dzdz f, ...],
# [ dy f, dy dz f, dy dzdz f, ...],
# [ dydy f, dydy dz f, dydy dzdz f, ...],
# ...
# ],
# [
# [ dx f, dx dz f, dx dzdz f, ...],
# [ dx dy f, dx dy dz f, dx dy dzdz f, ...],
# [ dx dydy f, dx dydy dz f, dx dydy dzdz f, ...],
# ...
# ],
# [
# [ dxdx f, dxdx dz f, dxdx dzdz f, ...],
# [ dxdx dy f, dxdx dy dz f, dxdx dy dzdz f, ...],
# [dxdx dydy f, dxdx dydy dz f, dxdx dydy dzdz f, ...],
# ...
# ],
# ...
# ]
print(table[0][0][0].shape) # f
print(table[1][0][1].shape) # dx dz f
print(table[2][2][2].shape) # dxdx dydy dzdz f
# torch.Size([2])
# torch.Size([2, 4, 2])
# torch.Size([2, 4, 4, 3, 3, 2, 2])
```