I'm trying to implement a beam search decoding strategy in a text generation model. This is the function that I am using to decode the output probabilities.
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
Now you can see this function is implemented with batch_size 1 in mind. Adding another loop for batch size would make the algorithm O(n^4)
. It is slow as it is now. Is there any way to improve the speed of this function. My model output is usually of the size (32, 150, 9907)
which follows the format (batch_size, max_len, vocab_size)
You can use this library
https://pypi.org/project/pytorch-beam-search/
It implements Beam Search, Greedy Search and sampling for PyTorch sequence models.
The following snippet implements a Transformer seq2seq model and uses it to generate predictions.