tensorflow DenseHashTable lookup multi-dimensional keys

33 views Asked by At

I want use DenseHashTable lookup string tensors, just like this answeranswer , keys' type is tf.string, value is embedding with tf.float32 dtype. But when keys is multi-dimensional, error occurs.

keys = ["Fritz", "Franz", "Fred"]
values = [[1, 2, 3, -1], [4, 5, -1, -1], [6, 7, 8, 9]]
table = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string, value_dtype=tf.float32, empty_key="0", deleted_key="-1", default_value=[-1,-1,-1,-1])
table.insert(keys, values)
table.lookup(['Franz', 'Emil']) # shape=(2,) its ok
table.lookup([['Franz', 'Emil'], ['Emil', 'Fred']]) # when lookup with 2-D tensor(shape like (batch_size, 2)), throws error.

How can i make it works just like tf.nn.embedding_lookup? Keys not array index but tf.string.

1

There are 1 answers

0
mhenning On BEST ANSWER

The problem is that TensorFlow expects a list of keys, not a nested list of keys. Granted, the Can be a tensor of any shape. in the keys description in the docs is a bit confusing.
What you can do is flatten your list, hash it and reshape it afterwards:

keys = [['Franz', 'Emil'], ['Emil', 'Fred']]
keys = tf.convert_to_tensor(keys)  # to get the shape
key_shape = keys.shape  # shape: (2, 2)
x = table.lookup(tf.reshape(keys, -1))  # shape: (4, 4) after hashing

x = tf.reshape(x, key_shape+(x.shape[-1:]))  # shape: (2, 2 ,4)