tf.numpy_function has None output shape using Lambda layer

112 views Asked by At

I am trying to add custom numpy layer which will serve as an activation function in my model. I use Lambda layer and tf.numpy_function function. However, the output shape of this layer can't be recognized by Tensorflow and is marked as None. Is there any way to specify output shape of this layer in order to use the output of it in further layers? The logic of this layer can't be implemented with in-built tf/keras functions.

I face this issue even with simple example:

import numpy as np
import tensorflow as tf
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Lambda


def my_numpy_func(x):
    # my function logic will be here
    return x


@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def tf_function(input):
    y = tf.numpy_function(my_numpy_func, [input], tf.float32)
    return y


inp = Input(shape=(28, 28, 1))
my_layer = Lambda(tf_function)(inp)
model = Model(inp, my_layer)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
lambda (Lambda)              None                      0         
=================================================================

So if I add more layers like maxpool2d etc., these layers can't read the output of lambda layer since it has None shape.

I tried to reshape the output of tf_function with tf.reshape(y, input.shape) and tf.reshape(y, (None, 2048, 1)), expand dims with tf.expand_dims(y, axis=-1) but got the following error:

Error: Cannot convert a partially known TensorShape (None, 26, 26, 32) to a Tensor.

I also tried to explicitly specify shape of input parameter of tf_function with @tf.function(input_signature=[tf.TensorSpec(**(None, 28, 28, 1)**, tf.float32)]), but it didn't help.

0

There are 0 answers