I am working on an differential equation solver written in JAX. A common workflow I come across is something like this:
import jax.numpy as jnp
from jax import jit
# Function to integrate.
@jit
def dxdt(t, x):
return -x**2
# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
return x + f(t, x) * dt
t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]
x_list = []
# initialize x.
x = 0.
for t in t_arr:
x_list.append(x)
x = integrator(f, t, x, dt)
x_arr = jnp.array(x_list)
My question is if there is a way to 'vectorize' that for-loop using JAX?
I recognize that jax.vmap()
would not be appropriate here, since the variable x is being changed in each for-loop iteration. If there a more JAX-friendly approach to this workflow?
This sort of sequential operation, where each step is dependent on the last, is supported in JAX via
jax.lax.scan
. Here's how you might do the equivalent of your computation usingscan
: