Handling batch_size in a TorchRL environment

52 views Asked by At

I am struggling with batch_size in TorchRL environment. I created an environment that passes the check_env test. Here is an example of the environment, with the _step method simplified but with appropriate shape operations.

class FlyEnv(EnvBase):
  
    batch_locked = False 
    
    def __init__(self, nb_joints=nb_joints, td_params=None, seed=None, device="cpu"):
        self.nb_joints = nb_joints
        self.dt = time_step
        if td_params is None:
            td_params = self.gen_params()

        super().__init__(device=device, batch_size=[])
        self._make_spec(td_params)
        
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

    def gen_params(self, batch_size=None) -> TensorDictBase:
        if batch_size is None:
            batch_size = []

        obs_min = torch.concat(
            (-torch.pi*torch.ones(self.nb_joints),
             -10*torch.ones(self.nb_joints)),
             dim=-1)

        td = TensorDict(
            {
                "params": TensorDict(
                    {
                        "obs_min":obs_min,
                        "obs_max":-obs_min,
                        "act_min":obs_min,
                        "act_max":-obs_in
                    },
                    [],
                )
            },
            [],
        )

        if batch_size:
            td = td.expand(batch_size).contiguous()
        return td
    
    def _make_spec(self, td_params):

        self.observation_spec = CompositeSpec(
            observation=BoundedTensorSpec(
                low=td_params["params", "obs_min"],
                high=td_params["params", "obs_max"],
                shape=(2*self.nb_joints,),
                dtype=torch.float32,
            ),
            params=make_composite_from_td(td_params["params"]),
            shape=td_params.shape,
        )

        # since the environment is stateless, we expect the previous output as input.
        # For this, ``EnvBase`` expects some state_spec to be available
        self.state_spec = self.observation_spec.clone()

        # action-spec will be automatically wrapped in input_spec when
        # `self.action_spec = spec` will be called supported
        self.action_spec = BoundedTensorSpec(
            low=td_params["params", "act_min"],
            high=td_params["params", "act_max"],
            shape=(*td_params.shape, 2*self.nb_joints),
            dtype=torch.float32,
        )
        
        self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))

    def _step(self, tensordict):
    
        action = tensordict["action"]
        new_obs = action

        reward = torch.sum(action, dim=-1)
        reward = reward.view(*tensordict.shape, 1)
        done = torch.zeros_like(reward, dtype=torch.bool)

        out = TensorDict(
            {
                "observation": new_obs,
                "params": tensordict["params"],
                "reward": reward,
                "done": done,
            },
            tensordict.shape,
        )
        return out

    def _reset(self, tensordict):
        if tensordict is None or tensordict.is_empty():
            tensordict = self.gen_params(batch_size=self.batch_size)

        obs_max = tensordict["params", "obs_max"]
        obs_min = tensordict["params", "obs_min"]

        size = (*tensordict.shape, 2*self.nb_joints)

        # for non batch-locked environments, the input ``tensordict`` shape dictates the number
        # of simulators run simultaneously. In other contexts, the initial
        # random state's shape will depend upon the environment batch-size instead.
        obs = (
            torch.rand(size, generator=self.rng, device=self.device)
            * (obs_max - obs_min)
            + obs_min
        )

        out = TensorDict(
            {
                "observation": obs,
                "params": tensordict["params"],
            },
            batch_size=tensordict.shape,
            device="cpu",
        )
        return out
        
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

The in the torchRL pipeline, I use this environment with an actor that takes inputs [*B, F] and outputs [*B, F] where B is the batch shape. However, with collector=SyncDataCollector, the line for i, tensordict_data in enumerate(collector): produces the following error

  File "c:\Users\samje\Documents\EPFL\Cours\Semester project 2\Code\RL_copy.py", line 160, in <module>
    for i, tensordict_data in enumerate(collector):
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\collectors\collectors.py", line 952, in iterator
    tensordict_out = self.rollout()
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\_utils.py", line 469, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\collectors\collectors.py", line 1069, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\common.py", line 2576, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\common.py", line 1409, in step
    next_tensordict = self._step(tensordict)
  File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\transforms\transforms.py", line 738, in _step
    next_tensordict = self.base_env._step(tensordict_in)
  File "c:\Users\samje\Documents\EPFL\Cours\Semester project 2\Code\environment_copy_copy.py", line 168, in _step
    reward = reward.view(*tensordict.shape, 1)
RuntimeError: shape '[1]' is invalid for input of size 4

This is due to the fact that the action (output of the policy network) has shape [B, F] whereas the tensordict.shape = torch.Size([])... And reward has shape [B].

I tried to manually set the batch_size in the different tensordicts/variables of the environment, which solves this issue but fails later in the code in GAE (from torchrl.objectives.value). The latter concatenates tensordicts in shape [B, T, F] and the new environment does not handle the batch size [B, T]...

My question: is there a simple way to handle these batch_sizes in a torchRL environment ?

Please let me now if there is anything missing, I tried to put the minimum as it is already long enough.

1

There are 1 answers

2
vmoens On

A couple of things while looking at the code:

  • All your specs must have a leading dim that corresponds to the batch size. It's restrictive but it's the price to pay for clarity - otherwise we would rely on broadcasting and it'd be error prone (it used not to be the case and the code base was very messy!) So for instance you should have
        self.observation_spec = CompositeSpec(
            observation=BoundedTensorSpec(
                low=td_params["params", "obs_min"],
                high=td_params["params", "obs_max"],
                shape=(*td_params.shape, 2 * self.nb_joints,),
                dtype=torch.float32,
            ),
            params=make_composite_from_td(td_params["params"]),
            shape=td_params.shape,
        )
  • I would pass the batch size in the constructor if that makes sense
    def __init__(self, nb_joints=nb_joints, td_params=None, seed=None, batch_size=None, device="cpu"):
        self.nb_joints = nb_joints
        self.dt = time_step
        if td_params is None:
            td_params = self.gen_params(batch_size)

        super().__init__(device=device, batch_size=batch_size)
        self._make_spec(td_params)

        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

If you want a dynamic batch-size, it is supported (I see you set batch-locked=False which is what is needed there).

  • With the above corrections, I see that these two things work which is already cool
# Without batch
env = FlyEnv()
env.reset()
print(env.rollout(4))

# With batch
tdreset = env.reset().expand(5)
env.rand_step(tdreset)
print(tdreset)

For the collector, we need to be able to tell the environment what the batch-size is going to be. It's not an easy task and we should find a way to streamline that... To say it gently, the compatibility with envs that are not batch-locked and collectors is limited. env.rollout with a tensordict=smth argument should work though!

I made a PR to solve this issue https://github.com/pytorch/rl/pull/2030