How exactly Tensor masking and indexing should be done in Tensorflow?

353 views Asked by At

I have been using TF for 2 years now and at each project, I have lots of non-sense error popping out for masking which usually are not helpful and don't indicate what actually is wrong. or worst than that, the result is wrong but no error. I always test the code outside the training loop with dummy data and it's fine. but in the training (calling fit), I don't understand what TensorFlow expects exactly. just for one example, can someone experienced please tell me why this code does not work for a binary cross-entropy, the result is wrong and model does not converge but no error in this case:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

while this works correctly:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

and for a categorical example, the opposite is true. I can't use the mask as an index like y_pred[mask], or y_pred[mask[0]], or using tf.squeeze() and so on. but using tf.gather_nd() works. I always try all the combination that I think is possible, I just don't get it why something so simple should be this hard and painful. is Pytorch like this too? I'm happy to switch if you know Pytorch doesn't have similar annoying details.

EDIT 1: They work correctly outside training loop, or graph mode to be more exact.

y_pred = tf.random.uniform(shape=[10,], minval=0, maxval=1, dtype='float32')
y_true = tf.random.uniform(shape=[10,], minval=0, maxval=2, dtype='int32')

# first method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

# instantiate
mbxe = MaskedBXE()
print(f'first snippet: {mbxe(y_true, y_pred).numpy()}')


# second method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
# instantiate
mbxe = MaskedBXE()
print(f'second snippet: {mbxe(y_true, y_pred).numpy()}')

first snippet: 1.2907861471176147

second snippet: 1.2907861471176147

EDIT 2: After printing losses in graph mode as @jdehesa suggeted, they differ, which they shouldn't:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        # first
        y_t = tf.squeeze(y_true)
        mask = tf.where(y_t!=2)
        y_t = tf.gather_nd(y_t, mask)
        y_p = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        first_loss =  tf.reduce_mean(loss)
        tf.print('first:')
        tf.print(first_loss, summarize=-1)
        # second
        mask = tf.where(y_true!=2, True, False)
        y_t = y_true[mask]
        y_p = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        second_loss = tf.reduce_mean(loss)
        tf.print('second:')
        tf.print(second_loss, summarize=-1)
        return second_loss

first:

0.814215422

second:

0.787778914

first:

0.779697835

second:

0.802924752

. . .

1

There are 1 answers

2
javidcf On BEST ANSWER

I think the problem is that you are inadvertently doing broadcasted operations in the first version, which is giving you the wrong result. This will happen if you batches have shape (?, 1), because of the tf.squeeze operation. Note the shapes in this example

import tensorflow as tf

# Make random y_true and y_pred with shape (10, 1)
tf.random.set_seed(10)
y_true = tf.dtypes.cast(tf.random.uniform((10, 1), 0, 3, dtype=tf.int32), tf.float32)
y_pred = tf.random.uniform((10, 1), 0, 1, dtype=tf.float32)

# first
y_t = tf.squeeze(y_true)
mask = tf.where(y_t != 2)
y_t = tf.gather_nd(y_t, mask)
tf.print(tf.shape(y_t))
# [7]
y_p = tf.gather_nd(y_pred, mask)
tf.print(tf.shape(y_p))
# [7 1]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
first_loss =  tf.reduce_mean(loss)
tf.print(tf.shape(loss), summarize=-1)
# [7]
tf.print(first_loss, summarize=-1)
# 0.884061277

# second
mask = tf.where(y_true!=2, True, False)
y_t = y_true[mask]
tf.print(tf.shape(y_t))
# [7]
y_p = y_pred[mask]
tf.print(tf.shape(y_p))
# [7]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
tf.print(tf.shape(loss), summarize=-1)
# []
second_loss = tf.reduce_mean(loss)
tf.print(second_loss, summarize=-1)
# 1.15896356

In the first version, both y_t and y_p become broadcasted into 7x7 tensors so the cross-entropy is basically computed "all vs all", and then averaged. In the second case the cross-entropy is only calculated for each pair of corresponding values, which is the correct thing to do.

If you simply remove the tf.squeeze operation in the example above the result is corrected.