Tensorflow: How to use a Ragged Tensor as an index into a normal tensor?

842 views Asked by At

I have a 2D RaggedTensor consisting of indices I want from each row of a full tensor, e.g.:

[
    [0,4],
    [1,2,3],
    [5]
]

into

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

gives

[
    [200,20],
    [315,401,20],
    [105]
]

How can I achieve this in the most efficient way (preferably only with tf functions)? I believe that things like gather_nd are able to take RaggedTensors but I cannot figure out how it works.

1

There are 1 answers

0
Lescurel On BEST ANSWER

You can use tf.gather, with the batch_dims keyword argument:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>