Stable-Baslines3 Type Error in _predict w. custom environment & policy

22 views Asked by At

I'm in the process of integrating a custom environment and policy into Stable-Baselines3 (SB3). While setting up the _predict functionality, I encountered an issue. Although I can manually utilize the _predict functionality with a standard Python dict. I need to define the environment using gym.spaces when working with SB3. I suspect that SB3 uses the gym.spaces to do something internally. Consequently, I represent the observations from the environment using a standard Python dictionary and define the observation and action space using gym.spaces. However, when employing the SB3 Proximal Policy Optimization (PPO) algorithm, an error arises. It appears that PPO is passing the <class 'gymnasium.spaces.dict.Dict'> type instead of the actual dict observation. Perhaps I need to embed the observation into the observation space. However, I don't think this is possible. Should I utilize SB3 feature extractor in some fashion? Perhaps there is something that I have forgotten to define in order to ensure that the actual observation is used and not just the gym.space?

import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy


class CustomEnv(gym.Env):

    def __init__(self, nRows, nCols):
        super().__init__()

        # init
        self.nRows = nRows
        self.nCols = nCols
        self.iter = 0
        self.done = False
        self.truncated = False

        # action space
        self.action_space = gym.spaces.Discrete(self.nRows * self.nCols)

        # observation space
        self.observation_space = gym.spaces.Dict({
            'layout': gym.spaces.Box(low=0, high=255, shape=(self.nRows, self.nCols), dtype=np.uint8),
            'mask': gym.spaces.Box(low=0, high=255, shape=(self.nRows * self.nCols,), dtype=np.uint8)
        })

        # actual observation
        self.observation = {'layout': np.zeros((self.nRows, self.nCols), dtype=np.uint8),
                            'mask': np.zeros(self.nRows * self.nCols, dtype=np.uint8)}

    def step(self, action):
        self.iter += 1
        reward = 0
        print("action: ", action)
        layout = self.observation["layout"].flatten()
        mask = self.observation["mask"]

        if layout[action] == 0:
            layout[action] = 1
            mask[action] = 1
            reward = 1
        else:
            reward = -1

        self.observation = {'layout': np.reshape(layout, (self.nRows, self.nCols)),
                            'mask': mask}

        if self.iter > self.nRows * self.nCols:
            self.done = True
            self.truncated = True

        info = {}

        return self.observation, reward, self.done, self.truncated, info

    def reset(self, seed=None, options=None):
        # reset
        self.iter = 0
        self.done = False
        self.truncated = False

        self.observation = {'layout': np.zeros((self.nRows, self.nCols), dtype=np.uint8),
                            'mask': np.zeros(self.nRows * self.nCols, dtype=np.uint8)}
        info = {}

        return self.observation, info

    def render(self):
        pass

    def close(self):
        pass


class CustomPolicy(BasePolicy):
    def __init__(self, observation_space, action_space):
        super(CustomPolicy, self).__init__(observation_space, action_space)

        input_size = np.shape(observation_space["layout"])[0] * np.shape(observation_space["layout"])[1]
        output_size = action_space.n

        self.l1 = nn.Linear(input_size, 5)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(5, output_size)

    def forward(self, obs, r=None, deterministic: bool = False, **kwargs):
        print(type(obs))
        x = torch.Tensor(obs["layout"].flatten())
        output = self.l1(x)
        output = self.relu(output)
        output = self.l2(output)
        return output

    def _predict(self, obs, deterministic: bool = False):
        # Forward pass through the network
        action_logits = self.forward(obs, deterministic=deterministic)
        if deterministic:
            # For deterministic action selection, take the action with maximum probability
            action = torch.argmax(action_logits)
        else:
            # For stochastic action selection, sample from the action distribution
            action_probs = F.softmax(action_logits, dim=0)
            action_dist = torch.distributions.Categorical(probs=action_probs)
            action = action_dist.sample()
        return action, None  # Don't need log-prob.


# Create an instance of your custom environment
env = CustomEnv(3, 3)
custom_policy = CustomPolicy(env.observation_space, env.action_space)

# Running a small test
action = env.action_space.sample()
observation, reward, done, truncated, _ = env.step(action)

action, _ = custom_policy._predict(observation, False)
print(action)

action, _ = custom_policy._predict(observation, True)
print(action)

# train PPO with your custom environment and custom policy
model = PPO(policy=custom_policy, env=env, verbose=1).model.learn(total_timesteps=1000)

# eval
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")


-------------THE ERROR-------------------
x = torch.Tensor(obs["layout"].flatten())
AttributeError: 'Box' object has no attribute flatten 

0

There are 0 answers