How to provide two inputs to a Tensorflow Model using ragged tensors?

292 views Asked by At

I am trying to create a model with two inputs. The model is very simple containing only one lstm layer for each input. The problem is that I want to provide lists of different length as inputs. For that, I am using ragged tensors, but the training process fails.

ds = pd.DataFrame({"col_1":[[0],[0,0],[0,0,0],[0,0,0,0],[0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0]],"col_2":[8*[0],7*[1],6*[2],5*[3],4*[4],3*[5],2*[6],1*[7]]})
ds = ds.loc[ds.index.repeat(1250)].reset_index(drop=True)
ds = ds.sample(frac=1, random_state=43).reset_index(drop=True)

feat_1_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_1")]
feat_1 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_1_inputs[0])

feat_2_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_2")]
feat_2 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_2_inputs[0])

concat_inputs = tf.keras.layers.Concatenate()([feat_1, feat_2])
output = tf.keras.layers.Dense(10, activation='relu',kernel_initializer=glorot_uniform())(concat_inputs)
output = tf.keras.layers.Dense(10, kernel_initializer=glorot_uniform())(output)
output = tf.keras.layers.Activation(activation='softmax', dtype='float32')(output)

model = tf.keras.Model(feat_1_inputs + feat_2_inputs, output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.sparse_categorical_crossentropy)

col_1_data = [tf.expand_dims(tf.ragged.constant(ds['col_1'].values,dtype=np.int64),axis=-1)]
col_2_data = tf.expand_dims(tf.ragged.constant(ds['col_2'].values,dtype=np.int64),axis=-1)
col_1_data.append(col_2_data)

model.fit(x=col_1_data,y=col_2_data,epochs=10)

Error:

Epoch 1/10
Traceback (most recent call last):
  File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
    model.fit(x=col_1_data,y=col_2_data,epochs=10)
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
      model.fit(x=col_1_data,y=col_2_data,epochs=10)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 183, in call
      return self._merge_function(inputs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 531, in _merge_function
      return backend.concatenate(inputs, axis=self.axis)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/backend.py", line 3311, in concatenate
      return tf.concat(tensors, axis)
        Node: 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert'
        assertion failed: [Inputs must have identical ragged splits] [Condition x == y did not hold element-wise:] [x (model/lstm/RaggedFromTensor/concat:0) = ] [0 8 11...] [y (model/lstm_1/RaggedFromTensor/concat:0) = ] [0 1 7...]
             [[{{node model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert}}]] [Op:__inference_train_function_9256]

If rows in both columns contain lists of the same length then it works fine. Is there a way to work with lists of different length using ragged tensors?

TF2.8 is used.

0

There are 0 answers