The said error is raised on learning loop, upon calling loss.backwards function.
here's my code:
critic_value_ = self.agents[agent].target_critic.forward(states_,
new_actions).flatten()
critic_value_[dones[:, 0]] = 0.0
critic_value = self.agents[agent].critic.forward(states, old_actions).flatten()
target = rewards[:, agent_idx] + self.agents[agent].gamma * critic_value_
loss = self.agents[agent].critic.loss(target, critic_value)
self.agents[agent].critic.optimizer.zero_grad()
# T.autograd.set_detect_anomaly(True)
loss.backward(retain_graph=True)
self.agents[agent].critic.optimizer.step()
the network which is trained looks like this:
import os
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class CriticNetwork(nn.Module):
def __init__(self, beta, input_dims, fc1_dims, fc2_dims,
n_agents, n_actions, name, chkpt_dir):
super(CriticNetwork, self).__init__()
self.chkpt_file = os.path.join(chkpt_dir, name)
self.fc1 = nn.Linear(input_dims + n_agents * n_actions, fc1_dims)
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
self.q = nn.Linear(fc2_dims, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta)
self.loss = nn.MSELoss()
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
# self.double()
self.to(self.device)
def forward(self, state, action):
x = F.relu(self.fc1(T.cat([state, action], dim=1)))
x = F.relu(self.fc2(x))
q = self.q(x)
return q
this the the error I get:
Traceback (most recent call last):
File "PycharmProjects/MARL/source/main.py", line 86, in <module>
maddpg_agents.learn(memory)
File "PycharmProjects/MARL/source/maddpg.py", line 84, in learn
loss.backward(retain_graph=True)
File "PycharmProjects/MARL/venv/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "PycharmProjects/MARL/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Double
objects details during run, just before error was raised:
loss: tensor(2.1452, dtype=torch.float64, grad_fn=<MseLossBackward0>)
target: tensor([-1.9751, -1.8311, -1.1488, ..., -1.9725, -0.5983, -1.1475],
dtype=torch.float64, grad_fn=<AddBackward0>)
critic_value: tensor([0.2046, 0.2089, 0.1132, ..., 0.1629, 0.1426, 0.0946],
grad_fn=<ReshapeAliasBackward0>)
trying to cast tensors as float didn't work.
trying workarounds such as suggested here (similar error but the other way around - "Found dtype Double but expected Float") didn't work either (off course I tried in the correct direction...).