Custom GRU implementation performing very slow

42 views Asked by At

I am working on customizing the GRU layer to suit my specific requirements. To achieve this, I am implementing a custom GRU layer following the architecture and implementation of the GRU layer in Keras.

However, I noticed that when I experiment with the custom GRU layer alone, it takes around 5 times more time to execute compared to the original Keras GRU layer. Additionally, when I integrate the custom GRU layer into my model, the execution time increases to around 25 times more than using the original Keras GRU layer. This is how I customly implemented GRU portion.

class CustomGRU(tf.keras.layers.Layer):
    def __init__(self, units,  activation='tanh', recurrent_activation='sigmoid',kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, constraints=None, **kwargs):
        super(CustomGRU, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.recurrent_activation = tf.keras.activations.get(recurrent_activation)
        self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = tf.keras.regularizers.get(recurrent_regularizer)
        self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self.constraint = constraints

    def build(self, input_shape):
        self.input_dim = input_shape[-1]
        #self.default_caching_device = rnn_utils.caching_device(self)

        self.W = self.add_weight(shape=(self.input_dim, self.units * 3),
                                  initializer='glorot_uniform', regularizer=self.kernel_regularizer, constraint = self.constraint
        )

        self.U = self.add_weight(shape=(self.units, self.units * 3),
                                  initializer='orthogonal', regularizer=self.recurrent_regularizer, constraint = self.constraint
        )

        bias_shape = (2, 3 * self.units)
        self.bias = self.add_weight(shape=bias_shape,
                                    initializer='zeros', regularizer=self.bias_regularizer, constraint = self.constraint
          )

    def call(self, inputs, states=None):
        # Unstack the inputs along the time dimension
        h_tm1 = states[0]
        input_bias, recurrent_bias = tf.unstack(self.bias)

        x_z = tf.linalg.matmul(inputs, self.W[:, :self.units]) + input_bias[:self.units]
        x_r = tf.linalg.matmul(inputs, self.W[:, self.units:self.units * 2]) + input_bias[self.units:self.units * 2]
        x_h = tf.linalg.matmul(inputs, self.W[:, self.units * 2:]) + input_bias[self.units * 2:]

        recurrent_z = tf.linalg.matmul(h_tm1, self.U[:, :self.units]) + recurrent_bias[:self.units]
        recurrent_r = tf.linalg.matmul(h_tm1, self.U[:, self.units:self.units * 2]) + recurrent_bias[self.units:self.units * 2]
        recurrent_h = tf.linalg.matmul(h_tm1, self.U[:, self.units * 2:]) + recurrent_bias[self.units * 2:]

        z = self.recurrent_activation(x_z + recurrent_z)
        r = self.recurrent_activation(x_r + recurrent_r)

        recurrent_h = r * recurrent_h
        hh = self.activation(x_h + recurrent_h)

        # Previous and candidate state mixed by update gate
        h = z * h_tm1 + (1 - z) * hh
        h_tm1 = h
       
        return h,h

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        if inputs is not None:
            batch_size = tf.shape(input_data)[0]
            states = tf.zeros((batch_size, self.units))
        #print("internal states ", type(states))
        return states

    @property
    def state_size(self):
        return self.unitsype here
custom_gru_layer = CustomGRU(4,  activation='tanh', recurrent_activation='sigmoid')
batch_size = tf.shape(input_data)[0]
initial_state = tf.zeros((batch_size, custom_gru_layer.units))
rnn_layer = tf.keras.layers.RNN(custom_gru_layer,  return_sequences=True, return_state=True)
custom_gru_output,state = rnn_layer(input_data, initial_state=initial_state)

What are the possible reasons for getting such poor performance?

0

There are 0 answers