How to not use too much RAM with deep q learning?

581 views Asked by At

I have run this code but after 10 seconds, the session crashed because the code uses too much ram. What is the problem that causes this issue and how to fix it?

My assumption is that the class PreProcess is taken all RAMs or it might be a stochastic gradient descent algorithm? I literally have no idea.

[NOTE: I used google collab with Keras and tensorflow]

Here is the code:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import random
import gym
import time
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib.animation import PillowWriter

#input_shape = (4,80,80)
class replay_buffer:
    def __init__(self, mem_size=25000, input_shape=(210,160,3)):
        self.mem_size = mem_size
        self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.state_memory = np.zeros((self.mem_size,*input_shape), dtype=np.float32)
        self.next_state_memory = np.zeros((self.mem_size,*input_shape), dtype=np.float32)
        self.terminal_state_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.mem = 0


    def store_transition(self, state, next_state, action, reward, terminal_state):
        index = self.mem % self.mem_size 
        print("index: {}".format(index))
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.state_memory[index] = state
        self.next_state_memory[index] = next_state
        self.terminal_state_memory[index] = terminal_state
        self.mem += 1

    def sample_buffer(self):
        indexs = np.random.choice(range(len(self.terminal_state_memory)), size = 32)
        action_sample = np.array(self.action_memory[x] for x in indexs)
        reward_sample = np.array(self.reward_memory[x] for x in indexs)
        state_sample = np.array(self.state_memory[x] for x in indexs)
        next_state_sample = np.array(self.next_state_memory[x] for x in indexs)
        terminal_state_sample = np.array(self.terminal_state_memory[x] for x in indexs)

        return action_sample, reward_sample, state_sample, next_state_sample, terminal_state_sample

class dqn_network(tf.keras.Model):
    def __init__(self):
        super(dqn_network,self).__init__()
        self.input_layer = tf.keras.Input(shape=(84,84,4,))
        self.first_hidden_layer = tf.keras.layers.Conv2D(16,8, strides=4, activation="relu")
        self.second_hidden_layer = tf.keras.layers.Conv2D(32,4, strides=2, activation="relu")
        self.dense_layer = tf.keras.layers.Dense(256, activation="relu")
        self.output_layer = tf.keras.layers.Dense(4, activation="linear")
    
    def __call__(self):
        layer1 = self.first_hidden_layer(self.input_layer)
        layer2 = self.second_hidden_layer(layer1)
        layer3 = Flatten()(layer2)
        layer4 = self.dense_layer(layer3)
        layer5 = self.output_layer(layer4)
        model = tf.keras.Model(inputs=layer1, outputs=layer5)
        model.compile(optimizer= tf.keras.optimizers.Adam(lr=1e-3) , loss= tf.keras.losses.mean_squared_error)
        return model


class agent(object):
    def __init__(self, epsilon=1, max_epsilon=1, min_epsilon=0.1, update_target=10000,timestep=0, batch_size=32):
        super(agent, self).__init__()        
        self.epsilon = epsilon
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.target_network = dqn_network()
        self.Q_network = dqn_network()
        self.update_target = update_target
        self.timestep = timestep
        self.experience_relay = replay_buffer()
        self.batch_size = batch_size

    def update_timestep(self, newtimestep):
        self.timestep = newtimestep

    def update_target_network(self):
        if self.timestep != 0 and self.update_target % self.timestep == 0:
            self.target_network.set_weights(self.Q_network.get_weights())
    
    def greedy_policy(self):
        if random.uniform(0,1) < self.epsilon:
            return np.random.choice(env.action_space.n)
        else:
            q_values = self.target_network.predict(state[np.newaxis])
            return np.argmax(q_values[0])


    def store_transition(self, state, next_state, action, reward, terminal_state):
        self.experience_relay.store_transition(state, next_state, action, reward, terminal_state)

    def annealing_epsilon(self):
        interval = self.max_epsilon - self.min_epsilon
        self.epsilon -= interval / 100000
        if self.epsilon < 0.1:
            self.epsilon = self.min_epsilon

    def training(self):
        if self.timestep % 4 == 0 and self.experience_relay.mem > self.batch_size:
            actions, rewards, states, next_states, dones = self.experience_relay.sample_buffer()
            print("next state: {}".format(next_states))
            np_states = np.expand_dims(next_states, axis=0)
            print("np_states: {}".format(np_states))

            #tf_states = tf.convert_to_tensor(next_states)
            self.update_target_network()
            next_q_value = self.Q_network.predict(next_states)

            q_targets = rewards + (1-dones)*gamma*np.max(next_q_value, axis = 1)
            mask = tf.one_hot(actions, env.action_space.n)

            with tf.GradientTape() as tape:
                total_q_value = dqn_network(states)
                q_values = tf.reduce_sum(mask*total_q_value, axis=1, keepdims=True)

                loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(q_targets, q_values))
            
            grad = tape.gradient(loss, self.Q_network.trainable.variables)
            tf.keras.optimizers.Adam.apply_gradients(zip(grad, self.Q_network.trainable.variables))


class PreProcess(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(PreProcess, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0,high=255,shape=(84,84,1), dtype= np.uint8)
            
    def greyscale(self):
        return np.mean(self.observation_space, axis=2)

class model:
    def __init__(self):
        self.frame_buffer = []
    
    def add_img(self, img):
        self.frame_buffer.append(img)

    def create_gif(self, filepath=None):  # here here here
        plt.figure(figsize=(self.frame_buffer[0].shape[1] / 72, self.frame_buffer[0].shape[0] / 72), dpi = 72)
        patch = plt.imshow(self.frame_buffer[0])
        plt.axis('off')
        
        def animate(i):
            patch.set_data(self.frame_buffer[i])
        
        ani = animation.FuncAnimation(plt.gcf(), animate, frames = len(self.frame_buffer))
        if filepath:
            writergif = animation.PillowWriter(fps=20)
            ani.save(filepath, writer = writergif)
            print("file saved")
    

if __name__ == "__main__":
    env = gym.make("BreakoutDeterministic-v4")
    PreProcess(env)

    dqn = agent()

    target_update = 10000
    MaxTimestep = 100000
    episode_num = 0
    frame_num = 0
    state = env.reset()

    while True:
        image_file = model()
        start = time.time()

        for timestep in range(MaxTimestep):
            frame_num +=1
            action = dqn.greedy_policy()
            dqn.annealing_epsilon()

            next_state, reward, done, info = env.step(action)

            dqn.update_timestep(timestep)

            dqn.store_transition(state, next_state, action, reward, done)
            
            img = env.render("rgb_array")
            image_file.add_img(img)

            if (target_update % frame_num) == 0:
                dqn.training()
            
            if done or (timestep == MaxTimestep-1):
                end = time.time()
                print("[episode: {}, time taken: {:.5f} sec, timestep: {}]".format(episode_num + 1 , end-start, timestep))

                if episode_num % 10 == 0:
                    image_file.create_gif(filepath= r"./drive/My Drive/GIF-breakout-v1/episode{}.gif")
                    
                    print("[[episode: {}, time taken: {:.5f} sec, timestep: {}]]".format(episode_num + 1 , end-start, timestep))
                    break
                break

        episode_num += 1

Thank you in advance :]

1

There are 1 answers

2
Tom Dörr On

You could try to reduce the size of the replay buffer from 25000 to 250 to see if that's the issue.