2

I have a recursively defined function my_func that is jitted using jax.jit from the jax library. It is defined below:

# Imports

import jax
import jax.numpy as jnp
from functools import partial
import time


# Constants and subroutines used in the core recursive routing below ...

sx = jnp.asarray([[0,1.],[1.,0]], dtype=complex)
sy = jnp.asarray([[0,-1j],[1j,0]], dtype=complex)

def conj_op(A):
    return jnp.swapaxes(A, -1,-2).conj()

def commutator_herm(A, B):
    comm = A @ B
    comm = comm - conj_op(comm)
    return comm

def H(t):
    return jnp.cos(t) * sy

def X0(t):
    return sx

# Core recursive routine ...

@partial(jax.jit, static_argnames="k")
def my_func(t, k):
    if k==0:
        X_k = X0(t)
        return X_k
    else:
        X_km1 = lambda t: my_func(t,k-1)
        X_k = 1j * commutator_herm(H(t), X_km1(t)) + jax.jacfwd(X_km1, holomorphic=True)(t)
        return X_k

with the relevant test:

# Tests ...

t = jnp.asarray(1, dtype=complex)

seq_exec_times = []

for k in range(9,10): # or toggle to range(10) to compile sequentially
    start = time.time()
    my_func(t, k)
    dur = time.time() - start
    seq_exec_times.append(dur)

total_seq_exec_time = sum(seq_exec_times)

print("Sequential execution times:")
print(["{:.3e} s".format(x) for x in seq_exec_times])
print("Total execution time:")
print("{:.3e} s".format(total_seq_exec_time))

If I execute this function the first time only for k=9, then I get a quite long compilation time, which I figure is because tracing a recursive function like this one takes an effort that scales exponentially with recursion depth. The output is:

Sequential execution times:
['6.306e+01 s']   # First execution time when calling directly with k=9
Total execution time:
6.306e+01 s

But then I thought that in practice, I need to evaluate my_func for increasing values of k=0,1,2,3... anyway. And if the lower step has already been traced, then you only need to trace the next level of the tree, and that should be more efficient. And indeed, executing k=1,2,3...,8 before executing k=9 yields a slightly lower execution time the first time k=9 is evaluated:

Sequential execution times:
['3.797e-03 s',
'2.203e-02 s',
'3.487e-02 s',
'7.054e-02 s',
'1.779e-01 s',
'4.680e-01 s',
'1.326e+00 s',
'4.145e+00 s',
'1.456e+01 s',
'5.550e+01 s']    # First execution time of k=9 after calling k=0,1,2,3...,8 first
Total execution time:
7.631e+01 s

That said, this still scales exponentially with recursion depth, and I was naively expecting the compilation of k=9 to be more efficient. If the lower levels k=1,2,3...,8 are already compiled, then I would naively expect the compilation at the next level k=9 to be relatively simple. I would think that you can simply trace the link between k=9 and k=8, and avoid going through the whole recursion tree again at the lower levels.

But it looks like I was wrong, and I'm curious to know why? And if I'm not wrong, how do I make this better?

This was run with jax - 0.4.33 on MacOS - 15.6.1.

1 Answer 1

1

In general, you should avoid a recursive coding style when using JAX code with JIT, autodiff, or other transformations.

There are three different things at play here that complicate the analysis of runtimes:

  • tracing: this is the general process used in transforming JAX code, whether for jit or for autodiff like jacfwd . I believe the main reason you are seeing different timings depending on the sequence of executions is because of the trace cache: for each value of k, the function will be traced only once and subsequent calls will use the cached trace.

  • autodiff: the jacfwd call in your function retraces the original function and generates a sequence of operations representing the forward-jacobian. I don't believe that there is any cache for this, so each time you call jacfwd the transformation will be recomputed from the cached trace.

  • compilation: I don't believe the that compilation pass currently makes use of previously-compiled units using the trace cache. Any control flow in JAX (loops, recursive calls, etc.) are effectively flattened before being passed to the compiler: in your case the number of operations looks to scale roughly as O[3^k]. Compilation cost is superlinear—and often roughly quadratic—in the number of operations, and so you'll find compilation will become very expensive as k gets larger.

Unfortunately, there's not really any workaround for these issues. When using JAX, you should avoid deep Python control flow like for loops and recursion. You may be able to make progress by re-expressing your recursive function as an iterative function, using one of JAX's control flow operators like fori_loop to reduce the number of lines and cut down the compilation time.

Sign up to request clarification or add additional context in comments.

3 Comments

Thanks. That's a shame, this is by far the best way to write this function. I'll have to rethink, or just hope I don't need k too large.
Final comment after playing more: it seems that the problem is inherent to the tracing of higher-order derivatives. If I just take a scalar function and apply grad to it k times, the first-execution time still scales exponentially with k (albeit at lower rate). And so it seems that as long as composition of derivatives is not made more efficient (maybe with extra caching, or with jax.experimental.jet), this will remain a problem. It's true that 10th order derivatives are not something we often calculate, but it would be satisfying if such compositions were optimized within jax.
This makes sense: computing the gradient of a single operation requires at least two operations (the gradient itself, and multiplication by the tangent) so the total number of operations is at least exponential in the number of gradient transformations. You can get around this with jet, which uses a taylor series expansion to compute multiple gradients without multiple autodiff passes.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.