Pyro and conditional probability

762 views Asked by At

I'm working my way through this tutorial: An Introduction to Inference in Pyro

What I don't understand is the following. In order to get (|,=9.5) we can use the pyro.condition function with

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    print(weight)
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

and conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})

I wrote the following script:

    pyro.set_rng_seed(101)
    scale(0.3) # tensor(-1.0905)
    pyro.set_rng_seed(101)
    conditioned_scale(0.3) # tensor(-1.0905)

For both functions we get the same sample for the weight. Isn't this tutorial saying that with conditioned_scale we're getting a sample from a weight distribution that is conditioned on measurement=9.5? If so, shouldn't the samples of the weight be different, because in the first call we don't observe any data but in the second we condition on data?

Thanks!

1

There are 1 answers

0
Ola Rønning On

Running the model will not produce samples from the posterior; you'll need to run inference (like SVI or MCMC).

condition replaces the sample site value with the value you specify. Since you specify values for measurement, weight is unaffected. The model you've written is equivalent to N(measurement;N(weight;guess,1),.75) and by conditioning, you've stated measurement=9.5. conditioned_scale = pyro.condition(scale, data={"weight": 9.5}) and same key will produce different measurements. Below I've written the same program in NumPyro. You should check out https://forum.pyro.ai/.

import numpyro
import numpyro.distributions as dist


def scale(rng_key, guess):
    w_key, m_key = random.split(rng_key)

    weight = numpyro.sample("weight", dist.Normal(guess, 1.0), rng_key=w_key)
    print(weight)
    return numpyro.sample("measurement", dist.Normal(weight, 0.75), rng_key=m_key)


if __name__ == '__main__':
    rng_key = random.PRNGKey(0)
    print(scale(rng_key, 0.3))  # -0.49476373

    conditioned_scale = numpyro.handlers.condition(scale, data={"weight": 9.5})
    print(conditioned_scale(rng_key, 0.3))  # 8.561346