Can this nested for-loop be rewritten using tensorflow functions to allow for gradient calculation?

119 views Asked by At

I wrote a function that sums only certain q-values from a tensor, those being the values corresponding to previous actions taken. I need this function to be auto-differentiable, but my current implementation uses a numpy array with nested for-loops, so the TensorFlow computation graph cannot track it and I am getting the error:

ValueError: No gradients provided for any variable: ['critic_network/fc1/kernel:0', 'critic_network/fc1/bias:0', 'critic_network/fc2/kernel:0', 'critic_network/fc2/bias:0', 'critic_network/q/kernel:0', 'critic_network/q/bias:0'].

Here is the function in question:

# Get q-values for the actions taken at the sampled states (= q)
critic1_reshaped = tf.reshape( self.critic_1(states), [BATCH_SIZE, NUM_BOTS, NUM_NODES] )  # critic values shape = (64, 132) => (64, 12, 11) reshaped
q1 = np.zeros(BATCH_SIZE)
for i, batch in enumerate(actions):  # action shape = (BATCH_SIZE, 7, 2)  # each action is a list of 7 [group, node] lists
    for action in batch:
        group = action[0]
        node = action[1]
        value = critic1_reshaped[i, group, node-1]
        q1[i] += value

Structure-wise, the actions (shape=(64,7,2)) tensor contains BATCH_SIZE=64 samples, each sample i being of the form:

actions[i] = [[g0, n0],[g1, n1],[g2, n2],[g3, n3],[g4, n4],[g5, n5],[g6, n6]] .

The critic1_reshaped (shape=(64,12,11)) tensor also contains BATCH_SIZE=64 samples, divided first into the group g then node n. Here is an example of group g at sample i:

critic1_reshaped[i][g] = [n0, n1, n2, n3, n4, n5, n6, n7, n8, n9, n10]


Essentially, I want to grab each action[i]'s g and n, use them to find the value at critic1_reshaped[i][g][n], and sum them all together (so in total, 7 pairs should be summed). This should be done for each sample, resulting in a shape=(64,) output tensor.

I was messing around trying to make this into a list-comprehension or using reduce_sum(), but TensorFlow wasn't playing nice when trying to index using another tensor.

Any ideas?

0

There are 0 answers