How to get reproducible samples from Tensorflow Probability: tfp.mcmc.sample_chain?

348 views Asked by At

Running a simple Bayesian regression model, I am not able to replicate the results with multiple runs on GPU. I am wondering how I can set tfp.mcmc.sample_chain to generate reproducible results on GPU? Seeding the sample_chain didn't work for me.

The test code snippet:


import os
import random
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import tensorflow_probability as tfp

sns.reset_defaults()
#sns.set_style('whitegrid')
#sns.set_context('talk')
sns.set_context(context='talk',font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

tfd = tfp.distributions
tfb = tfp.bijectors

dtype = tf.float64
dfhogg = pd.DataFrame(np.array([[1, 201, 592, 61, 9, -0.84],
                                 [2, 244, 401, 25, 4, 0.31],
                                 [3, 47, 583, 38, 11, 0.64],
                                 [4, 287, 402, 15, 7, -0.27],
                                 [5, 203, 495, 21, 5, -0.33],
                                 [6, 58, 173, 15, 9, 0.67],
                                 [7, 210, 479, 27, 4, -0.02],
                                 [8, 202, 504, 14, 4, -0.05],
                                 [9, 198, 510, 30, 11, -0.84],
                                 [10, 158, 416, 16, 7, -0.69],
                                 [11, 165, 393, 14, 5, 0.30],
                                 [12, 201, 442, 25, 5, -0.46],
                                 [13, 157, 317, 52, 5, -0.03],
                                 [14, 131, 311, 16, 6, 0.50],
                                 [15, 166, 400, 34, 6, 0.73],
                                 [16, 160, 337, 31, 5, -0.52],
                                 [17, 186, 423, 42, 9, 0.90],
                                 [18, 125, 334, 26, 8, 0.40],
                                 [19, 218, 533, 16, 6, -0.78],
                                 [20, 146, 344, 22, 5, -0.56]]),
                   columns=['id','x','y','sigma_y','sigma_x','rho_xy'])


## for convenience zero-base the 'id' and use as index
dfhogg['id'] = dfhogg['id'] - 1
dfhogg.set_index('id', inplace=True)

## standardize (mean center and divide by 1 sd)
dfhoggs = (dfhogg[['x','y']] - dfhogg[['x','y']].mean(0)) / dfhogg[['x','y']].std(0)
dfhoggs['sigma_y'] = dfhogg['sigma_y'] / dfhogg['y'].std(0)
dfhoggs['sigma_x'] = dfhogg['sigma_x'] / dfhogg['x'].std(0)

X_np = dfhoggs['x'].values
sigma_y_np = dfhoggs['sigma_y'].values
Y_np = dfhoggs['y'].values

def sample(seed):

  mdl_ols_batch = tfd.JointDistributionSequential([
      # b0
      tfd.Normal(loc=tf.cast(0, dtype), scale=1.),
      # b1
      tfd.Normal(loc=tf.cast(0, dtype), scale=1.),
      # likelihood
      #   Using Independent to ensure the log_prob is not incorrectly broadcasted
      lambda b1, b0: tfd.Independent(
          tfd.Normal(
              # Parameter transformation
              loc=b0[..., tf.newaxis] + b1[..., tf.newaxis]*X_np[tf.newaxis, ...],
              scale=sigma_y_np[tf.newaxis, ...]),
          reinterpreted_batch_ndims=1
      ),
  ])
  
  
  @tf.function(autograph=False, experimental_compile=True)
  def run_chain(init_state, 
                step_size,
                target_log_prob_fn,
                unconstraining_bijectors,
                num_steps=500, 
                burnin=50):

    def trace_fn(_, pkr):
      return (
          pkr.inner_results.inner_results.target_log_prob,
          pkr.inner_results.inner_results.leapfrogs_taken,
          pkr.inner_results.inner_results.has_divergence,
          pkr.inner_results.inner_results.energy,
          pkr.inner_results.inner_results.log_accept_ratio
      )

    kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=tfp.mcmc.NoUTurnSampler(
        target_log_prob_fn,
        step_size=step_size),
      bijector=unconstraining_bijectors)

    hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
      inner_kernel=kernel,
      num_adaptation_steps=burnin,
      step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
          inner_results=pkr.inner_results._replace(step_size=new_step_size)),
      step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
      log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio
    )

    # Sampling from the chain.
    chain_state, sampler_stat = tfp.mcmc.sample_chain(
        num_results=num_steps,
        num_burnin_steps=burnin,
        current_state=init_state,
        kernel=hmc,
        trace_fn=trace_fn,
        seed=seed
    )
    return chain_state, sampler_stat

  nchain = 4
  b0, b1, _ = mdl_ols_batch.sample(nchain)
  init_state = [b0, b1]
  step_size = [tf.cast(i, dtype=dtype) for i in [.1, .1]]
  target_log_prob_fn = lambda *x: mdl_ols_batch.log_prob(x + (Y_np, ))

  # bijector to map contrained parameters to real
  unconstraining_bijectors = [
      tfb.Identity(),
      tfb.Identity(),
  ]

  samples, sampler_stat = run_chain(
      init_state, step_size, target_log_prob_fn, unconstraining_bijectors)
  print(tf.reduce_sum(samples))
  

seed = 24
  
os.environ['TF_DETERMINISTIC_OPS'] = 'true'
os.environ['PYTHONHASHSEED'] = f'{seed}'
np.random.seed(seed)
random.seed(seed)
tf.random.set_seed(seed)
sample(seed)

os.environ['TF_DETERMINISTIC_OPS'] = 'true'
os.environ['PYTHONHASHSEED'] = f'{seed}'
np.random.seed(seed)
random.seed(seed)
tf.random.set_seed(seed)
sample(seed)

0

There are 0 answers