Does anyone know of a way to generate a 'segment label' for a Tensor, given a unique value that represents segment boundaries within the Tensor?
For example, given a 1D input tensor where the value 1
represents a segment boundary,
x = torch.Tensor([5, 4, 1, 3, 6, 2])
the resulting segment label Tensor should have the same shape with values representing the two segments:
segment_label = torch.Tensor([1, 1, 1, 2, 2, 2])
Likewise, for a batch of inputs, e.g. batch size = 3,
x = torch.Tensor([
[5, 4, 1, 3, 6, 2],
[9, 4, 5, 1, 8, 10],
[10, 1, 5, 4, 8, 9]
])
the resulting segment label Tensor (using 1
as the segment separator) should look something like this:
segment_label = torch.Tensor([
[1, 1, 1, 2, 2, 2],
[1, 1, 1, 1, 2, 2],
[1, 1, 2, 2, 2, 2]
])
Context: I'm currently working with Fairseq's Transformer implementation in PyTorch for a seq2seq NLP task. I am looking for a way to incorporate BERT-like segment embeddings in Transformer during the encoder's forward pass, rather than modifying an exisiting dataset used for translation tasks such as language_pair_dataset
.
Thanks in advance!
You can use
torch.cumsum
to pull the trick:Results with the desired
segment_label
.