I have a family of functions parameterized by args
f(x, args)
and want to determine the minimum of f
over x
for N = 1000
values of args
. I have access to both the function and its derivative. My first attempt was to loop through the different values of args
and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap
inside a jax.scipy.optimize.minimize
or jaxopt.ScipyMinimize
, but I can't seem to pass more than one value for args
.
Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.
I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args
are not changing.
You can define a function to find the minimum given particular
args
, and then wrap it injax.vmap
to automatically vectorize it. For example: