what is the role of RNNLanguageModel's forward method?

94 views Asked by At

i'm reading a tutorial about character based neural networks using AllenNlp framework, the goal is building a model which can complete a sentence. there is a step of instances building after that i want to train my model. i have the code below, i could not understand the role of forward function, anyone can help ? could someone provide an example

class RNNLanguageModel(Model):
def __init__(self,
             embedder: TextFieldEmbedder,
             hidden_size: int,
             max_len: int,
             vocab: Vocabulary) -> None:
    super().__init__(vocab)

    self.embedder = embedder

    # initialize a Seq2Seq encoder, LSTM
    self.rnn = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))

    self.hidden2out = torch.nn.Linear(in_features=self.rnn.get_output_dim(), out_features=vocab.get_vocab_size('tokens'))
    self.hidden_size = hidden_size
    self.max_len = max_len

def forward(self, input_tokens, output_tokens):
    '''
    This is the main process of the Model where the actual computation happens. 
    Each Instance is fed to the forward method. 
    It takes dicts of tensors as input, with same keys as the fields in your Instance (input_tokens, output_tokens)
    It outputs the results of predicted tokens and the evaluation metrics as a dictionary. 
    '''

    mask = get_text_field_mask(input_tokens)
    embeddings = self.embedder(input_tokens)
    rnn_hidden = self.rnn(embeddings, mask)
    out_logits = self.hidden2out(rnn_hidden)
    loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)

    return {'loss': loss}
1

There are 1 answers

0
petew On

The forward() method is where we implement the "forward pass" of a model. This determines how inputs (your data) flow through your model to produce outputs and a loss value.

The forward() method is required to be implemented by any class that inherits from a PyTorch Module, such as AllenNLP's Model class.

AllenNLP is ultimately just a higher level wrapper of PyTorch, so if you're confused by any of this, I suggest you start by familiarizing yourself more with PyTorch: https://pytorch.org/tutorials/