Is there a way to avoid tfp.distributions.Categorical.log_prob
raising an error if the input is a label out of range?
I am passing a batch of samples to the log_prob
method, some of them have the value n_categories + 1
, which is what you get as fallback value when you sample from a probability distribution off all zeros. Some of the probability distributions in my probs
batch are all zeros**.
dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action
I don't care what value I get from log_prob
in those cases because later I will mask it and not use it. Not sure if a fallback
value can be implemented somehow. If not, is there any workaround to avoid an error to be raised while I execute it in graph mode (with @tf.function
)?
**This is because I am doing stochastic decoding with an RNN of batches of sequences of variable length, a seq to seq task.
If you can mask the log_prob, you could also mask the probs to, say, 1 / n. Note, it's more numerically stable to use the logits parameterization of Categorical and drop the (presumably) upstream softmax activation.