How to customize the number of encoders/decoders in a pre-trained transformer

45 views Asked by At

I am implementing a pretrained transformer model using Python's transformer module to perform text summarization and I would like to compare the performance of the fine-tuned BART transformer given different number of encoders. My question is, how can I customize the number of encoders? The default transformer has 12 encoders, what if say I want to keep only the first 6 encoders? I found the following documentation for BART but I have no idea how to adapt it to my code (see below). I am new to ML and NLP so I'd be grateful if you could provide me with detailed explanation with code, Thank you!

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

tokenizer = AutoTokenizer.from_pretrained(model_checkpoints)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoints)
collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# preprocessing step omitted
# tokenized_data = preprocessed data

args = transformers.Seq2SeqTrainingArguments(
    'conversation-summ',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size= 1,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    eval_accumulation_steps=1,
    fp16=True
    )

trainer = transformers.Seq2SeqTrainer(
    model, 
    args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['validation'],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_rouge
)

trainer.train()
1

There are 1 answers

0
inverted_index On

Customizing the number of encoders in a pre-trained BART transformer model involves modifying the architecture of the model. To do so, you should first load the pre-trained model:

from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large')

To modify the encoder, you need to access the encoder of the BART model and then keep only the first 6 layers:

# Keeping only the first 6 encoder layers
model.model.encoder.layers = model.model.encoder.layers[:6]

Then, you'll need to update the configuration:

model.config.encoder_layers = 6