Is this possible with tf.tensor_scatter_nd_add

378 views Asked by At

A simple example of the following use of tf.tensor_scatter_nd_add is giving me problems.

B = tf.tensor_scatter_nd_add(A, indices, updates)

tensor A is (1,4,4)

A = [[[1. 1. 1. 1.],
      [1. 1. 1. 1.],
      [1. 1. 1. 1.],
      [1. 1. 1. 1.]]]

the desired result is tensor B:

B = [[[1. 1. 1. 1.],
      [1. 2. 3. 1.],
      [1. 4. 5. 1.],
      [1. 1. 1. 1.]]]

i.e. I want to add this smaller tensor to just the 4 inner elements of tensor A

updates = [[[1, 2],
            [3, 4]]]

Tensorflow 2.1.0. I've tried a number of ways of constructing indices. The call to tensor_scatter_nd_add returns an error saying the inner dimensions don't match.

Do the updates tensor need to be the same shape as A?

1

There are 1 answers

2
Poe Dator On BEST ANSWER

Planaria,

Try passing indices and updates the following way: updates with shape (n), indices with shape (n,3) where n is number of changed items. Indices should point to individual cells that you want to change:

A = tf.ones((1,4,4,), dtype=tf.dtypes.float32)
updates =  tf.constant([1., 2., 3., 4])
indices = tf.constant([[0,1,1], [0,1,2], [0,2,1], [0,2,2]])
tf.tensor_scatter_nd_add(A, indices, updates)

<tf.Tensor: shape=(1, 4, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
        [1., 2., 3., 1.],
        [1., 4., 5., 1.],
        [1., 1., 1., 1.]]], dtype=float32)>