Linked Questions

Popular Questions

Pytorch - store weight updates for momemntum

Asked by At

I am trying to implement momentum in my implementation of SGD with momentum. From my understanding this update look like this:

parameters -= (lr * (p.grad*0.1 + p_delta_prev*0.9))

My question is how I should store my previous deltas from every update

here is what I have in my update function:

#we now want to do the update with momentum
#momentum takes derivative, multiplies it by 0.1, then takes the previous update,
#multiplies it by 0.9 and we add the two together
#alpha = 0.1, beta = 0.9;  p-=grad*0.1 + p*0.9
def update(x,y,lr):
    wd = 1e-5
    y_hat = model(x)
    # weight decay
    w2 = 0.
    for p in model.parameters(): w2 += (p**2).sum()
    # add to regular loss
    loss = loss_func(y_hat, y) + w2*wd
    loss.backward()
    with torch.no_grad():
        for p in model.parameters():
            #p.grad is the slope of the line of that parameter
            #current_p-previous_p to get difference
            p_update = (lr * (p.grad*0.1 + p*0.9))
            p.sub_(p_update)
            p.grad.zero_()
    return loss.item()

Here the p*0.9 should be replace by the p_delta_prev. But how should I store these deltas for every parameter? If I save them to a tensor wouldn't I would be effectively copying the weight deltas to memory making my model two times the size. What would be a good way to accomplish this? I do not want to use an inbuilt function that does that activation for me. I did look into the pytorch sgd.py and it looks like the store the states.

EDIT: I have updated the code:

#we now want to do the update with momentum
#momentum takes derivative, multiplys it by 0.1, then takes the previous update,
#multiplies it by 0.9 and we add the two together
#alpha = 0.1, beta = 0.9;  p-=grad*0.1 + p*0.9
p_delta = {}
def update(x,y,lr):
    wd = 1e-5
    y_hat = model(x)
    # weight decay
    w2 = 0.
    for p in model.parameters(): w2 += (p**2).sum()
    # add to regular loss
    loss = loss_func(y_hat, y) + w2*wd
    loss.backward()
    with torch.no_grad():
        i = 0
        for p in model.parameters():
            #p.grad is the slope of the line of that parameter
            if i not in p_delta:#check if key exists
                p_delta[i] = torch.zeros_like(p)
            p_update = (lr *p.grad) + (p_delta[i]*0.9)
            p_delta[i] = p_update.clone()
            p.sub_(p_update)
            p.grad.zero_()
            print((p_delta[i]))
            i+=1
    return loss.item()

I think the code in the excel spreadsheet is incorrect. Jeremy seems to show: lr* ((p.grad*0.1) + (p_delta[i]*0.9)) but many tutorials seem to show: (lr *p.grad) + (p_delta[i]*0.9) If we implement Jeremy’s code the loss actually is slower than vanilla GD. The part of the video is here: https://youtu.be/CJKnDu2dxOE?t=6581

Related Questions