I'm in the process of integrating a custom environment and policy into Stable-Baselines3 (SB3). While setting up the _predict functionality, I encountered an issue. Although I can manually utilize the _predict functionality with a standard Python dict. I need to define the environment using gym.spaces when working with SB3. I suspect that SB3 uses the gym.spaces to do something internally. Consequently, I represent the observations from the environment using a standard Python dictionary and define the observation and action space using gym.spaces. However, when employing the SB3 Proximal Policy Optimization (PPO) algorithm, an error arises. It appears that PPO is passing the <class 'gymnasium.spaces.dict.Dict'> type instead of the actual dict observation. Perhaps I need to embed the observation into the observation space. However, I don't think this is possible. Should I utilize SB3 feature extractor in some fashion? Perhaps there is something that I have forgotten to define in order to ensure that the actual observation is used and not just the gym.space?
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
class CustomEnv(gym.Env):
def __init__(self, nRows, nCols):
super().__init__()
# init
self.nRows = nRows
self.nCols = nCols
self.iter = 0
self.done = False
self.truncated = False
# action space
self.action_space = gym.spaces.Discrete(self.nRows * self.nCols)
# observation space
self.observation_space = gym.spaces.Dict({
'layout': gym.spaces.Box(low=0, high=255, shape=(self.nRows, self.nCols), dtype=np.uint8),
'mask': gym.spaces.Box(low=0, high=255, shape=(self.nRows * self.nCols,), dtype=np.uint8)
})
# actual observation
self.observation = {'layout': np.zeros((self.nRows, self.nCols), dtype=np.uint8),
'mask': np.zeros(self.nRows * self.nCols, dtype=np.uint8)}
def step(self, action):
self.iter += 1
reward = 0
print("action: ", action)
layout = self.observation["layout"].flatten()
mask = self.observation["mask"]
if layout[action] == 0:
layout[action] = 1
mask[action] = 1
reward = 1
else:
reward = -1
self.observation = {'layout': np.reshape(layout, (self.nRows, self.nCols)),
'mask': mask}
if self.iter > self.nRows * self.nCols:
self.done = True
self.truncated = True
info = {}
return self.observation, reward, self.done, self.truncated, info
def reset(self, seed=None, options=None):
# reset
self.iter = 0
self.done = False
self.truncated = False
self.observation = {'layout': np.zeros((self.nRows, self.nCols), dtype=np.uint8),
'mask': np.zeros(self.nRows * self.nCols, dtype=np.uint8)}
info = {}
return self.observation, info
def render(self):
pass
def close(self):
pass
class CustomPolicy(BasePolicy):
def __init__(self, observation_space, action_space):
super(CustomPolicy, self).__init__(observation_space, action_space)
input_size = np.shape(observation_space["layout"])[0] * np.shape(observation_space["layout"])[1]
output_size = action_space.n
self.l1 = nn.Linear(input_size, 5)
self.relu = nn.ReLU()
self.l2 = nn.Linear(5, output_size)
def forward(self, obs, r=None, deterministic: bool = False, **kwargs):
print(type(obs))
x = torch.Tensor(obs["layout"].flatten())
output = self.l1(x)
output = self.relu(output)
output = self.l2(output)
return output
def _predict(self, obs, deterministic: bool = False):
# Forward pass through the network
action_logits = self.forward(obs, deterministic=deterministic)
if deterministic:
# For deterministic action selection, take the action with maximum probability
action = torch.argmax(action_logits)
else:
# For stochastic action selection, sample from the action distribution
action_probs = F.softmax(action_logits, dim=0)
action_dist = torch.distributions.Categorical(probs=action_probs)
action = action_dist.sample()
return action, None # Don't need log-prob.
# Create an instance of your custom environment
env = CustomEnv(3, 3)
custom_policy = CustomPolicy(env.observation_space, env.action_space)
# Running a small test
action = env.action_space.sample()
observation, reward, done, truncated, _ = env.step(action)
action, _ = custom_policy._predict(observation, False)
print(action)
action, _ = custom_policy._predict(observation, True)
print(action)
# train PPO with your custom environment and custom policy
model = PPO(policy=custom_policy, env=env, verbose=1).model.learn(total_timesteps=1000)
# eval
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")
-------------THE ERROR-------------------
x = torch.Tensor(obs["layout"].flatten())
AttributeError: 'Box' object has no attribute flatten