Linked Questions

Popular Questions

Using JAX and Vectorizing a Function

Asked by At

Here is my code that returns a value if you give it a probability between 0 and 1 (it is an inverseCDF function).

import jax.numpy as jnp
from jax import jit, vmap, lax
from jaxopt import Bisection

def find_y(M, a1, a2, a3):
    Finds the value of y that corresponds to a given value of M(y), using the bisection method implemented with JAX.

    M (float): The desired value of M(y).
    a1 (float): The value of coefficient a1.
    a2 (float): The value of coefficient a2.
    a3 (float): The value of coefficient a3.

    float: The value of y that corresponds to the given value of M(y).
    # Define a function that returns the value of M(y) for a given y
    def M_fn(y):
        eps = 1e-8  # A small epsilon to avoid taking the log of a negative number
        return a1 + a2 * jnp.log(y / (1 - y + eps)) + a3 * (y - 0.5) * jnp.log(y / (1 - y + eps))

    # Define a function that returns the difference between M(y) and M
    def f(y):
        return M_fn(y) - M

    # Set the bracketing interval for the root-finding function
    interval = (1e-7, 1 - 1e-7)

    # Use the bisection function to find the root
    y = Bisection(f, *interval).run().params

    # Return the value of y
    return y

## test the algorithm
a1 = 16.
a2 = 3.396
a3 = 0.0

y = find_y(16, a1, a2, a3)

I would like to pass an array for argument M instead of a scalar, but no matter what I try, I get an error (usually about some Boolean trace). Any ideas? Thanks!!

Related Questions