I am trying to get the following model to work:
def model_dynamic(self, hemp_size_t, values_t):
# Unpack the values at time t
t, actions_performed = values_t
# Check if harvesting are performed at time step t
harvest = self.is_performed("harvest-hemp", actions_performed)
# Compute the states at time t + 1
hemp_size_t1 = deterministic(f"hemp_size", (hemp_size_t + 0.05 * hemp_can_grow_t) * (1 - harvest))
# Compute the yield at time t
test = deterministic("yield_test", hemp_size_t * harvest) # Works ok.
sample("yield", Normal(test, 0.1)) # assert rng_key is not None
sample("yield", Gamma(1, 0.1), sample_shape=(1, 1, )) # assert rng_key is not None
sample("yield", Normal(hemp_size_t * harvest, 0.1)) # assert rng_key is not None
return hemp_size_t1, None
def model(self, *args, **kwargs):
# Sample the initial hemp size
hemp_size = jnp.zeros((1, 1))
# Create a vector of time indices
time_indices = jnp.expand_dims(jnp.expand_dims(jnp.arange(0, len(self.policy)), axis=1), axis=2)
# Call the scan function that unroll the model over time
scan(self.model_dynamic, hemp_size, (time_indices, self.policy))
The only observed variable is the yield, and as soon I try to sample it from a distribution I get the following error that seems to come from within the scan function:
File "/usr/local/lib/python3.10/dist-packages/numpyro/contrib/control_flow/scan.py", line 47, in _subs_wrapper
assert rng_key is not None
AssertionError
In terms of infernce, I am using MCMC as follows:
prng = jax.random.PRNGKey(0)
prng, _rng_key = random.split(prng)
cond_model = numpyro.handlers.condition(model, data=data)
mcmc = numpyro.infer.MCMC(self.kernel, num_chains=4, num_samples=1000, num_warmup=1000)
mcmc.run(rng_key=_rng_key)
I found the following fix: passing a rng_key to the sample function fixes the error.
rng_key = jax.random.PRNGKey(0)
sample("yield", Gamma(1, 0.1), sample_shape=(1, 1, ), rng_key=rng_key)
However, I still don't understand why. I have other models, that are made similarly but do not the rng_key to be provided to the sample function.