Ragged tensor slicing with tf.newaxis does not work. tf.expand_dims does though

235 views Asked by At

I have a ragged tensor seq (num_sentences, num_words, word_dim). If I want to add a new axis (dimension) at the end with a slice seq[..., tf.newaxis], it fails with the error below.

tf.expand_dims(seq, -1) works however.

Is there an explanation for that? Also, I'd rather use the slice syntax, because it's much more readable -- if first arg of expand_dims is more complicated -- you have to write the axis argument somewhere at the end (where other axis arguments usually accumulate...).

 File "main.py", line 159, in __init__
    seq = seq[..., tf.newaxis]
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 205, in wrapper
    result = dispatch(wrapper, args, kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 118, in dispatch
    result = dispatcher.handle(args, kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 1538, in handle
    return SlicingOpLambda(self.op)(*args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 951, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1090, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 1521, in _call_wrapper
    return original_call(*new_args, **new_kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 1327, in _call_wrapper
    return self._call_wrapper(*args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 1359, in _call_wrapper
    result = self.function(*args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 963, in _slice_helper
    tensor = ops.convert_to_tensor(tensor)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py", line 163, in wrapped
    return func(*args, **kwargs)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1540, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 339, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 264, in constant
    return _constant_impl(value, dtype, shape, name, verify_shape=False,
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 281, in _constant_impl
    tensor_util.make_tensor_proto(
  File "/home/adam/.local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_util.py", line 551, in make_tensor_proto
    raise TypeError("Failed to convert object of type %s to Tensor. "
TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 64), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(None,), dtype=int64)). Consider casting elements to a supported type.

EDIT: replicating the error:

inputs = tf.keras.layers.Input([None, 10], dtype=tf.float32, ragged=True)
inputs[..., tf.newaxis]
Out: # I get the error above

tf.expand_dims(inputs, -1)
Out: <KerasTensor: type_spec=RaggedTensorSpec(TensorShape([None, None, 10, 1]), tf.float32, 1, tf.int64) (created by layer 'tf.expand_dims')>
0

There are 0 answers