Weighted mse custom loss function in keras

9k views Asked by At

I'm working with time series data, outputting 60 predicted days ahead.

I'm currently using mean squared error as my loss function and the results are bad

I want to implement a weighted mean squared error such that the early outputs are much more important than later ones.

Weighted Mean Square Root formula:

Weighted Mean Square Root formula

So I need some way to iterate over a tensor's elements, with an index (since I need to iterate over both the predicted and the true values at the same time, then write the results to a tensor with only one element. They're both (?,60) but really (1,60) lists.

And nothing I'm trying is working. Here's the code for the broken version

def weighted_mse(y_true,y_pred):
    wmse = K.cast(0.0,'float')

    size = K.shape(y_true)[0]
    for i in range(0,K.eval(size)):
        wmse += 1/(i+1)*K.square((y_true[i]-y_pred)[i])

    wmse /= K.eval(size)
    return wmse

I am currently getting this error as a result:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'dense_2_target' with dtype float
 [[Node: dense_2_target = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Having read the replies to similar posts, I don't think a mask can accomplish the task, and looping over elements in one tensor would also not work since I'd not be able to access the corresponding element in the other tensor.

Any suggestions would be appreciated

1

There are 1 answers

7
Daniel Möller On BEST ANSWER

You can use this approach:

def weighted_mse(yTrue,yPred):

    ones = K.ones_like(yTrue[0,:]) #a simple vector with ones shaped as (60,)
    idx = K.cumsum(ones) #similar to a 'range(1,61)'


    return K.mean((1/idx)*K.square(yTrue-yPred))

The use of ones_like with cumsum allows you to use this loss function to any kind of (samples,classes) outputs.


Hint: always use backend functions when working with tensors. You can use slices, but avoid iterating.