Using JAX and Vectorizing a Function

102 views Asked by At

Here is my code that returns a value if you give it a probability between 0 and 1 (it is an inverseCDF function).

import jax.numpy as jnp
from jax import jit, vmap, lax
from jaxopt import Bisection

def find_y(M, a1, a2, a3):
    """
    Finds the value of y that corresponds to a given value of M(y), using the bisection method implemented with JAX.

    Parameters:
    M (float): The desired value of M(y).
    a1 (float): The value of coefficient a1.
    a2 (float): The value of coefficient a2.
    a3 (float): The value of coefficient a3.

    Returns:
    float: The value of y that corresponds to the given value of M(y).
    """
    # Define a function that returns the value of M(y) for a given y
    @jit
    def M_fn(y):
        eps = 1e-8  # A small epsilon to avoid taking the log of a negative number
        return a1 + a2 * jnp.log(y / (1 - y + eps)) + a3 * (y - 0.5) * jnp.log(y / (1 - y + eps))

    # Define a function that returns the difference between M(y) and M
    @jit
    def f(y):
        return M_fn(y) - M

    # Set the bracketing interval for the root-finding function
    interval = (1e-7, 1 - 1e-7)

    # Use the bisection function to find the root
    y = Bisection(f, *interval).run().params

    # Return the value of y
    return y

## test the algorithm
a1 = 16.
a2 = 3.396
a3 = 0.0

y = find_y(16, a1, a2, a3)
print(y)

I would like to pass an array for argument M instead of a scalar, but no matter what I try, I get an error (usually about some Boolean trace). Any ideas? Thanks!!

1

There are 1 answers

3
jakevdp On BEST ANSWER

You can do this with jax.vmap, as long as you set check_bracket=False in Bisection (see here):

y = Bisection(f, *interval, check_bracket=False).run().params

With that change to your function, you can pass a vector of values for M like this:

import jax
M = jnp.array([4, 8, 16, 32])
result = jax.vmap(find_y, in_axes=(0, None, None, None))(M, a1, a2, a3)
print(result)
[0.02837202 0.08661279 0.5        0.99108815]