# Using JAX and Vectorizing a Function

Asked by At

Here is my code that returns a value if you give it a probability between 0 and 1 (it is an inverseCDF function).

``````import jax.numpy as jnp
from jax import jit, vmap, lax
from jaxopt import Bisection

def find_y(M, a1, a2, a3):
"""
Finds the value of y that corresponds to a given value of M(y), using the bisection method implemented with JAX.

Parameters:
M (float): The desired value of M(y).
a1 (float): The value of coefficient a1.
a2 (float): The value of coefficient a2.
a3 (float): The value of coefficient a3.

Returns:
float: The value of y that corresponds to the given value of M(y).
"""
# Define a function that returns the value of M(y) for a given y
@jit
def M_fn(y):
eps = 1e-8  # A small epsilon to avoid taking the log of a negative number
return a1 + a2 * jnp.log(y / (1 - y + eps)) + a3 * (y - 0.5) * jnp.log(y / (1 - y + eps))

# Define a function that returns the difference between M(y) and M
@jit
def f(y):
return M_fn(y) - M

# Set the bracketing interval for the root-finding function
interval = (1e-7, 1 - 1e-7)

# Use the bisection function to find the root
y = Bisection(f, *interval).run().params

# Return the value of y
return y

## test the algorithm
a1 = 16.
a2 = 3.396
a3 = 0.0

y = find_y(16, a1, a2, a3)
print(y)
``````

I would like to pass an array for argument `M` instead of a scalar, but no matter what I try, I get an error (usually about some Boolean trace). Any ideas? Thanks!!