I'm using Numpyro to create a simple linear regression model consisting of two variables, the aim is to obtain a similar graph to https://num.pyro.ai/en/latest/tutorials/bayesian_regression.html (3rd graph).
I have used numpyro to generate 2000 samples of data and all of the code below runs as expected.
def model(data=None):
mu = numpyro.sample("mu", dist.Normal(0.0, 0.2))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
rng_key_, data=data.AgeUncScaled.values
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
Now, when moving onto the final step:
def plot_regression(x, y_mean, y_hpdi):
# Sort values for plotting by x axis
idx = jnp.argsort(x)
age = x[idx]
mean = y_mean[idx]
hpdi = y_hpdi[:, idx]
age_unc = dataset.AgeUncScaled.values[idx]
# Plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
ax.plot(age, mean)
ax.plot(age, age_unc, "o")
ax.fill_between(age, hpdi[0], hpdi[1], alpha=0.3, interpolate=True)
return ax
posterior_mu = (
jnp.expand_dims(samples_1["mu"], -1)
)
mean_mu = jnp.mean(posterior_mu, axis=0)
hpdi_mu = hpdi(posterior_mu, 0.9)
ax = plot_regression(dataset.AgeScaled.values, mean_mu, hpdi_mu)
I run into all kinds of errors (including type & index) ... not too sure what is going on and would be grateful for any help :)