Empty value for state['args'] when I load model with BARTmodel.from_pretained()

69 views Asked by At

First I train the model with code:

module load libffi
source $HOME/env38/bin/activate

​
TOTAL_NUM_UPDATES=20000
WARMUP_UPDATES=500      
LR=3e-05
MAX_TOKENS=2048
UPDATE_FREQ=2
​
BART_PATH=models/BART_models/bart.large/model.pt
DATA_PATH=wikihow-cmlm-dataset/wikihow-cmlm-dataset-bin
SAVE_DIR=wikihow-cmlm/wikihow-cmlm-model/
mkdir $SAVE_DIR
    
​
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train $DATA_PATH \
    --max-epoch 3 \
    --restore-file $BART_PATH \
    --save-dir $SAVE_DIR \
    --max-tokens $MAX_TOKENS \
    --task translation \
    --source-lang source --target-lang target \
    --truncate-source \
    --layernorm-embedding \
    --share-all-embeddings \
    --share-decoder-input-output-embed \
    --reset-optimizer --reset-dataloader --reset-meters \
    --required-batch-size-multiple 1 \
    --arch bart_large \
    --criterion label_smoothed_cross_entropy \
    --label-smoothing 0.1 \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
    --clip-norm 0.1 \
    --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
    --fp16 --update-freq $UPDATE_FREQ \
    --skip-invalid-size-inputs-valid-test \
    --find-unused-parameters;

Then when I load the trained model with fairseq.BARTmodel.from_pretrained() method, there is an error:

AttributeError                            Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 bart = BARTModel.from_pretrained('models/wikihow-cmlm-model',
      2                                  checkpoint_file='checkpoint3.pt',
      3                                  data_name_or_path=DATA_NAME_OR_PATH)

File ~/miniconda3/lib/python3.8/site-packages/fairseq/models/bart/model.py:115, in BARTModel.from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, bpe, **kwargs)
    104 @classmethod
    105 def from_pretrained(
    106     cls,
   (...)
    111     **kwargs,
    112 ):
    113     from fairseq import hub_utils
--> 115     x = hub_utils.from_pretrained(
    116         model_name_or_path,
    117         checkpoint_file,
    118         data_name_or_path,
    119         archive_map=cls.hub_models(),
    120         bpe=bpe,
    121         load_checkpoint_heads=True,
    122         **kwargs,
    123     )
    124     return BARTHubInterface(x["args"], x["task"], x["models"][0])

File ~/miniconda3/lib/python3.8/site-packages/fairseq/hub_utils.py:70, in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs)
     67 if "user_dir" in kwargs:
     68     utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
---> 70 models, args, task = checkpoint_utils.load_model_ensemble_and_task(
     71     [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
     72     arg_overrides=kwargs,
     73 )
     75 return {
     76     "args": args,
     77     "task": task,
     78     "models": models,
     79 }

File ~/miniconda3/lib/python3.8/site-packages/fairseq/checkpoint_utils.py:280, in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards)
    278 if not PathManager.exists(filename):
    279     raise IOError("Model file not found: {}".format(filename))
--> 280 state = load_checkpoint_to_cpu(filename, arg_overrides)
    281 if shard_idx == 0:
    282     args = state["args"]

File ~/miniconda3/lib/python3.8/site-packages/fairseq/checkpoint_utils.py:232, in load_checkpoint_to_cpu(path, arg_overrides)
    230 if arg_overrides is not None:
    231     for arg_name, arg_val in arg_overrides.items():
--> 232         setattr(args, arg_name, arg_val)
    233 state = _upgrade_state_dict(state)
    234 return state

AttributeError: 'NoneType' object has no attribute 'bpe'

And then, I study what is going wrong and I found that when I load the checkpoint file, the state dictionary does contain key 'args', but the value is None.

This is the code which could reproduce this error:

import json
import torch

from tqdm import tqdm
from fairseq.models.bart import BARTModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

CMLM_MODEL_PATH = 'models/wikihow-cmlm-model'
MLM_MODEL_PATH = 'models/bart.large'

# DATA_NAME_OR_PATH = 'models/wikihow-bin'

bart = BARTModel.from_pretrained('models/wikihow-cmlm-model',
                                 checkpoint_file='checkpoint3.pt',
                                 data_name_or_path=CMLM_MODEL_PATH)

and then the error will raise

How could I solve it? thank you for any help

I tried to use exactly same way to load a similar model, it works, the content of the two path MLM MODEL PATH and CMLM MODEL PATH are exactly same, except the names of corresponding model .pt files.

prior_bart=BARTModel.from_pretrained(MLM_MODEL_PATH,
                                     checkpoint_file='model.pt',
                                     data_name_or_path=MLM_MODEL_PATH)

so the reason I think is when I train the previous model, I did not save the arguments, but I checked fairseq-train's documentation, and I did not find any arguments which indicates to save what I have missed.

Thank you for all replies.

0

There are 0 answers