779 questions
1
vote
3
answers
54
views
Assigning abstract attributes through the constructor in a `jax` python dataclass
I'm trying to subclass the AbstractWrappedSolver dataclass from the jax library diffrax. The class has this definition:
class AbstractWrappedSolver(AbstractSolver[_SolverState]):
"""...
2
votes
2
answers
39
views
Skip computation of output leaves of `jax` pytree function when input leaves are set to `None`
I have a function fun that takes as argument a jax pytree and returns a jax pytree with the same structure. Sometimes, I don't want to calculate the function for a specific leaf of the pytree. In that ...
Best practices
1
vote
2
replies
41
views
Apply different function to each leaf of pytree in `jax`
Suppose that I have a jax pytree with n leaves and a set of distinct functions to apply to each leaf of the pytree. I can do the following:
from functools import partial
import jax
# Function to ...
1
vote
1
answer
63
views
Run expensive function (containing for loop) on multiple GPUs. pmap gives out of memory error
I have an expensive function expensive_func, which I am trying to run for multiple input parameters stored in the array inputs of size (N, m) where N is the total number of cases. I want to perform ...
0
votes
1
answer
37
views
jax.distributed.initialize hangs when submit job with Slurm
I am attempting to run a JAX script in a distributed manner across multiple hosts using Slurm. The initialization code is as follows:
jax.distributed.initialize(
coordinator_address=f"{os....
1
vote
0
answers
29
views
PyMC + Diffrax: MCMC sampler hangs at 0% for hierarchical PK ODE model with JAX backend
Problem
I'm implementing a hierarchical pharmacokinetic (PK) model in PyMC using Diffrax for ODE solving via the icomo.jax2pytensor wrapper. The model compiles successfully and initializes, but MCMC ...
Advice
0
votes
1
replies
68
views
Can Ray pause/resume tasks at synchronization points when GPUs are limited?
I'm training multiple neural networks in parallel using Ray, where networks must synchronize at specific points during training (not just at completion) to share metadata and update hyperparameters ...
3
votes
0
answers
110
views
How to vectorize (ensemble) nnx.Modules with separate parameters using nnx.vmap in JAX/Flax
I have a vectorized (ensemble) Q-network implemented using Flax Linen that works as expected. Each critic in the ensemble has separate parameters, and the output is stacked along the first dimension (...
1
vote
1
answer
129
views
why the order of return variables affect the jax jitted function's performance so much?
In jax, you can donate a function argument to save the execute memory and time, if this argument is not used any more.
If you know that one of the inputs is not needed after the computation, and if ...
1
vote
0
answers
136
views
Is my JAX implementation of continuous wavelet transform correct?
I would like to implement continuous wavelet transform (CWT) using JAX. According to ChatGPT, it is in practice computed by performing a discrete convolution with a sampled wavelet function at ...
1
vote
1
answer
56
views
Grain (JAX) - equivalent to pyTorch `collect_fn` for batches
I defined a dataset class with __len__ and __getitem__ which returns a tuple of values. I can use `grain.transforms.Batch` to compose batches, but how do I specify how each item is combined into a ...
1
vote
1
answer
176
views
How to JIT-compile a function in JAX when input dimensions grow over time?
I’m implementing time series models using JAX in Python. These models are computationally expensive and need to be retrained over time using an expanding window approach. To improve performance, I ...
0
votes
1
answer
149
views
Passing 4 arguments to a JIT function with 4 parameters raises "TypeError: jit() takes 1 positional argument but 5 positional arguments"
For GPU optimized simulations I have a function where the head looks like this:
import jax
from functools import partial
@partial(partial, jax.jit, static_argnums=(2,4))
def init_particles_grid(...
3
votes
1
answer
1k
views
How to correctly install JAX with CUDA on Linux when `jax[cuda12_pip]` consistently falls back to the CPU version?
I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, ...
4
votes
1
answer
150
views
JAX crashes with `CUDNN_STATUS_INTERNAL_ERROR` when using `joblib` or `multiprocessing`, but works in a single process
I am running into a FAILED_PRECONDITION: DNN library initialization failed error when trying to parallelize a JAX function using either Python's multiprocessing library or joblib.
The strange part is ...