Will JAGS evaluate all parent nodes of dcat, or only the one needed?

1.4k views Asked by At

Say we have the following statement:

for (i in 1:N) {
    pi[i,1] <- ....
    pi[i,2] <- ....
    pi[i,3] <- ....
    ...
    pi[i,100] <- ...
    Y[i] ~ dcat(p[i,])
}

Let's say that Y[1] = 5. Will jags evaluate all the pi[1,1:100] nodes, or the only one needed, i.e. pi[1,5]?

From my experience, it seems that JAGS is inefficiently evaluating all of the parent nodes, because my model was sped up 3x times after I got rid of the dcat. I got to use multiple for loops though for different outcomes of Y[i].

Now I realized that dcat in JAGS actually doesn't require that sum(pi[]) = 1, and that dcat will normalize pi[] so that it sums to 1. This means that it must evaluate all of the nodes.

This is very sad. Is there any smart equivalent of dcat that will only evaluate the only one parent node which is needed? What about WinBUGS and Stan?

2

There are 2 answers

2
Martyn Plummer On

Your exampled does not quite have enough detail for me to answer. I have added some expressions on the right hand side:

for (i in 1:N) {
    pi[i,1] <- funx(alpha)
    pi[i,2] <- funy(alpha)
    pi[i,3] <- funz(beta)
    ...
    pi[i,100] <- funw(beta)
    Y[i] ~ dcat(p[i,])
}

Suppose we are updating the node alpha, then the sampler that is responsible for updating alpha needs to evaluate funx(alpha) and funy(alpha) but not funz(beta) or funw(beta) (assuming that beta is not a deterministic function of alpha). So in this case pi[i,1] and pi[i,1] are evaluated but not pi[i,3] or pi[i,100]. These other nodes retain their current value.

However, for the likelihood calculations we do have to dereference the current value of all the nodes p[i,1] to p[i,100] to calculate the sum and normalize p. Dereferencing is cheap, but if you do it enough times then it becomes expensive. For example, if you have

for (i in 1:N) {
    for (j in 1:M) {
        pi[i,j] ~ dexp(1)
     }
    Y[i] ~ dcat(p[i,])
}

then you have N*M*M dereferencing operations per iteration which can soon add up.

So I guess what you are asking for is a sampler that caches the sum of p[i,] for the likelihood calculations and then updates it based only on the elements that have changed, avoiding the need to dereference the others. That is not available in JAGS, but it might be possible to work towards it in some future versions.

0
Matt Denwood On

I think you can do what you are asking by just using dbern, i.e.:

for(i in 1:N){
    pi[i,1] <- ...
    ...
    pi[i,100] <- ...

    Ones[i] ~ dbern(pi[i,Y[i])
}

Where Ones[] is specified in data as an N-length vector of 1.

However, all of pi[] will still be calculated - it has to be because it is a node in your model, and JAGS (or WinBUGS/stan) has no way of telling which nodes you care about and which you don't. You may be able to avoid this by having one value of pi[] for each i and shifting the use of the Y[i] index inside the right hand side of the pi[i] equation - although as Martyn says your example doesn't give enough detail to determine if this is possible.

Matt