I'm new to tensorflow probability. I am building a hierarchical model, for which I use the JointDistributionSequential API:
jds = tfp.distributions.JointDistributionSequential(
[
# mu_g ~ uniform on sphere
tfp.distributions.VonMisesFisher(
mean_direction= [1] + [0]*(D-1),
concentration=0,
validate_args=True,
name="mu_g"
),
# epsilon ~ Exponential
tfp.distributions.Exponential(
rate=1,
validate_args=True,
name="epsilon"
),
# mu_s ~ von Mises Fisher centered on mu_g
lambda epsilon, mu_g: tfp.distributions.VonMisesFisher(
mean_direction=mu_g,
concentration=np.array(
[epsilon]*S
),
validate_args=True,
name="mu_s"
),
# sigma ~ Exponential
tfp.distributions.Exponential(
rate=1,
validate_args=True,
name="sigma"
),
# mu_t_s ~ von Mises Fisher centered on mu_s
lambda sigma, mu_s: tfp.distributions.VonMisesFisher(
mean_direction=mu_s,
concentration=np.array(
[
[sigma]*S
]*T
),
validate_args=True,
name="mu_t_s"
),
# kappa ~ Exponential
tfp.distributions.Exponential(
rate=1,
validate_args=True,
name="kappa"
),
# x_t_s ~ mixture of L groups of vMF
lambda kappa, mu_t_s: tfp.distributions.VonMisesFisher(
mean_direction=mu_t_s,
concentration=np.array(
[
[
[
kappa
]*S
]*T
]*N
),
validate_args=True,
name="x_t_s
)
]
)
I then intend to create a mixture of those models using the Mixture API:
l = tfp.distributions.Categorical(
probs=np.array(
[
[
[
[1.0/L]*L
]*S
]*T
]*N
),
name="l"
)
mixture = tfd.Mixture(
cat=l,
components=[
jds
] * L,
validate_args=True
)
This doesn't work. What I intend to mix upon is the random variables at the "end" of the hierarchical model, the x_t_s, of batch shape (N, T, S). I guess I need to feed those to the components argument for the mixture. The problem is that I can't easily retrieve those variables from the model object.
Does anybody see a way around this problem ?
Note that I tried using the jds.model[-1] instead of jds, but this points to the lambda function, which isn't what I need here.
Several thoughts here.
SphericalUniform
for the first distribution.Mixture
s of the same type, consider usingMixtureSameFamily
.MixtureSameFamily(Categorical(...), VonMisesFisher(...))
.ds, xs = jds.sample_distributions()
, and look atds[-1].component_distribution
Feel free to email [email protected] w/ questions too.