2
\$\begingroup\$

I'm computing higher order derivatives using nested jacobian from pytorch/functorch. $$f(x, y, z, \dots) : R^{m_x} \times R^{m_y} \times R^{m_z} \times \dots \to R^{m} $$ Given a function f(x, y, z, ...) and derivative orders for first several arguments (nx, ny, nz, ..), partial derivatives can be computed and used as Taylor series representation of the original function.

I'm aware that nesting jacobian generates redundancy for high orders, e.g. both dxdyf and dydxf are computed. I'm looking for a style/code review for derivative function. In particular, I would like to avoid using nested kwarg.

Code

"""
Derivative module.

"""
from __future__ import annotations

import torch
import functorch

from typing import Iterator
from typing import Callable

from math import factorial
from multimethod import multimethod


def flatten(table:tuple) -> Iterator:
    """
    Flatten a nested tuple.

    Parameters
    ----------
    table: tuple
        input (nested) tuple

    Returns
    -------
    flattened tuple generator (iterator)

    """
    if isinstance(table, tuple): 
        for _ in table: yield from flatten(_) 
    else: 
        yield table
        
        
@multimethod
def derivative(order:int,
               function:Callable,
               *args,
               jacobian:Callable=functorch.jacfwd) -> list:
    """
    Compute input function derivatives with respet to the first argument upto a given order.
    
    Input function is expected to return a tensor or a (nested) list of tensors
    The first function argument is expected to be a tensor
    
    Parameters
    ----------
    order: int
        maximum derivative order
    function: Callable
        input function
    *args:
        input function arguments
    jacobian: Callable, optional, default=functorch.jacfwd
        functorch.jacfwd or functorch.jacrev

    Returns
    -------
    function derivatives (list)

    """
    def local(x, *xs):
        y = function(x, *xs)
        return y, y

    for _ in range(order):
        def local(x, *xs, local=local):
            y, ys = jacobian(local, has_aux=True)(x, *xs)
            return y, (ys, y)

    _, ys = local(*args)

    return list(flatten(ys))


@multimethod
def derivative(order:tuple[int, ...],
               function:Callable,
               *args,
               jacobian:Callable=functorch.jacfwd) -> list:
    """
    Compute input function derivatives with respet to several first function arguments upto corresponding given orders.
    
    Input function is expected to return a tensor or a (nested) list of tensors
    The first several function arguments are expected to be tensors

    Parameters
    ----------
    order: tuple[int, ...]
        maximum derivative orders
    function: Callable
        input function
    *args:
        input function arguments
    jacobian: Callable, optional, default=functorch.jacfwd
        functorch.jacfwd or functorch.jacrev

    Returns
    -------
    function derivatives (list)

    """
    pars = [*args][len(order):]
    
    def fixed(*args):
        return function(*args, *pars)

    def build(order, value):
        def local(*args):
            return derivative(order, lambda x: fixed(*args, x), value, jacobian=jacobian)
        return local
    
    (order, value), *rest = zip(order, args)
    for degree, tensor in reversed(rest):
        def build(order, value, build=build(degree, tensor)):
            def local(*args):
                return derivative(order, lambda x: build(*args, x), value, jacobian=jacobian)
            return local
        
    return build(order, value)()


@multimethod
def evaluate(table:list, 
             point:list[torch.Tensor]) -> torch.Tensor:
    """
    Evaluate input table of derivatives at a given (deviation) point.
        
    Parameters
    ----------
    table: list
        input table of derivatives
    point: list[torch.Tensor]
        evaluation point
        
    Returns
    ----------
    value (torch.Tensor)
        
    """
    return sum(evaluate([i], subtable, point) for i, subtable in enumerate(table))


@multimethod
def evaluate(index:list[int], table:list, point:list[torch.Tensor]) -> torch.Tensor:
    return sum(evaluate(index + [i], subtable, point) for i, subtable in enumerate(table))


@multimethod
def evaluate(index:list[int], table:torch.Tensor, point:list[torch.Tensor]) -> torch.Tensor:
    factor = 1.0
    for position, order in enumerate(index):
        factor *= 1.0/factorial(order)
        value = point[position]
        if value.ndim > 0:
            for _ in range(order): table @= value
        else:
            table *= value**order
    return factor*table

Examples

print(torch.__version__)
print(functorch.__version__)
# 1.13.0+cu117
# 1.13.0+cu117
# explicit nested derivatives

# set test function

# f: R x R x R -> R

def f(x, y, z):
    return (1.0 + (2.0 + 3.0*z)*y)*x

