I'm implementing a custom tf.keras.layers.Layer that needs to support masking.

Consider the following scenario

embedded = tf.keras.layer.Embedding(input_dim=vocab_size + 1, 
x = MyCustomKerasLayers(embedded)

Now per the documentation

mask_zero: Whether or not the input value 0 is a special "padding" value that should be masked out. This is useful when using recurrent layers which may take variable length input. If this is True then all subsequent layers in the model need to support masking or an exception will be raised. If mask_zero is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1).

I wonder, what does that mean? Looking through TensorFlow's custom layers guide and the tf.keras.layer.Layer documentation it is not clear what should be done to support masking

  1. How do I support masking?

  2. How do I access the mask from the past layer?

  3. Assuming input of (batch, time, channels) or `(batch, time) would the masks look different? What will be their shapes?

  4. How do I pass it on to the next layer?

1 Answers

bluesummers On Best Solutions
  1. To support masking one should implement the compute_mask method inside the custom layer

  2. To access the mask, simply add as the second positional argument in the call method the argument mask, and it will be accessible (ex. call(self, inputs, mask=None))

  3. This cannot be guessed, it is the layer's before responsible to calculate the mask

  4. Once you implemented the compute_mask passing the mask to the next layer happens automatically


class MyCustomKerasLayers(tf.keras.layers.Layer):
    def __init__(self, .......):

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer or 
        # manipulate it if this layer changes the shape of the input
        return mask

    def call(self, input, mask=None):
        # using 'mask' you can access the mask passed from the previous layer

Notice that this example just passes on the mask, if the layer will output a shape different than the one received, you should change the mask accordingly in compute_mask to pass on the correct one