Migration from haiku: Alternative to Haiku's PRNGSequence?

79 views Asked by At

I am writing a Markov chain Monte Carlo simulation in JAX which involves a large series of sampling steps. I currently rely on haiku's PRNGSequence to do the pseudo random number generator key bookkeeping:

import haiku as hk

def step(key, context):
  key_seq = hk.PRNGSequence(key)
  x1 = sampler(next(key_seq), context_1)
  ...
  xn = other_sampler(next(key_seq), context_n)

Question:

Since Haiku has been discontinued, I am looking for an alternative to PRNGSequence.

I find the standard JAX approach:

def step(key, context):
  key, subkey = jax.random.split(key)
  x1 = sampler(subkey, context_1)
  ...
  key, subkey = jax.random.split(key)
  xn = other_sampler(subkey, context_n)

unsatisfactory on two accounts:

  • Very error prone: It is easy to slip up and re-use a key. This is especially problematic in MCMC simulations, which are sensitive to these biases and very difficult to debug.
  • It is quite bulky: I need to roughly double the size of my code to split keys.

Any suggestions how to mitigate these problems?

Thanks!

Hylke

1

There are 1 answers

0
jakevdp On BEST ANSWER

If all you need is a simple class that locally handles splitting keys for you, why not define it yourself? You could create a suitable one in a few lines – for example:

import jax

class PRNGSequence:
  def __init__(self, key):
    self._key = key
  def __next__(self):
    self._key, key = jax.random.split(self._key)
    return key

def step(key):
  key_seq = PRNGSequence(key)
  print(jax.random.uniform(next(key_seq)))
  print(jax.random.uniform(next(key_seq)))

step(jax.random.PRNGKey(0))
# 0.10536897
# 0.2787192

As always, though, you have to be careful about this kind of hidden state when you're using JAX transformations like jit: see JAX Sharp Bits: Pure Functions for information on this.