How can I use decaying learning rate in DeepSpeed?

410 views Asked by At

I am training dolly2.0.

When I do so, I get the following output from the terminal:

Screenshot of dolly2.0

If I use DeepSpeed to perform this training, I note that the learning rate didn't improve:

Screenshot of DeepSpeed

Why didn't the learning rate improve?


This is the DeepSpeed config that I use

{
    "fp16": {
      "enabled": false
    },
    "bf16": {
      "enabled": true
    },
    "optimizer": {
      "type": "AdamW",
      "params": {
        "lr": "auto",
        "betas": "auto",
        "eps": "auto",
        "weight_decay": "auto"
      }
    },
    "scheduler": {
      "type": "WarmupLR",
      "params": {
        "warmup_min_lr": "auto",
        "warmup_max_lr": "auto",
        "warmup_num_steps": "auto"
      }
    },
    "zero_optimization": {
      "stage": 3,
      "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
      },
      "offload_param": {
          "device": "cpu",
          "pin_memory": true
      },
      "overlap_comm": true,
      "contiguous_gradients": true,
      "sub_group_size": 1e9,
      "reduce_bucket_size": "auto",
      "stage3_prefetch_bucket_size": "auto",
      "stage3_param_persistence_threshold": "auto",
      "stage3_max_live_parameters": 1e9,
      "stage3_max_reuse_distance": 1e9,
      "stage3_gather_16bit_weights_on_model_save": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
  }
1

There are 1 answers

0
DoneForAiur On

According to the docs, you can use decaying learning rate like this:

"optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  }

Keep in mind that the default behaviour is to decrease the learning rate every epoch, not every step according to this GitHub answer..