Numpyro: Error when using MCMC with a model that uses scan

198 views Asked by At

I am trying to get the following model to work:

def model_dynamic(self, hemp_size_t, values_t):
    # Unpack the values at time t
    t, actions_performed = values_t

    # Check if harvesting are performed at time step t
    harvest = self.is_performed("harvest-hemp", actions_performed)

    # Compute the states at time t + 1
    hemp_size_t1 = deterministic(f"hemp_size", (hemp_size_t + 0.05 * hemp_can_grow_t) * (1 - harvest))

    # Compute the yield at time t
    test = deterministic("yield_test", hemp_size_t * harvest)  # Works ok.
    sample("yield", Normal(test, 0.1))  # assert rng_key is not None
    sample("yield", Gamma(1, 0.1), sample_shape=(1, 1, ))  # assert rng_key is not None
    sample("yield", Normal(hemp_size_t * harvest, 0.1))  # assert rng_key is not None

    return hemp_size_t1, None

def model(self, *args, **kwargs):
    # Sample the initial hemp size
    hemp_size = jnp.zeros((1, 1))

    # Create a vector of time indices
    time_indices = jnp.expand_dims(jnp.expand_dims(jnp.arange(0, len(self.policy)), axis=1), axis=2)

    # Call the scan function that unroll the model over time
    scan(self.model_dynamic, hemp_size, (time_indices, self.policy))

The only observed variable is the yield, and as soon I try to sample it from a distribution I get the following error that seems to come from within the scan function:

File "/usr/local/lib/python3.10/dist-packages/numpyro/contrib/control_flow/scan.py", line 47, in _subs_wrapper
    assert rng_key is not None
AssertionError

In terms of infernce, I am using MCMC as follows:

prng = jax.random.PRNGKey(0)
prng, _rng_key = random.split(prng)
cond_model = numpyro.handlers.condition(model, data=data)
mcmc = numpyro.infer.MCMC(self.kernel, num_chains=4, num_samples=1000, num_warmup=1000)
mcmc.run(rng_key=_rng_key)

I found the following fix: passing a rng_key to the sample function fixes the error.

rng_key = jax.random.PRNGKey(0)
sample("yield", Gamma(1, 0.1), sample_shape=(1, 1, ), rng_key=rng_key)

However, I still don't understand why. I have other models, that are made similarly but do not the rng_key to be provided to the sample function.

0

There are 0 answers