# set derivatives order

nx, ny, nz = 1, 1, 1

# set evaluation point

px = torch.tensor(1.0, dtype=torch.float64)
py = torch.tensor(1.0, dtype=torch.float64)
pz = torch.tensor(1.0, dtype=torch.float64)

# compute derivatives (explicit nesting)

t1 = derivative(nx, lambda x: derivative(ny, lambda y: derivative(nz, lambda z: f(x, y, z), pz), py), px)

# compute derivatives (implicit nesting)

t2 = derivative((nx, ny, nz), f, px, py, pz)

# compare derivatives

print(all(k1 == k2 for i1, i2 in zip(t1, t2) for j1, j2 in zip(i1, i1) for k1, k2 in zip(j1, j1)))

# evaluate at a different point

vx = torch.tensor(5.0, dtype=torch.float64)
vy = torch.tensor(5.0, dtype=torch.float64)
vz = torch.tensor(5.0, dtype=torch.float64)

print(f(vx, vy, vz))
print(evaluate(t1, [vx - px, vy - py, vz - pz]))
print(evaluate(t2, [vx - px, vy - py, vz - pz]))

# True
# tensor(4.300000000000e+02, dtype=torch.float64)
# tensor(4.300000000000e+02, dtype=torch.float64)
# tensor(4.300000000000e+02, dtype=torch.float64)
# derivatives with respect to several vector arguments

# set test function

# f: R^2 x R^3 x R^4 -> R^2

def f(x, y, z):
    x1, x2 = x
    y1, y2, y3 = y
    z1, z2, z3, z4 = z
    
    Z1 = (1 + z1 + z2 + z3 + z4)**2
    Z2 = (1 + z1 + z2 + z3 + z4)**2
    Z3 = (1 + z1 + z2 + z3 + z4)**2
    Z4 = (1 + z1 + z2 + z3 + z4)**2
    
    Y1 = (Z1 + Z2)*(1 + y1 + y2 + y3)**2
    Y2 = (Z2 + Z3)*(1 + y1 + y2 + y3)**2
    Y3 = (Z3 + Z4)*(1 + y1 + y2 + y3)**2
    
    X1 = (Y1 + Y2)*(1 + x1 + x2)**2
    X2 = (Y2 + Y3)*(1 + x1 + x2)**2
    
    return torch.stack([X1, X2])

# set derivatives order

nx, ny, nz = 2, 2, 2

# set evaluation point

px = torch.tensor([0.0, 0.0], dtype = torch.float64)
py = torch.tensor([0.0, 0.0, 0.0], dtype = torch.float64)
pz = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype = torch.float64)

# compute derivatives

table = derivative((nx, ny, nz), f, px, py, pz)

# evaluate at a different point

vx = torch.tensor([1.0, 2.0], dtype = torch.float64)
vy = torch.tensor([3.0, 4.0, 5.0], dtype = torch.float64)
vz = torch.tensor([6.0, 7.0, 8.0, 9.0], dtype = torch.float64)

print(f(vx, vy, vz))
print(evaluate(table, [vx - px, vy - py, vz - pz]))

# tensor([1.039417600000e+07, 1.039417600000e+07], dtype=torch.float64)
# tensor([1.039417600000e+07, 1.039417600000e+07], dtype=torch.float64)

# structure of derivatives

# [
#     [
#         [          f,           dz f,           dzdz f, ...],
#         [       dy f,        dy dz f,        dy dzdz f, ...],
#         [     dydy f,      dydy dz f,      dydy dzdz f, ...],
#         ...
#     ],
#     [
#         [       dx f,        dx dz f,        dx dzdz f, ...],
#         [    dx dy f,     dx dy dz f,     dx dy dzdz f, ...],
#         [  dx dydy f,   dx dydy dz f,   dx dydy dzdz f, ...],
#         ...
#     ],
#     [
#         [     dxdx f,      dxdx dz f,      dxdx dzdz f, ...],
#         [  dxdx dy f,   dxdx dy dz f,   dxdx dy dzdz f, ...],
#         [dxdx dydy f, dxdx dydy dz f, dxdx dydy dzdz f, ...],
#         ...
#     ],
#     ...
# ]

print(table[0][0][0].shape) # f
print(table[1][0][1].shape) # dx dz f
print(table[2][2][2].shape) # dxdx dydy dzdz f

# torch.Size([2])
# torch.Size([2, 4, 2])
# torch.Size([2, 4, 4, 3, 3, 2, 2])
```
\$\endgroup\$

0

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.