Error associated with using NumPyro to create a linear regression model

48 views Asked by At

I'm using Numpyro to create a simple linear regression model consisting of two variables, the aim is to obtain a similar graph to https://num.pyro.ai/en/latest/tutorials/bayesian_regression.html (3rd graph).

I have used numpyro to generate 2000 samples of data and all of the code below runs as expected.

def model(data=None):
    mu = numpyro.sample("mu", dist.Normal(0.0, 0.2))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, data=data.AgeUncScaled.values
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

Now, when moving onto the final step:

def plot_regression(x, y_mean, y_hpdi):
    # Sort values for plotting by x axis
    idx = jnp.argsort(x)
    age = x[idx]
    mean = y_mean[idx]
    hpdi = y_hpdi[:, idx]
    age_unc = dataset.AgeUncScaled.values[idx]

    # Plot
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
    ax.plot(age, mean)
    ax.plot(age, age_unc, "o")
    ax.fill_between(age, hpdi[0], hpdi[1], alpha=0.3, interpolate=True)
    return ax

posterior_mu = (
    jnp.expand_dims(samples_1["mu"], -1)
  
)

mean_mu = jnp.mean(posterior_mu, axis=0)
hpdi_mu = hpdi(posterior_mu, 0.9)
ax = plot_regression(dataset.AgeScaled.values, mean_mu, hpdi_mu)

I run into all kinds of errors (including type & index) ... not too sure what is going on and would be grateful for any help :)

0

There are 0 answers