Encoder-Decoder with Huggingface Models

53 views Asked by At

I want to create an Encoder-Decoder Model using the following structure:

I want to do this to basically implement the idea I read about in the In-Context Autoencoder paper and test it out myself (https://arxiv.org/abs/2307.06945)

I would like to do this with the huggingface library using PyTorch as it helps to minimize the programming efforts a lot and because I do not know where I would even get the raw implementations of the OPT-125M or BERT model and how to implement them by hand. Also the optimization of huggingface plays a big role to try it on a normal desktop-PC.

My problem is that the OPT-125M model uses a tokenizer for inputs and I am not able to bypass this.

Does anyone know of a way to directly input the output of the linear layer into OPT-125M without encoding it, or a different way of implementing it other than huggingface which is also as performant?

This is the skeleton code that I have already written which produces an error because of the wrong input to OPT:

from transformers import BertTokenizer, BertModel, AutoModelForCausalLM
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
OPT = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
import torch
from torch import nn

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, input_text):
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # CLS token embeddings

class LinearTransformation(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearTransformation, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

    def forward(self, x):
        # Assuming x is prepared correctly for the OPT model
        output = self.model(input_ids=x)
        return output

class BertOptPipeline(nn.Module):
    def __init__(self):
        super(BertOptPipeline, self).__init__()
        self.encoder = Encoder()
        self.linear_transformation = LinearTransformation(768, 512)
        self.decoder = Decoder()

    def forward(self, input_text):
        encoded = self.encoder(input_text)
        transformed = self.linear_transformation(encoded)
        print(transformed.shape)
        # Further processing may be needed here to match the decoder's input requirements
        decoded = self.decoder(transformed)
        return decoded

pipeline = BertOptPipeline()
input_text = "thank you for your help"
output = pipeline(input_text)

Thanks for your help!

0

There are 0 answers