I want to solve a 2D-differential equation using neural network and working with the JAX library. The neural network function I am using basically approximates the function u = f(x,y) and goes something like this:
def f(params, inputs_x, inputs_y):
inputs = jnp.concatenate((inputs_x, inputs_y), axis=1)
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
params is a PyTree that contains the weights and biases matrices. For the 2D problem, let's take layer sizes as something like [2,5,1]. There are 10 batches of (x_inputs, y_inputs) passed onto the function, hence inputs_x, inputs_y both are of shapes (10,1). Therefore, the output I want should also have the shape (10,1). But, the real problem comes when I'm trying to find out du/dx, du/dy, d2u/dx2 or d2u/dy2. I am writing something like this:
u = lambda x,y: f(params, x, y)
u = lambda x,y: f(params, x)
u_x = lambda x,y: vmap(jacfwd(u,argnums=0), in_axes=(0,0))(x,y)
u_xx = lambda x,y: vmap(jacfwd(u_x,argnums=0), in_axes=(0,0))(x,y)
I am getting errors.
If I was solving a 1D differential equation, then everything was going fine. In that case, the neural network function is something like this:
def f(params, inputs):
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
u = lambda x,: f(params, x)
u_x = lambda x: vmap(jacfwd(u,argnums=0))(x)
Layer Sizes are [1,5,1] and I pass 10 batches of inputs into the neural network function and compute the gradients using vmap. Everything works fine!
As soon as I have a 2D problem and two input neurons, the layer sizes become [2,5,1] and then I pass 10 batches of inputs for both x and y together, vmap doesn't work anymore. I wanted to find du/dx, du/dy, d2u/dx2 or d2u/dy2 using the neural network and four functions below, and I expect all the four functions to return me results of shape (10,1), but I am getting error.