target_log_prob_fn for JointDistributionNamed

61 views Asked by At

I'm currently learning Bayesian approaches through the book "Rethinking" and would like to utilize TensorFlow Probability. Specifically, I'm working on this chapter : https://colab.research.google.com/github/ksachdeva/rethinking-tensorflow-probability/blob/master/notebooks/04_geocentric_models.ipynb In this example, the model is defined using a function. However, I prefer to use the JointDistributionNamed approach as it closely aligns with how models are written in "Rethinking." I'm facing issues when running HMC, particularly with the target_log_prob_fn function. I'm unsure about the type of object returned within the HamiltonianMonteCarlo function.

Here is my code and the corresponding error message:

import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
tfb = tfp.bijectors
import pandas as pd
d = pd.read_csv('rethinking-master/data/Howell1.csv',sep = ';')
height= d.height
model = tfd.JointDistributionNamedAutoBatched(dict(
    s = tfd.Sample(tfd.Exponential(1), sample_shape=1),
    alpha = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
    beta = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
    y = lambda s,alpha,beta: tfd.Independent(tfd.Normal( alpha + beta * d.weight.values,s),
                                          reinterpreted_batch_ndims=1),
))

def _trace_fn_transitioned(_, pkr):
    return pkr.inner_results.inner_results.log_accept_ratio

num_chains = 4
num_leapfrog_steps = 4
step_size = 0.8
burnin = 500
params = ['s', 'alpha', 'beta']
init_state = list(model.sample(num_chains))[:-1]
bijectors = [tfb.Identity() for i in init_state]
observed_data=(d.height.values)

def target_log_prob_fn (**x):
    print(x[0])
    return model.log_prob(model.sample(y = observed_data, **x))


hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size
    )

inner_kernel = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=hmc_kernel, bijector=bijectors
)

kernel = tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=inner_kernel,
    target_accept_prob=0.8,
    num_adaptation_steps=int(0.8 * burnin),
    log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)

tfp.mcmc.sample_chain(
        num_results=544,
        num_burnin_steps=burnin,
        current_state=init_state,
        kernel=kernel,
        trace_fn=_trace_fn_transitioned,
    )

error: target_log_prob_fn() takes 0 positional arguments but 3 were given

Thanks,

1

There are 1 answers

1
Giorgio On BEST ANSWER

To use the unpacking approach with **x, you would need to adjust how the HMC kernel passes the arguments. Unfortunately, this isn't straightforward because the TFP MCMC framework is designed to work with positional arguments, and I have no idea how to help you there.

The standard way to define target_log_prob_fn would be

def target_log_prob_fn(s, alpha, beta):
    return model.log_prob(s=s, alpha=alpha, beta=beta, y=observed_data)

Moreover, there's an issue with how you're sampling initial states and creating bijectors. Since you have three parameters (s, alpha, beta), your initial state should be a list of tensors, each corresponding to one of these parameters. The bijectors should match these parameters as well. You can do it like this

init_state = [model.sample()[param].numpy() for param in params]    
bijectors = [tfb.Identity() for _ in params]

Running it on my machine has the desired outcome, so I believe that it should be the same for you. Although it may not be the reply you were looking for, I hope it stil helps