RuntimeError: Input type (unsigned char) and bias type (float) should be the same

611 views Asked by At

I'm using Pytorch, CUDA, and Pycharm to program a DQN Agent for Gymnasium's tetris environment. The error comes as soon as my agent tries to decide on an action. Something is wrong with the

action_values = self.net(state, model="current")

Here's my code for my DQN:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage



class QNetwork(nn.Module):
    def __init__(self, state_shape, output_size):
        super(QNetwork, self).__init__()
        stacked_frames, h, w = state_shape
        self.current = self.__build_network(stacked_frames, output_size)

        self.target = self.__build_network(stacked_frames, output_size)
        self.target.load_state_dict(self.current.state_dict())
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if model == 'current':
            return self.current(input)
        elif model == 'target':
            return self.target(input)

    def __build_network(self, stacked_frames, output_size):
        return nn.Sequential(
            nn.Conv2d(stacked_frames, 64, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Flatten(3136, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Linear(256, output_size),
        ).float()

class DQNAgent:
    def __init__(self, state_shape, action_size, save_dir):
        self.state_shape = state_shape
        self.action_size = action_size
        self.save_dir = save_dir

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = QNetwork(self.state_shape, self.action_size).float()
        self.net = self.net.to(self.device)

        self.exploration_rate = 1
        self.gamma = .9
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0
        self.save_every = 5e5
        self.burnin = 1e4  # min. experiences before training
        self.learn_every = 3  # no. of experiences between updates to Q_online
        self.sync_every = 1e4  # no. of experiences between Q_target & Q_online sync

        self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device('cpu')))
        self.batch_size = 64

        self.optimizer = optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = nn.SmoothL1Loss()

    def choose_action(self, state):
        if np.random.rand() < self.exploration_rate:
            action = np.random.randint(self.action_size)

        else:
            state = state[0].__array__() if isinstance(state, tuple) else state.__array__()
            state = torch.tensor(state, device=self.device).unsqueeze(0)
            action_values = self.net(state, model="current")
            action = torch.argmax(action_values).item()
            #action = torch.argmax(action_values, axis=1).item()

        # decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # increment step
        self.curr_step += 1
        return action

    def cache(self, state, next_state, action, reward, done):
        def first_if_tuple(x):
            return x[0] if isinstance(x, tuple) else x

        state = first_if_tuple(state).__array__()
        next_state = first_if_tuple(next_state).__array__()

        state = torch.tensor(state)
        next_state = torch.tensor(next_state)
        action = torch.tensor([action])
        reward = torch.tensor([reward])
        done = torch.tensor([done])

        # self.memory.append((state, next_state, action, reward, done,))
        self.memory.add(
            TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done},
                       batch_size=[]))

    def recall(self):
        batch = self.memory.sample(self.batch_size).to(self.device)
        state, next_state, action, reward, done = (batch.get(key) for key in ('state', 'next_state', 'action', 'reward', 'done'))
        return state, next_state, action, reward, done

    def td_estimate(self, state, action):
        current_q = self.net(state, model="current")[
            np.arange(0, self.batch_size), action
        ]
        return current_q

    def td_target(self, reward, next_state, done):
        next_state_Q = self.model(next_state, model="current")
        best_action = torch.argmax(next_state_Q)
        next_Q = self.net(next_state, model="target")[np.arange(0, self.batch_size), best_action]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_actual_Q(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.target_net.load_state_dict(self.actual_net.state_dict())

    def save_model(self):
        save_path = (
            self.save_space / f"tetris_net_{self.curr_step//self.save_every}.chkpt"
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate), save_path
        )
        print(f"Model saved to {save_path}")

    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()

        if self.curr_step % self.save_every == 0:
            self.save()

        if self.curr_step < self.burnin:
            return None, None

        if self.curr_step % self.learn_every != 0:
            return None, None

        # Sample from memory
        state, next_state, action, reward, done = self.recall()

        # Get TD Estimate
        td_est = self.td_estimate(state, action)

        # Get TD Target
        td_tgt = self.td_target(reward, next_state, done)

        # Backpropagate loss through Q_online
        loss = self.update_Q_online(td_est, td_tgt)

        return (td_est.mean().item(), loss)

And here's my code to run the agent playing Tetris:

import gymnasium as gym
import torch.cuda
from gymnasium.wrappers import FrameStack, GrayScaleObservation
from DeepQAgentTest import DQNAgent, MetricLogger
from collections import namedtuple


env = gym.make("ALE/Tetris-v5", obs_type="rgb")
env = GrayScaleObservation(env)
env = FrameStack(env, 5)

observation_space = env.observation_space
action_space = env.action_space
num_of_actions = env.action_space.n
print(num_of_actions)
print(observation_space.shape)

env.reset()


use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")

save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

agent = DQNAgent(observation_space.shape, num_of_actions, save_dir)


episodes = 100000
for e in range(episodes):

    state = env.reset()

    while True:

        # Run agent on the state
        action = agent.choose_action(state)

        # Agent performs action
        next_state, reward, done, trunc, info = env.step(action)

        # Remember
        agent.cache(state, next_state, action, reward, done)

        # Learn
        q, loss = agent.learn()

        # Logging
        #logger.log_step(reward, loss, q)

        # Update state
        state = next_state

        # Check if end of game
        if done:
            break

The error message:

Traceback (most recent call last):
  File "C:\Users\flipp\PycharmProjects\RLFinal\Tetris Test.py", line 49, in <module>
    action = agent.choose_action(state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\flipp\PycharmProjects\RLFinal\DeepQAgentTest.py", line 66, in choose_action
    action_values = self.net(state, model="current")
RuntimeError: Input type (unsigned char) and bias type (float) should be the same

I'm on python 3.11, torch 2.11, and gymnasium 0.29.1.

1

There are 1 answers

0
vmoens On

The error comes from the fact that your observation has a uint8 dtype, not float.

Try converting your state to float, eg

state = state/255

which will also normalize it to a [0, 1] space (better than [0, 255] integers).