Appending to list alternative workflow for JAX

145 views Asked by At

I am working on an differential equation solver written in JAX. A common workflow I come across is something like this:

import jax.numpy as jnp
from jax import jit

# Function to integrate.
@jit
def dxdt(t, x):
   return -x**2

# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
    return x + f(t, x) * dt

t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]

x_list = []

# initialize x.
x = 0.

for t in t_arr:
    x_list.append(x)
    x = integrator(f, t, x, dt)

x_arr = jnp.array(x_list)

My question is if there is a way to 'vectorize' that for-loop using JAX?

I recognize that jax.vmap() would not be appropriate here, since the variable x is being changed in each for-loop iteration. If there a more JAX-friendly approach to this workflow?

1

There are 1 answers

0
jakevdp On

This sort of sequential operation, where each step is dependent on the last, is supported in JAX via jax.lax.scan. Here's how you might do the equivalent of your computation using scan:

import jax

def scan_body(carry, t):
  x, dt = carry
  new_x = integrator(dxdt, t, x, dt)
  return (new_x, dt), x

_, x_arr = jax.lax.scan(scan_body, (0., dt), t_arr)