How can I create a custom loss function in keras ? (Custom Weighted Binary Cross Entropy)

1.1k views Asked by At

I'm creating a fully convolutional neural network, which given an image in input is capable to identify zones in it (black, 0), and also identify background (white, 255). My targets are binarized images (ranging 0-255), and I'd like to get some balancing between my two semantic classes (0 or 255). in fact I get 1.8 times more "special" zones (0), than background zones (255), so I need to counterbalance this effect, and I'd like to penalize more the fact to make an error on the background, to avoid having a prediction of only special zones.

I tried to follow some topics about it, it doesn't seem to be very hard, but I get stuck in my implementation,i don't really know why.

Each time my implementations works on the compiling stage, but only in the fitting step it returns an error.
Here's what I tried so far :

import keras.backend as kb    
def custom_binary_crossentropy(y_true, y_pred):
        """
        Used to reequilibrate the data, as there is more black (0., articles), than white (255., non-articles) on the pages.
        """
        if y_true >=128:   # Half the 0-255 range
            return -1.8*kb.log(y_pred/255.)
        else:
            return -kb.log(1-(y_pred/255.))

But it returned :

InvalidArgumentError:  The second input must be a scalar, but it has shape [16,256,256]
     [[{{node gradient_tape/custom_binary_crossentropy/cond/StatelessIf/gradient_tape/custom_binary_crossentropy/weighted_loss/Mul/_17}}]] [Op:__inference_train_function_24327]

Function call stack:
train_function

I don't really understand this error.

I had tried previously :

def custom_binary_crossentropy(y_true, y_pred):
    """
    Used to reequilibrate the data, as there is more black (0., articles), than white (255., non-articles) on the pages.
    """
    if y_true >=128:   # Half the 0-255 range
        return 1.8*keras.losses.BinaryCrossentropy(y_true, y_pred)
    else:
        return keras.losses.BinaryCrossentropy(y_true, y_pred)

But I got :

TypeError: in user code:

    /Users/axeldurand/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    <ipython-input-67-7b6815236f63>:6 custom_binary_crossentropy  *
        return -1.8*keras.losses.BinaryCrossentropy(y_true, y_pred)

    TypeError: unsupported operand type(s) for *: 'float' and 'BinaryCrossentropy'

I'm a bit confused, Keras always makes it so easy, I must omit something easy but I don't really get it.

2

There are 2 answers

0
Durand On BEST ANSWER

Thanks a lot @qmeeus, you showed me the path to success !
I did not know the difference between keras.losses.BinaryCrossentropy and keras.losses.binary_crossentropy, but it is a major one.

Here's how I did :

def custom_binary_crossentropy(y_true, y_pred):
    """
    Used to reequilibrate the data, as there is more black (0., articles),
    than white (255. (recalibrated to 1.), non-articles) on the pages.
    """
    # I put 0 so that the shape is (batch_size, 256, 256)
    # and not (batch_size, 256, 256, 1)
    is_white = y_true[:,:,:,0]>=0.5 
    white_error = 1.8*keras.losses.binary_crossentropy(y_true, y_pred)
    black_error = keras.losses.binary_crossentropy(y_true, y_pred)
    # Returns the right loss for each type of error.
    # We do make twice the calculation but I did not find a better way for now
    return tf.where(is_white, white_error, black_error)

I did not know the use of tf.where, but it is extremely useful. I saw this tutorial on the excellent book by Aurélien Géron, Machine learning with Keras and TensorFlow.

Simply use next :

# Compiling using this function
model.compile(optimizer="rmsprop", loss=custom_binary_crossentropy)

Then fit your model using your data and favourite hyperparameters and you're good to go !

0
qmeeus On

You are using keras.losses.BinaryCrossentropy in the wrong way. You actually want the functional version of this loss, which is tf.keras.losses.binary_crossentropy

see https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy and https://www.tensorflow.org/api_docs/python/tf/keras/losses/binary_crossentropy