I'm working with numpyro to run a Bayesian inferece over a simulator so I need to perform file I/O operations to call this external simulation. Specifically my hierarchical model includes a function (propagate) that writes parameters to a file and then runs an external simulation. However, I'm encountering issues with JAX tracer objects when passing parameters to this function.
Here's a simplified version of my function:
def propagate(param01, param02, param03, qoi='pressure'):
# Trying to treat params as constants
param01 = jax.lax.stop_gradient(param01)
param02 = jax.lax.stop_gradient(param02)
param03 = jax.lax.stop_gradient(param03)
render_dict = {
'param01': param01,
'param02': param02,
'param03': param03,
'output_name': "delete"
}
env = Environment(loader=FileSystemLoader(sim_path))
template = env.get_template(input_template)
input_sample = template.render(render_dict)
with open('INPUT', 'w') as f:
f.write(input_sample)
process = subprocess.run(['wine', 'simulator.exe'], capture_output=True, text=True)
if process.returncode != 0:
print(process.stderr)
if not os.path.exists('delete.dat'):
raise FileNotFoundError("Expected output file not found.")
df_hist = get_hist_output( sim_path + '/delete.dat' )
returned_var = df_hist[qoi].values[:239]
return returned_var
In the Numpyro model, the propagate function is used like this:
def model_numpyro(dataset, rv_inputs):
param_dict = {}
for k, v in rv_inputs.items():
param_dict[k] = numpyro.sample(k, dist.Uniform(v[0], v[1]))
prod_m = propagate(**param_dict)
prod_s = numpyro.sample('s_mua', dist.HalfNormal(1e2))
mua_rv = numpyro.sample('mua', dist.Normal(prod_m, prod_s), obs=dataset[qoi].values[:239])
sample_size = 100
nuts_kernel = NUTS(model_numpyro)
mcmc = MCMC(nuts_kernel, num_samples=sample_size, num_warmup=sample_size//10)
mcmc.run(jax.random.PRNGKey(3), dataset, rv_inputs)
mcmc.print_summary()
The problem seems to be related with the automatic differentiation. I'm facing a ConcretizationTypeError
when trying to run the MCMC sampler with this setup indicating a problem to handle abstract tracer values in a context that requires concrete values.
I've tried some approaches like the one that can be seen in the code using jax.lax.stop_gradient
and converting the tracer objects to concrete values using jax.device_get().item()
, but these approaches haven't resolved the issue.