Bayesian inference with numpyro, file I/O, and external model calls

39 views Asked by At

I'm working with numpyro to run a Bayesian inferece over a simulator so I need to perform file I/O operations to call this external simulation. Specifically my hierarchical model includes a function (propagate) that writes parameters to a file and then runs an external simulation. However, I'm encountering issues with JAX tracer objects when passing parameters to this function.

Here's a simplified version of my function:

def propagate(param01, param02, param03, qoi='pressure'):
    
    # Trying to treat params as constants
    param01 = jax.lax.stop_gradient(param01)
    param02 = jax.lax.stop_gradient(param02)
    param03 = jax.lax.stop_gradient(param03)

    render_dict = {
        'param01': param01,
        'param02': param02,
        'param03': param03,
        'output_name': "delete"
    }

    env = Environment(loader=FileSystemLoader(sim_path))
    template = env.get_template(input_template)

    input_sample = template.render(render_dict)
    with open('INPUT', 'w') as f:
        f.write(input_sample)

    process = subprocess.run(['wine', 'simulator.exe'], capture_output=True, text=True)
    if process.returncode != 0:
        print(process.stderr)

    if not os.path.exists('delete.dat'):
        raise FileNotFoundError("Expected output file not found.")
    df_hist = get_hist_output( sim_path + '/delete.dat' )
    returned_var = df_hist[qoi].values[:239]

    return returned_var

In the Numpyro model, the propagate function is used like this:

def model_numpyro(dataset, rv_inputs):
    param_dict = {}
    for k, v in rv_inputs.items():
        param_dict[k] = numpyro.sample(k, dist.Uniform(v[0], v[1]))
    
    prod_m = propagate(**param_dict)

    prod_s = numpyro.sample('s_mua', dist.HalfNormal(1e2))
    mua_rv = numpyro.sample('mua', dist.Normal(prod_m, prod_s), obs=dataset[qoi].values[:239])

sample_size = 100

nuts_kernel = NUTS(model_numpyro)
mcmc = MCMC(nuts_kernel, num_samples=sample_size, num_warmup=sample_size//10)
mcmc.run(jax.random.PRNGKey(3), dataset, rv_inputs)
mcmc.print_summary()

The problem seems to be related with the automatic differentiation. I'm facing a ConcretizationTypeError when trying to run the MCMC sampler with this setup indicating a problem to handle abstract tracer values in a context that requires concrete values.

I've tried some approaches like the one that can be seen in the code using jax.lax.stop_gradient and converting the tracer objects to concrete values using jax.device_get().item(), but these approaches haven't resolved the issue.

0

There are 0 answers