Tensor shape seems to disappear when indexing result of tf.shape(tensor)

168 views Asked by At

When I try to index the result of tf.shape(tensor) where tensor is some tensor, the result seems to turn into None unexpectedly. For example, I ran this code:

>>> from ray.rllib.models.utils import try_import_tf
>>> tf1, tf, tfv = try_import_tf() 
>>> tf.compat.v1.enable_eager_execution()  
>>> inp = tf.keras.layers.Input(shape=([19, 33, 1]), name='input')
>>> tf.shape(inp)
<KerasTensor: shape=(4,) dtype=int32 inferred_value=[None, 19, 33, 1] (created by layer 'tf.compat.v1.shape')>

And the results are as expected. However, when I try to run the following code next:

>>> tf.shape(inp)[0]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem')>
>>> tf.shape(inp)[1]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_1')>
>>> tf.shape(inp)[2]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_2')>
>>> tf.shape(inp)[3]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_3')>

The inferred values are all None. What's going on here? Is this expected behaviour?

1

There are 1 answers

0
AudioBubble On

Given code works with Tensorflow 2.8.0

import tensorflow as tf
print(tf.__version__)
inp = tf.keras.layers.Input(shape=([19, 33, 1]), name='input')
tf.shape(inp)
2.8.0
<KerasTensor: shape=(4,) dtype=int32 inferred_value=[None, 19, 33, 1] (created by layer 'tf.compat.v1.shape_5')>

>>tf.shape(inp)[0]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem')>
>>tf.shape(inp)[1]
<KerasTensor: shape=() dtype=int32 inferred_value=[19] (created by layer 'tf.__operators__.getitem_1')>
>>tf.shape(inp)[2]
<KerasTensor: shape=() dtype=int32 inferred_value=[33] (created by layer 'tf.__operators__.getitem_2')>
>>tf.shape(inp)[3]
<KerasTensor: shape=() dtype=int32 inferred_value=[1] (created by layer 'tf.__operators__.getitem_3')>