Workaround / fallback value for tfp.distributions.Categorical.log_prob in tensorflow graph mode

299 views Asked by At

Is there a way to avoid tfp.distributions.Categorical.log_probraising an error if the input is a label out of range?

I am passing a batch of samples to the log_prob method, some of them have the value n_categories + 1, which is what you get as fallback value when you sample from a probability distribution off all zeros. Some of the probability distributions in my probs batch are all zeros**.

dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action

I don't care what value I get from log_probin those cases because later I will mask it and not use it. Not sure if a fallback value can be implemented somehow. If not, is there any workaround to avoid an error to be raised while I execute it in graph mode (with @tf.function)?

**This is because I am doing stochastic decoding with an RNN of batches of sequences of variable length, a seq to seq task.

1

There are 1 answers

0
Brian Patton On

If you can mask the log_prob, you could also mask the probs to, say, 1 / n. Note, it's more numerically stable to use the logits parameterization of Categorical and drop the (presumably) upstream softmax activation.