I am struggling with implementing a model where the concentration factor of the Dirichlet variable is dependent on another variable.
The situation is the following:
A system fails due to faulty components (there are three components, only one fails at each test/observation).
The probability of failure of the components is dependent on the temperature.
Here is a (commented) short implementation of the situation:
import numpy as np
import pymc3 as pm
import theano.tensor as tt
# Temperature data : 3 cold temperatures and 3 warm temperatures
T_data = np.array([10, 12, 14, 80, 90, 95])
# Data of failures of 3 components : [0,0,1] means component 3 failed
F_data = np.array([[0, 0, 1], \
[0, 0, 1], \
[0, 0, 1], \
[1, 0, 0], \
[1, 0, 0], \
[1, 0, 0]])
n_component = 3
# When temperature is cold : Component 1 fails
# When temperature is warm : Component 3 fails
# Component 2 never fails
# Number of observations :
n_obs = len(F_data)
# The number of failures can be modeled as a Multinomial F ~ M(n_obs, p) with parameters
# - n_test : number of tests (Fixed)
# - p : probability of failure of each component (shape (n_obs, 3))
# The probability of failure of components follows a Dirichlet distribution p ~ Dir(alpha) with parameters:
# - alpha : concentration (shape (n_obs, 3))
# The Dirichlet distributions ensures the probabilities sum to 1
# The alpha parameters (and the the probability of failures) depend on the temperature alpha ~ a + b * T
# - a : bias term (shape (1,3))
# - b : describes temperature dependency of alpha (shape (1,3))
_
# The prior on "a" is a normal distributions with mean 1/2 and std 0.001
# a ~ N(1/2, 0.001)
# The prior on "b" is a normal distribution zith mean 0 and std 0.001
# b ~ N(0, 0.001)
# Coding it all with pymc3
with pm.Model() as model:
a = pm.Normal('a', 1/2, 1/(0.001**2), shape = n_component)
b = pm.Normal('b', 0, 1/(0.001**2), shape = n_component)
# I generate 3 alphas values (corresponding to the 3 components) for each of the 6 temperatures
# I tried different ways to compute alpha but nothing worked out
alphas = pm.Deterministic('alphas', a + b * tt.stack([T_data, T_data, T_data], axis=1))
#alphas = pm.Deterministic('alphas', a + b[None, :] * T_data[:, None])
#alphas = pm.Deterministic('alphas', a + tt.outer(T_data,b))
# I think I should get 3 probabilities (corresponding to the 3 components) for each of the 6 temperatures
#p = pm.Dirichlet('p', alphas, shape = n_component)
p = pm.Dirichlet('p', alphas, shape = (n_obs,n_component))
# Multinomial is observed and take values from F_data
F = pm.Multinomial('F', 1, p, observed = F_data)
with model:
trace = pm.sample(5000)
I get the following error in the sample function:
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 73, in run
self._start_loop()
File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 113, in _start_loop
point, stats = self._compute_point()
File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 139, in _compute_point
point, stats = self._step_method.step(self._point)
File "/anaconda3/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 247, in step
apoint, stats = self.astep(array)
File "/anaconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 117, in astep
'might be misspecified.' % start.energy)
ValueError: Bad initial energy: inf. The model might be misspecified.
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
ValueError: Bad initial energy: inf. The model might be misspecified.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-5-121fdd564b02> in <module>()
1 with model:
2 #start = pm.find_MAP()
----> 3 trace = pm.sample(5000)
/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
438 _print_step_hierarchy(step)
439 try:
--> 440 trace = _mp_sample(**sample_args)
441 except pickle.PickleError:
442 _log.warning("Could not pickle model, sampling singlethreaded.")
/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
988 try:
989 with sampler:
--> 990 for draw in sampler:
991 trace = traces[draw.chain - chain]
992 if trace.supports_sampler_stats and draw.stats is not None:
/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py in __iter__(self)
303
304 while self._active:
--> 305 draw = ProcessAdapter.recv_draw(self._active)
306 proc, is_last, draw, tuning, stats, warns = draw
307 if self._progress is not None:
/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
221 if msg[0] == 'error':
222 old = msg[1]
--> 223 six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
224 elif msg[0] == 'writing_done':
225 proc._readable = True
/anaconda3/lib/python3.6/site-packages/six.py in raise_from(value, from_value)
RuntimeError: Chain 1 failed.
Any suggestions ?
Misspecified model. The
alphas
are taking on nonpositive values under your current parameterization, whereas theDirichlet
distribution requires them to be positive, making the model misspecified.In Dirichlet-Multinomial regression, one uses an exponential link function to mediate between the range of the linear model and the domain of the Dirichlet-Multinomial, namely,
There are details on this in the MGLM package documentation.
Dirichlet-Multinomial Regression Model
If we implement this model we can achieve decent model convergence and sampling.
Model Outcomes
The sampling in this model isn't perfect. For example, there are still a number of divergences even with the increased target acceptance rate and additional tuning.
Trace Plots
We can see the traces for
a
andb
look good, and the mean locations make sense with data.Pair Plot
While correlation is less of a problem for NUTS, having uncorrelated posterior sampling is ideal. For the most part we're seeing low correlation, with some slight structure within the
a
components.Posterior Plots
Finally, we can look at the posterior plots of
p
and confirm they make sense with the data.Alternative Model
The advantage of the Dirichlet-Multinomial is handling overdispersion. It might be worth trying the simpler Multinomial Logisitic Regression / Softmax Regression, since it runs significantly faster and doesn't exhibit any of the sampling problems coming up in the DMR model.
In the end, you could run both and perform model comparison to see if the Dirichlet-Multinomial really is adding explanatory value.
Model
Posterior Plots