Problem in Implementing a Graphical Model Using Pyro

433 views Asked by At

I am trying to implement this graphical model using Pyro:

enter image description here

My implementation is:

def model(data): 
    p = pyro.sample('p', dist.Beta(1, 1))

    label_axis = pyro.plate("label_axis", data.shape[0], dim=-3)
    f_axis = pyro.plate("f_axis", data.shape[1], dim=-2)

    with label_axis:
        l = pyro.sample('l', dist.Bernoulli(p))
    
    with f_axis:
        e = pyro.sample('e', dist.Beta(1, 10))

    with label_axis, f_axis:
        f = pyro.sample('f', dist.Bernoulli(1-e), obs=data)
        f = l*f + (1-l)*(1-f)
     return f

However, this doesn't seem to be right to me. The problem is "f". Since its distribution is different from Bernoulli. To sample from f, I used a sample from a Bernoulli distribution and then changed the sampled value if l=0. But I don't think that this would change the value that Pyro stores behind the scene for "f". This would be a problem when it's inferencing, right?

I wanted to use iterative plates instead of vectorized one, to be able to use control statements inside my plate. But apparently, this is not possible since I am reusing plates.

How can I correctly implement this PGM? Do I need to write a custom distribution? Or can I hack Pyro and change the stored value for "f" myself? Any type of help is appreciated! Cheers!

1

There are 1 answers

0
Arash Khoeini On BEST ANSWER

Here is the correct implementation:

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

def model(data): 
    p = pyro.sample('p', dist.Beta(1, 1))

    label_axis = pyro.plate("label_axis", data.shape[0], dim=-2)
    f_axis = pyro.plate("f_axis", data.shape[1], dim=-1)

    with label_axis:
        l = pyro.sample('l', dist.Bernoulli(p))

    with f_axis:
        e = pyro.sample('e', dist.Beta(1, 10))

    with label_axis, f_axis:
        prob = l * (1 - e) + (1 - l) * e
        return pyro.sample('f', dist.Bernoulli(prob), obs=data)

mcmc = MCMC(NUTS(model), 500, 500)
data = dist.Bernoulli(0.5).sample((20, 4))
mcmc.run(data)