Translate tensor horizontally, back-fill with zeros

Given a 2D tensor

T = [[1, 2, 3]
[4, 5, 6]]

and a 1D tensor containing horizontal shifts, say, s = [0, -2, 1], how can I obtain the following 3D tensor R?

R[0] = T

R[1] = [[3, 0, 0],  # shifted two to the left,
[6, 0, 0]]  # padding the rest with zeros

R[2] = [[0, 1, 2],  # shifted one to the right,
[0, 4, 5]]  # padding the rest with zeros

I know about tf.contrib.image.translate, but that isn't differentiable, so I am looking for some elegant combination of padding/slicing/looping/concatenating operations that accomplishes the same thing.

On Best Solutions

I have only come up with two ways to use tf.map_fn(). The first method is to fill about 0 in T and slice it.

import tensorflow as tf

T = tf.constant([[1, 2, 3],[4, 5, 6]],dtype=tf.float32)
s = tf.constant([0, -2, 1])

left = tf.reduce_max(s)
right = tf.reduce_min(s)

result = tf.map_fn(lambda x: tmp_slice[:,left-x:left-x+tf.shape(T)[1]],s,dtype=T.dtype)

with tf.Session() as sess:
print(sess.run(result))

# print
[[[1. 2. 3.]
[4. 5. 6.]]

[[3. 0. 0.]
[6. 0. 0.]]

[[0. 1. 2.]
[0. 4. 5.]]]
[array([[2., 2., 2.],
[2., 2., 2.]], dtype=float32)]

The second method is to compute a corresponding mask matrix by tf.sequence_mask and tf.roll().Then take the value by tf.where().

import tensorflow as tf

T = tf.constant([[1, 2, 3],[4, 5, 6]],dtype=tf.float32)
s = tf.constant([0, -2, 1])

indices = tf.tile([x], (tf.shape(T)[0],))

with tf.Session() as sess:
print(sess.run(result))

# print
[[[1. 2. 3.]
[4. 5. 6.]]

[[3. 0. 0.]
[6. 0. 0.]]

[[0. 1. 2.]
[0. 4. 5.]]]
[array([[2., 2., 2.],
[2., 2., 2.]], dtype=float32)]

Update

I found new method to achieve it. In essence, horizontal shifts are T multiplied by an offset identity matrix. So we can use np.eye() to create factor.

import tensorflow as tf
import numpy as np

T = tf.constant([[1, 2, 3],[4, 5, 6]],dtype=tf.float32)
s = tf.constant([0, -2, 1])

new_T = tf.tile(tf.expand_dims(T,axis=0),[tf.shape(s)[0],1,1])
s_factor = tf.map_fn(lambda x: tf.py_func(lambda y: np.eye(T.get_shape().as_list()[-1],k=y),[x],tf.float64),s,tf.float64)

result = tf.matmul(new_T,tf.cast(s_factor,new_T.dtype))

with tf.Session() as sess:
print(sess.run(result))