import pandas as pd
import psycopg2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
from gym.envs.registration import register
from gym import spaces
import gym
import numpy as np
import gymnasium
import stable_baselines3
print(f"Stable-Baselines3 version: {stable_baselines3.__version__}")
class CustomEnv(gymnasium.Env):
def __init__(self, dataset, columns):
super(CustomEnv, self).__init__()
if dataset is None:
raise ValueError("The dataset must be provided.")
if columns is None:
raise ValueError("The columns must be provided.")
self.dataset = dataset
self.columns = columns
if not isinstance(dataset, pd.DataFrame):
raise ValueError("The dataset must be a DataFrame.")
self.dataset = dataset
self.initial_balance = 10000.0 # Initial balance for trading
self.current_step = 0 # Current step in the dataset
self.balance = self.initial_balance
self.holding = 0 # Number of units of the asset held by the agent
# Use "MPN5P" as the price column
self.price_column = "MPN5P"
self.current_price = self.dataset[self.price_column].iloc[self.current_step] # Current price of the asset
self.action_space = gymnasium.spaces.discrete.Discrete(3)
self.observation_space = gymnasium.spaces.box.Box(low=0, high=1, shape=(2,), dtype=np.float32)
self.columns = columns
def step(self, action):
# Execute the action and update the environment state
self._take_action(action)
self.current_step += 1
if self.current_step >= len(self.dataset):
self.current_step = 0 # Reset if the end of the dataset is reached
self.current_price = self.dataset[self.price_column].iloc[self.current_step] # Update current price
# Initialize reward with a default value
reward = 0
# Calculate reward based on the new state and the agent's action
if self.current_step >= len(self.dataset):
reward = self._compute_reward(action)
# Initialize done and info outside the conditional block
done = False # For simplicity, assuming episodes never terminate
info = {} # Additional information (optional)
truncated = False
# Return observation, reward, done flag, and info dictionary
return self._get_observation(), reward, done,truncated, info
def reset(self, seed=None):
# Reset the environment state
self.current_step = 0
self.balance = self.initial_balance
self.holding = 0
self.current_price = self.dataset[self.price_column].iloc[self.current_step]
observation = np.array([self.current_price, self.initial_balance / self.initial_balance], dtype=np.float32) # Ensure a NumPy array with shape (2,)
info = {} # Empty dictionary for info (if not used)
print(f"Observation: {observation}")
print(f"Info: {info}")
return observation, info # Return both observation and reset info
def _take_action(self, action):
# Execute the action (buy, sell, or hold)
if action == 0: # Buy
if self.balance >= self.current_price:
self.holding += 1
self.balance -= self.current_price
elif action == 1: # Sell
if self.holding > 0:
self.holding -= 1
self.balance += self.current_price
def _compute_reward(self, action):
# Define reward computation (e.g., based on profit/loss)
reward = 0
if action == 0: # Buy
reward -= self.current_price
elif action == 1: # Sell
reward += self.current_price
return reward
def _get_observation(self):
current_price = self.dataset[self.price_column].iloc[self.current_step]
normalized_balance = self.balance / self.initial_balance
observation = np.array([current_price, normalized_balance], dtype=np.float32)
return observation
This is my environment initialization code. I didn't use MultiDiscrete but out of somewhere I am getting NameError. I think it's from my stable_baselines3 2.2.1.
I installed stable_baselines3 newer version but the code is only importing the older one for some reason.
I was expecting for the newer version to remove this glitch and train the RL model with PPO smoothly. I even edited Dummy_vec_env.py to expect a single environment instead of multiple as I am using vectorized environment.
I was thrown an error before to use vectorized environment which is why I used DummyVecEnv and then wrapped it.