Resume training from a checkpoint with different hyperparameters when training with PEFT and transformers

104 views Asked by At

I was wondering how you can resume training from a checkpoint with different hyperparameter config when training with transformers library. Given the example below, no matter what you change in the training_args, these will be overridden by whatever training args are saved in the checkpoint. The transformers library does not have the ability to change training arguments when resuming from a checkpoint. Some things like eval, batch_size and save_steps are overridable if you amend the checkpoint's JSON config, but other hyperparameters are not.

Given a non-PEFT model, you could just save the entire model from the checkpoitn, load it up and call trainer.train() on it to achieve this behaviour, but given a PEFT setup I'm not sure how you can do this?

from peft import prepare_model_for_kbit_training

ft_model.gradient_checkpointing_enable()
ft_model = prepare_model_for_kbit_training(model)

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "w1",
        "w2",
        "w3",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

print_trainable_parameters(ft_model)
ft_model = accelerator.prepare_model(ft_model)

import transformers
from datetime import datetime

tokenizer.pad_token = tokenizer.eos_token

learning_rate = 5e-5  
warmup_steps = 100

gradient_accumulation_steps = 2  

trainer = transformers.Trainer(
    model=model,
    callbacks=[upload_checkpoint_callback],
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=warmup_steps,
        per_device_train_batch_size=8,
        gradient_checkpointing=True,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_steps=5000,
        learning_rate=learning_rate,
        logging_steps=10,
        fp16=True,
        optim="paged_adamw_8bit",
        logging_dir="/content/logs",       
        save_strategy="steps",      
        save_steps=10,              
        evaluation_strategy="steps", 
        eval_steps=10,               
        load_best_model_at_end=True,
        report_to="wandb",           
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"  
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False
trainer.train(resume_from_checkpoint="/content/latest_checkpoint/")
0

There are 0 answers