I have a recursively defined function my_func that is jitted using jax.jit from the jax library. It is defined below:
# Imports
import jax
import jax.numpy as jnp
from functools import partial
import time
# Constants and subroutines used in the core recursive routing below ...
sx = jnp.asarray([[0,1.],[1.,0]], dtype=complex)
sy = jnp.asarray([[0,-1j],[1j,0]], dtype=complex)
def conj_op(A):
return jnp.swapaxes(A, -1,-2).conj()
def commutator_herm(A, B):
comm = A @ B
comm = comm - conj_op(comm)
return comm
def H(t):
return jnp.cos(t) * sy
def X0(t):
return sx
# Core recursive routine ...
@partial(jax.jit, static_argnames="k")
def my_func(t, k):
if k==0:
X_k = X0(t)
return X_k
else:
X_km1 = lambda t: my_func(t,k-1)
X_k = 1j * commutator_herm(H(t), X_km1(t)) + jax.jacfwd(X_km1, holomorphic=True)(t)
return X_k
with the relevant test:
# Tests ...
t = jnp.asarray(1, dtype=complex)
seq_exec_times = []
for k in range(9,10): # or toggle to range(10) to compile sequentially
start = time.time()
my_func(t, k)
dur = time.time() - start
seq_exec_times.append(dur)
total_seq_exec_time = sum(seq_exec_times)
print("Sequential execution times:")
print(["{:.3e} s".format(x) for x in seq_exec_times])
print("Total execution time:")
print("{:.3e} s".format(total_seq_exec_time))
If I execute this function the first time only for k=9, then I get a quite long compilation time, which I figure is because tracing a recursive function like this one takes an effort that scales exponentially with recursion depth. The output is:
Sequential execution times:
['6.306e+01 s'] # First execution time when calling directly with k=9
Total execution time:
6.306e+01 s
But then I thought that in practice, I need to evaluate my_func for increasing values of k=0,1,2,3... anyway. And if the lower step has already been traced, then you only need to trace the next level of the tree, and that should be more efficient. And indeed, executing k=1,2,3...,8 before executing k=9 yields a slightly lower execution time the first time k=9 is evaluated:
Sequential execution times:
['3.797e-03 s',
'2.203e-02 s',
'3.487e-02 s',
'7.054e-02 s',
'1.779e-01 s',
'4.680e-01 s',
'1.326e+00 s',
'4.145e+00 s',
'1.456e+01 s',
'5.550e+01 s'] # First execution time of k=9 after calling k=0,1,2,3...,8 first
Total execution time:
7.631e+01 s
That said, this still scales exponentially with recursion depth, and I was naively expecting the compilation of k=9 to be more efficient. If the lower levels k=1,2,3...,8 are already compiled, then I would naively expect the compilation at the next level k=9 to be relatively simple. I would think that you can simply trace the link between k=9 and k=8, and avoid going through the whole recursion tree again at the lower levels.
But it looks like I was wrong, and I'm curious to know why? And if I'm not wrong, how do I make this better?
This was run with jax - 0.4.33 on MacOS - 15.6.1.