I have an issue when computing the mean squared error in the critic loss function using a DDPG agent. The error message I receive indicates a shape mismatch between the expected tensor shape and the actual tensor shape in the critic loss function of the DDPG agent between the td_targets and the q_values tensor.

Here is the relevant code snippet:

# Create the agent 
self.ddpg_agent = DdpgAgent(
            time_step_spec=self.tf_env.time_step_spec(),
            action_spec=self.tf_env.action_spec(),
            actor_network=actor_network,
            critic_network=critic_network,
            actor_optimizer=Adam(learning_rate=self.learning_rate),
            critic_optimizer=Adam(learning_rate=self.learning_rate),
            gamma=self.discount_factor,  
            target_update_tau=0.01, 
            ou_stddev=0.3,
            ou_damping=0.3,
            td_errors_loss_fn=common.element_wise_squared_loss,
        )
# Initialize replay buffer
replay_buffer = replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.ddpg_agent.collect_data_spec,
            batch_size=1,  
            max_length=5000)
#Add experiences to the replay buffer
experience = trajectory.from_transition(time_step, action_step, next_time_step)
            replay_buffer.add_batch(experience)

# Create the dataset
dataset = replay_buffer.as_dataset(
            sample_batch_size=self.batch_size, # self.batch_size = 32
            num_steps=2,  
            num_parallel_calls=3,  
            single_deterministic_pass=False
).prefetch(3)

#Train the agent 
iterator = iter(dataset)                
experience_set, _ = next(iterator)
loss = self.ddpg_agent.train(experience_set)

If I run the code, it gets interrupted during the loss calculation with the error:

 File "main.py", line 138, in <module>
    main()
  File "main.py", line 109, in main
    a2c.train_agent()
  File "a2c.py", line 41, in train_agent
    self.agent.train_agent()
  File "agent.py", line 161, in train_agent
    loss = self.ddpg_agent.train(experience_set)
  File "tf_agents\agents\tf_agent.py", line 330, in train
    loss_info = self._train_fn(
  File "tf_agents\utils\common.py", line 188, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "tf_agents\agents\ddpg\ddpg_agent.py", line 247, in _train
    critic_loss = self.critic_loss(time_steps, actions, next_time_steps,
  File "tf_agents\agents\ddpg\ddpg_agent.py", line 343, in critic_loss
    critic_loss = self._td_errors_loss_fn(td_targets, q_values)
  File "tf_agents\utils\common.py", line 1139, in element_wise_squared_loss
    return tf.compat.v1.losses.mean_squared_error(
  File "tensorflow\python\util\traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "tensorflow\python\framework\tensor_shape.py", line 1361, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (32, 1) and (32, 32) are incompatible

I checked all the spec_shapes, the experience shapes and the output shapes of my actor and critic network. They all seem correct and the actor and critic output layer produce the expected shape of (32, 1), where the batch size is 32. The mismatch is between the td_targets and q_values in the loss function in tf_agents\agents\ddpg\ddpg_agent.py with: TD Targets shape: (32, 32) Q Values shape: (32, 1)

Can someone advise me what I am missing here?

1

There are 1 answers

0
Peter Renz On BEST ANSWER

I solved the problem by choosing another loss function when the DDPG gets initialized:

# Create the agent 
self.ddpg_agent = DdpgAgent(
            time_step_spec=self.tf_env.time_step_spec(),
            action_spec=self.tf_env.action_spec(),
            actor_network=actor_network,
            critic_network=critic_network,
            actor_optimizer=Adam(learning_rate=self.learning_rate),
            critic_optimizer=Adam(learning_rate=self.learning_rate),
            gamma=self.discount_factor,  
            target_update_tau=0.01, 
            ou_stddev=0.3,
            ou_damping=0.3,
            td_errors_loss_fn=tf.keras.losses.MeanSquaredError(),
        )