I am planning to use TD3 with MultiInputPolicy that accepts Dict type observations for my custom multi-agent environment.
...train.py", line 114, in <module>
model = TD3(
^^^^
File "D:\anaconda3\Lib\site-packages\stable_baselines3\td3\td3.py", line 137, in __init__
self._setup_model()
File "D:\anaconda3\Lib\site-packages\stable_baselines3\td3\td3.py", line 140, in _setup_model
super()._setup_model()
File "D:\anaconda3\Lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 199, in _setup_model
self.policy = self.policy_class(
^^^^^^^^^^^^^^^^^^
File "D:\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: TD3Policy.forward() takes from 2 to 3 positional arguments but 4 were given
Relevant model and policy definitions:
model = TD3(
policy=policy,
env=env,
...
)
I tried substituting the env with a known working gym environment ('Pendulum-v1') for TD3 and that produced the same error. So I moved to investigating the policy definition:
policy = MultiInputPolicy(
env.observation_space,
env.action_space,
lr_schedule,
...
}
And this brought me back to the environment. Is something wrong with my observation and action space? Please advise.
...
self.action_space = Box(
0.0, +1.0, shape=(len(self.actions.keys()),), dtype=np.float32
)
self.observation_space = Dict(
{
"a": Box(
-2.0,
+1.0,
shape=(2 * r1 + 1, r2+ 1),
dtype=np.float32,
),
"b": Box(
-1.0,
1.0,
shape=(2 * r1 + 1, r2+ 1),
dtype=np.int32,
),
"c": Box(
-1.0,
100.0,
shape=(2 * r1 + 1, r2 + 1),
dtype=np.float32,
),
}
)
...