(Note: images are for equations as SO doesn't support LaTeX)
I've been trying to solve the 1D wave equation with a PINN:
With the initial and boundary conditions:
For my neural network, I chose a very basic model with one input layer and one output layer. This model is shown below:
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
key = jax.random.PRNGKey(0)
params = jax.random.normal(key, shape=(401,))
def softplus(x):
return jnp.log(1 + jnp.exp(x))
def model(params, x, t, activation=softplus):
wx = params[:100]
wt = params[100:200]
b0 = params[200:300]
w1 = params[300:400]
b2 = params[400]
l1 = activation(x * wx + t * wt + b0)
o = jnp.dot(w1, l1) + b2
return o
I setup my arrays:
x = jnp.linspace(0, 1, samples)
t = jnp.linspace(0, 1, samples)
initial = jnp.sin(2 * jnp.pi * x)
zeros = jnp.zeros(samples)
ones = jnp.ones(samples)
bc = zeros
Then took the derivatives of my model:
df_2dx_2 = grad(grad(model, 1), 1) # ∂^2f / ∂x^2
df_2dt_2 = grad(grad(model, 2), 2) # ∂^2f / ∂t^2
# For vectorized functions
model_vect = jit(vmap(model, (None, 0, 0)))
df_2dx_2_vect = jit(vmap(df_2dx_2, (None, 0, 0)))
df_2dt_2_vect = jit(vmap(df_2dt_2, (None, 0, 0)))
And set the 3 losses - the PDE loss, the boundary loss, and the initial loss:
# The PDE (what we refer to as F(x, t))
@jit
def pde(params, x, t, c=1):
# ∂^2f / ∂t^2 - c^2 (∂^2f / ∂x^2)
return df_2dt_2_vect(params, x, t) - \
c ** 2 * df_2dx_2_vect(params, x, t)
def loss(params, x, t, initial=initial, zeros=zeros, bc=bc, ones=ones):
eq_loss = pde(params, x, t)
initial_cond_loss = model_vect(params, x, zeros) - initial
bc_left_loss = model_vect(params, zeros, t) - bc
bc_right_loss = model_vect(params, ones, t) - bc
return jnp.mean(eq_loss ** 2) \
+ jnp.mean(initial_cond_loss ** 2) \
+ jnp.mean(bc_left_loss ** 2) \
+ jnp.mean(bc_right_loss ** 2)
loss_grad = jit(grad(loss, 0))
I then trained the model:
epochs = 10000
lr = 1e-3
for epoch in range(epochs):
if epoch % 100 == 0:
print(f"Epoch: {epoch} loss: {loss(params, x, t)}")
gradient = loss_grad(params, x, t)
params -= lr * gradient
However, the model never manages to fit anywhere close to a sine curve shape regardless of the number of epochs used - it always looks like nearly a straight line. Is this due to the fact that this partial differential equation is a second-order one?