Difficulty with Stablebaselines VecFrameStack function for observation space pre-processing

148 views Asked by At

I am very new to reinforcement learning.

I have been following a couple of tutorials but have hit a snag which I cannot resolve a couple of hours in.

I have: Imported the game and set-up the environment Pre-processed the action and observation space (Gray, resize etc) Attempted to hyperparameter tuning with SB3 and optuna

I am specifically having difficulty with the VecFrameStack function which produces an AssertionError: VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces

As far as I can tell my observation space has been set up inside a gym.spaces.Box - which is what it requests. Am very confused!

Any help would be really appreciated!

Code samples below:

import retro
import gym
from gym import spaces
import time
import os
from gym import Env
from gym.spaces import MultiBinary, Box
import numpy as np
import cv2
from matplotlib import pyplot as plt
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor 
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

class StreetFighter(Env): 
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low = 0, high = 255, shape = (84,84,1), dtype = np.uint8)
        self.action_space = MultiBinary(12)
        self.game = retro.make(game = 'StreetFighterIISpecialChampionEdition-Genesis', use_restricted_actions = retro.Actions.FILTERED)
    
    def reset(self): 
        obs = self.game.reset()
        obs = self.preprocess(obs)
        self.previous_frame = obs
        self.score = 0
        return obs
    
    def preprocess(self, observation):
        gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (84,84), interpolation = cv2.INTER_CUBIC)
        channels = np.reshape(resize, (84,84,1))
        return channels
    
    def step(self, action):
        obs, reward, done, info = self.game.step(action)
        obs = self.preprocess(obs)
        
        frame_delta = obs - self.previous_frame #Subtract previous from current
        self.previous_frame = obs #Update previous frame prior to taking next step

        reward = info['score'] - self.score
        self.score = info['score'] #Update current score prior to taking next step

        return frame_delta, reward, done, info
    
    def render(self, *args, **kwargs):
        self.game.render()

    def close(self):
        self.game.close()

LOG_DIR = './logs/'
OPT_DIR = './opt/'

def optimize_ppo(trial): 
    return {
        'n_steps':trial.suggest_int('n_steps', 2048, 8192),
        'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),
        'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
        'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),
        'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)
    }

def optimize_agent(trial):
    try:
        model_params = optimize_ppo(trial) 

        env = StreetFighter()
        env = Monitor(env, LOG_DIR)
        env = DummyVecEnv([lambda: env])
        env = VecFrameStack(env, 4, channels_order="last")

        model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
        model.learn(total_timesteps=300000) #Can let this go longer - maybe 100K

        mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
        env.close()
        
        SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
        model.save(SAVE_PATH)

        return mean_reward

    except Exception as e:
        #return -1000

study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=100, n_jobs=1) #Change to 100type here
1

There are 1 answers

0
Devils Den On

Ran into the same error a while ago, managed to figure it out. It seems to be an issue in the compatibility of stable-baselines3 dependency version.

pip install stable-baselines3==1.3.0

This worked for me, hope it works for you too.