When I was looking at the matd3 algorithm in the agilerl library, I didn’t quite understand the details of actor parameter update.
if len(self.scores) % self.policy_freq == 0:
if self.arch == "mlp":
if self.accelerator is not None:
with actor.no_sync():
action = actor(states[agent_id])
else:
action = actor(states[agent_id])
if not self.discrete_actions:
action = torch.where(
action > 0,
action * self.max_action[idx][0],
action * -self.min_action[idx][0],
)
detached_actions = copy.deepcopy(actions)
detached_actions[agent_id] = action
input_combined = torch.cat(
list(states.values()) + list(detached_actions.values()), 1
)
if self.accelerator is not None:
with critic_1.no_sync():
actor_loss = -critic_1(input_combined).mean()
else:
actor_loss = -critic_1(input_combined).mean()
**detached_actions[agent_id] = action ** I don't quite understand this code Why does it update the action in the current loop individually instead of updating all agent actions?
I asked claude and the answer he gave was: This ensures that the actions of all other agents except the current agent are detached (separated from the calculation graph) and will not affect the calculation graph and updates. I still don't quite understand
This code is located at line 689 of the agilerl library matd3 file