1

I want to solve a 2D-differential equation using neural network and working with the JAX library. The neural network function I am using basically approximates the function u = f(x,y) and goes something like this:

def f(params, inputs_x, inputs_y):
  inputs = jnp.concatenate((inputs_x, inputs_y), axis=1)
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs

params is a PyTree that contains the weights and biases matrices. For the 2D problem, let's take layer sizes as something like [2,5,1]. There are 10 batches of (x_inputs, y_inputs) passed onto the function, hence inputs_x, inputs_y both are of shapes (10,1). Therefore, the output I want should also have the shape (10,1). But, the real problem comes when I'm trying to find out du/dx, du/dy, d2u/dx2 or d2u/dy2. I am writing something like this:

u = lambda x,y: f(params, x, y)
    
u = lambda x,y: f(params, x)
u_x = lambda x,y: vmap(jacfwd(u,argnums=0), in_axes=(0,0))(x,y)
u_xx = lambda x,y: vmap(jacfwd(u_x,argnums=0), in_axes=(0,0))(x,y)

I am getting errors.

If I was solving a 1D differential equation, then everything was going fine. In that case, the neural network function is something like this:

def f(params, inputs):
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs
u = lambda x,: f(params, x)
u_x = lambda x: vmap(jacfwd(u,argnums=0))(x)

Layer Sizes are [1,5,1] and I pass 10 batches of inputs into the neural network function and compute the gradients using vmap. Everything works fine!

As soon as I have a 2D problem and two input neurons, the layer sizes become [2,5,1] and then I pass 10 batches of inputs for both x and y together, vmap doesn't work anymore. I wanted to find du/dx, du/dy, d2u/dx2 or d2u/dy2 using the neural network and four functions below, and I expect all the four functions to return me results of shape (10,1), but I am getting error.

1
  • Please clarify your specific problem or provide additional details to highlight exactly what you need. As it's currently written, it's hard to tell exactly what you're asking.
    – Community Bot
    Commented Dec 25, 2022 at 13:04

1 Answer 1

0

It looks like your function is not compatible with vmap, because it expects explicit batch dimensions. You can fix this by concatenating along axis=-1 rather than axis=1. Then your function calls could look something like the following:

from functools import partial
import jax
import jax.numpy as jnp
from jax import nn as jnn

def f(params, inputs_x, inputs_y):
  inputs = jnp.concatenate((inputs_x, inputs_y), axis=-1)
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs

# Some example inputs and parameters
inputs_x = jnp.ones((10, 1))
inputs_y = jnp.ones((10, 1))
params = [
    (jnp.ones((2, 5)), 1),
    (jnp.ones((5, 1)), 1)
]

u = partial(f, params)

# u: (10,1)->(10,1)
print(u(inputs_x, inputs_y).shape)
# (10, 1)

# u: (1)->(1) batched to (10,1)->(10,1)
print(jax.vmap(u)(inputs_x, inputs_y).shape)
# (10, 1)

# ∇u: (1) -> (1,1) batched to (10,1)->(10,1,1)
print(jax.vmap(jax.jacobian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1)

# ∇²u: (1) -> (1,1,1) batched to (10,1)->(10,1,1,1)
print(jax.vmap(jax.hessian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1, 1)
0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.