Are PINNs inherently worse at solving 2nd-order PDEs?

49 views Asked by At

(Note: images are for equations as SO doesn't support LaTeX)

I've been trying to solve the 1D wave equation with a PINN:

Wave equation PDE

With the initial and boundary conditions:

Boundary and initial conditions of wave equation PDE

For my neural network, I chose a very basic model with one input layer and one output layer. This model is shown below:

import jax
import jax.numpy as jnp
from jax import jit, grad, vmap

key = jax.random.PRNGKey(0)
params = jax.random.normal(key, shape=(401,))

def softplus(x):
    return jnp.log(1 + jnp.exp(x))

def model(params, x, t, activation=softplus):
    wx = params[:100]
    wt = params[100:200]
    b0 = params[200:300]
    w1 = params[300:400]
    b2 = params[400]
    l1 = activation(x * wx + t * wt + b0)
    o = jnp.dot(w1, l1) + b2
    return o

I setup my arrays:

x = jnp.linspace(0, 1, samples)
t = jnp.linspace(0, 1, samples)

initial = jnp.sin(2 * jnp.pi * x)
zeros = jnp.zeros(samples)
ones = jnp.ones(samples)
bc = zeros

Then took the derivatives of my model:

df_2dx_2 = grad(grad(model, 1), 1) # ∂^2f / ∂x^2
df_2dt_2 = grad(grad(model, 2), 2) # ∂^2f / ∂t^2

# For vectorized functions
model_vect = jit(vmap(model, (None, 0, 0)))
df_2dx_2_vect = jit(vmap(df_2dx_2, (None, 0, 0)))
df_2dt_2_vect = jit(vmap(df_2dt_2, (None, 0, 0)))

And set the 3 losses - the PDE loss, the boundary loss, and the initial loss:

# The PDE (what we refer to as F(x, t))
@jit
def pde(params, x, t, c=1):
    # ∂^2f / ∂t^2 - c^2 (∂^2f / ∂x^2)
    return df_2dt_2_vect(params, x, t) - \
        c ** 2 * df_2dx_2_vect(params, x, t)

def loss(params, x, t, initial=initial, zeros=zeros, bc=bc, ones=ones):
    eq_loss = pde(params, x, t)
    initial_cond_loss = model_vect(params, x, zeros) - initial
    bc_left_loss = model_vect(params, zeros, t) - bc
    bc_right_loss = model_vect(params, ones, t) - bc
    return jnp.mean(eq_loss ** 2) \
            + jnp.mean(initial_cond_loss ** 2) \
            + jnp.mean(bc_left_loss ** 2) \
            + jnp.mean(bc_right_loss ** 2)

loss_grad = jit(grad(loss, 0))

I then trained the model:

epochs = 10000
lr = 1e-3

for epoch in range(epochs):
    if epoch % 100  == 0:
        print(f"Epoch: {epoch} loss: {loss(params, x, t)}")
    gradient = loss_grad(params, x, t)
    params -= lr * gradient

However, the model never manages to fit anywhere close to a sine curve shape regardless of the number of epochs used - it always looks like nearly a straight line. Is this due to the fact that this partial differential equation is a second-order one?

0

There are 0 answers