Tensorflow probability: retrieving specific random variable from joint distribution

183 views Asked by At

I'm new to tensorflow probability. I am building a hierarchical model, for which I use the JointDistributionSequential API:

jds = tfp.distributions.JointDistributionSequential(
[
    # mu_g ~ uniform on sphere
    tfp.distributions.VonMisesFisher(
        mean_direction= [1] + [0]*(D-1),
        concentration=0,
        validate_args=True,
        name="mu_g"
    ),
    # epsilon ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="epsilon"
    ),
    # mu_s ~ von Mises Fisher centered on mu_g
    lambda epsilon, mu_g: tfp.distributions.VonMisesFisher(
        mean_direction=mu_g,
        concentration=np.array(
            [epsilon]*S
        ),
        validate_args=True,
        name="mu_s"
    ),
    # sigma ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="sigma"
    ),
    # mu_t_s ~ von Mises Fisher centered on mu_s
    lambda sigma, mu_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_s,
        concentration=np.array(
            [
                [sigma]*S
            ]*T
        ),
        validate_args=True,
        name="mu_t_s"
    ),
    # kappa ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="kappa"
    ),
    # x_t_s ~ mixture of L groups of vMF
    lambda kappa, mu_t_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_t_s,
        concentration=np.array(
            [
                [
                    [
                        kappa
                    ]*S
                ]*T
            ]*N
        ),
        validate_args=True,
    name="x_t_s
    )            
]
)

I then intend to create a mixture of those models using the Mixture API:

l = tfp.distributions.Categorical(
probs=np.array(
    [
        [
            [
                [1.0/L]*L
            ]*S
        ]*T 
    ]*N               
),
name="l"
)

mixture = tfd.Mixture(
cat=l,
components=[
    jds
] * L,
validate_args=True
)

This doesn't work. What I intend to mix upon is the random variables at the "end" of the hierarchical model, the x_t_s, of batch shape (N, T, S). I guess I need to feed those to the components argument for the mixture. The problem is that I can't easily retrieve those variables from the model object.

Does anybody see a way around this problem ?

Note that I tried using the jds.model[-1] instead of jds, but this points to the lambda function, which isn't what I need here.

1

There are 1 answers

5
Brian Patton On

Several thoughts here.

  1. Consider SphericalUniform for the first distribution.
  2. For Mixtures of the same type, consider using MixtureSameFamily.
  3. Put the mixture into the hierarchical model. i.e. instead of that last distribution being a vMF, it could be a MixtureSameFamily(Categorical(...), VonMisesFisher(...)).
  4. If you later want to access components, you can call ds, xs = jds.sample_distributions(), and look at ds[-1].component_distribution

Feel free to email [email protected] w/ questions too.