vectorized minimization and root finding in jax

186 views Asked by At

I have a family of functions parameterized by args

f(x, args)

and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My first attempt was to loop through the different values of args and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap inside a jax.scipy.optimize.minimize or jaxopt.ScipyMinimize, but I can't seem to pass more than one value for args.

Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.

I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args are not changing.

1

There are 1 answers

0
jakevdp On BEST ANSWER

You can define a function to find the minimum given particular args, and then wrap it in jax.vmap to automatically vectorize it. For example:

import jax
import jax.numpy as jnp
from jax.scipy import optimize

def f(x, args):
  a, b = args
  return jnp.sum(a + (x - b) ** 2)

def find_min(a, b):
  x0 = jnp.array([1.0])
  args = (a, b)
  return optimize.minimize(f, x0, (args,), method="BFGS")

a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))

results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())

print(results.success)
# [ True  True  True  True  True  True  True  True  True  True  True  True
#   True  True  True  True  True  True  True  True  True  True  True  True
#   True]

print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
#   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]