I am planning to use TD3 with MultiInputPolicy that accepts Dict type observations for my custom multi-agent environment.

...train.py", line 114, in <module>
    model = TD3(
            ^^^^
  File "D:\anaconda3\Lib\site-packages\stable_baselines3\td3\td3.py", line 137, in __init__
    self._setup_model()
  File "D:\anaconda3\Lib\site-packages\stable_baselines3\td3\td3.py", line 140, in _setup_model
    super()._setup_model()
  File "D:\anaconda3\Lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 199, in _setup_model
    self.policy = self.policy_class(
                  ^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: TD3Policy.forward() takes from 2 to 3 positional arguments but 4 were given

Relevant model and policy definitions:

model = TD3(
    policy=policy,
    env=env,
    ...
)

I tried substituting the env with a known working gym environment ('Pendulum-v1') for TD3 and that produced the same error. So I moved to investigating the policy definition:

policy = MultiInputPolicy(
    env.observation_space,
    env.action_space,
    lr_schedule,
    ...
}

And this brought me back to the environment. Is something wrong with my observation and action space? Please advise.

        ...
        self.action_space = Box(
            0.0, +1.0, shape=(len(self.actions.keys()),), dtype=np.float32
        )

        self.observation_space = Dict(
            {
                "a": Box(
                    -2.0,
                    +1.0,
                    shape=(2 * r1 + 1, r2+ 1),
                    dtype=np.float32,
                ),
                "b": Box(
                    -1.0,
                    1.0,
                    shape=(2 * r1 + 1, r2+ 1),
                    dtype=np.int32,
                ),
                "c": Box(
                    -1.0,
                    100.0,
                    shape=(2 * r1 + 1, r2 + 1),
                    dtype=np.float32,
                ),
            }
        )
        ...
0

There are 0 answers