Here is my code that returns a value if you give it a probability between 0 and 1 (it is an inverseCDF function).
import jax.numpy as jnp
from jax import jit, vmap, lax
from jaxopt import Bisection
def find_y(M, a1, a2, a3):
"""
Finds the value of y that corresponds to a given value of M(y), using the bisection method implemented with JAX.
Parameters:
M (float): The desired value of M(y).
a1 (float): The value of coefficient a1.
a2 (float): The value of coefficient a2.
a3 (float): The value of coefficient a3.
Returns:
float: The value of y that corresponds to the given value of M(y).
"""
# Define a function that returns the value of M(y) for a given y
@jit
def M_fn(y):
eps = 1e-8 # A small epsilon to avoid taking the log of a negative number
return a1 + a2 * jnp.log(y / (1 - y + eps)) + a3 * (y - 0.5) * jnp.log(y / (1 - y + eps))
# Define a function that returns the difference between M(y) and M
@jit
def f(y):
return M_fn(y) - M
# Set the bracketing interval for the root-finding function
interval = (1e-7, 1 - 1e-7)
# Use the bisection function to find the root
y = Bisection(f, *interval).run().params
# Return the value of y
return y
## test the algorithm
a1 = 16.
a2 = 3.396
a3 = 0.0
y = find_y(16, a1, a2, a3)
print(y)
I would like to pass an array for argument M
instead of a scalar, but no matter what I try, I get an error (usually about some Boolean trace). Any ideas? Thanks!!