I am writing a Markov chain Monte Carlo simulation in JAX which involves a large series of sampling steps. I currently rely on haiku's PRNGSequence to do the pseudo random number generator key bookkeeping:
import haiku as hk
def step(key, context):
key_seq = hk.PRNGSequence(key)
x1 = sampler(next(key_seq), context_1)
...
xn = other_sampler(next(key_seq), context_n)
Question:
Since Haiku has been discontinued, I am looking for an alternative to PRNGSequence.
I find the standard JAX approach:
def step(key, context):
key, subkey = jax.random.split(key)
x1 = sampler(subkey, context_1)
...
key, subkey = jax.random.split(key)
xn = other_sampler(subkey, context_n)
unsatisfactory on two accounts:
- Very error prone: It is easy to slip up and re-use a key. This is especially problematic in MCMC simulations, which are sensitive to these biases and very difficult to debug.
- It is quite bulky: I need to roughly double the size of my code to split keys.
Any suggestions how to mitigate these problems?
Thanks!
Hylke
If all you need is a simple class that locally handles splitting keys for you, why not define it yourself? You could create a suitable one in a few lines – for example:
As always, though, you have to be careful about this kind of hidden state when you're using JAX transformations like
jit
: see JAX Sharp Bits: Pure Functions for information on this.