Implementing Sarsa(lambda) - Gridworld - in Julia language

157 views Asked by At

Could you explain me what is wrong in this code ? I am trying to implement SARSA(lamda) with eligibility traces.

using ReinforcementLearningBase, GridWorlds
using PyPlot

world = GridWorlds.GridRoomsDirectedModule.GridRoomsDirected();
env = GridWorlds.RLBaseEnv(world)

mutable struct Agent
    env::AbstractEnv
    algo::Symbol
    ϵ::Float64 # exploration coefficient
    ϵ_decay::Float64
    ϵ_min::Float64
    λ::Float64 # parametr lambda
    β::Float64 # discount factor
    α::Float64 # learning rate
    Q::Dict
    score::Int # number of times the agent reached the goal
    steps_per_episode::Vector{Float64} # average number of steps per episode
    E::Dict 
end

function Agent(env, algo; ϵ = 1.0, ϵ_decay = 0.9975, ϵ_min = 0.005, λ=0.9,
        β = 0.99, α = 0.1) 
    if algo != :SARSA && algo != :Qlearning
        @error "unknown algorithm"
    end
    Agent(env, algo,
        ϵ, ϵ_decay, ϵ_min,λ, β, α, 
        Dict(), 0, [0.0,],Dict())
end

function learn!(agent, S, A, r, S′,A′)
    if !haskey(agent.Q, S)
        agent.E[S] = zeros(length(action_space(agent.env)))
        agent.Q[S] = zeros(length(action_space(agent.env)))
        agent.Q[S][A] = r
        agent.E[S][A]=1
    else
        Q_S′ = 0.0
        haskey(agent.Q, S′) && (Q_S′ += agent.Q[S′][A′])
        Δ = r + agent.β * agent.Q[S′][A′] - agent.Q[S][A]
        agent.E[S][A]=agent.β*agent.λ*agent.E[S][A]+1
        agent.Q[S][A] += agent.α * Δ*agent.E[S][A]

    end
end
function run_learning!(agent, steps; burning = true, animated = nothing) 
    step = 1.0
    steps_per_episode = 1.0
    episode = 1.0
    if !isnothing(animated)
        global str = ""
        global str = str * "FRAME_START_DELIMITER"
        global str = str * "step: $(step)\n"
        global str = str * "episode: $(episode)\n"
        global str = str * repr(MIME"text/plain"(), env)
        global str = str * "\ntotal_reward: 0"
    end
    while step <= steps
        if (burning && step < 0.1*steps) || rand() < agent.ϵ || !haskey(agent.Q, state(agent.env))
            A = rand(1:length(action_space(agent.env)))
        else 
            A = argmax(agent.Q[state(agent.env)])
        end
        S = deepcopy(state(agent.env))
        agent.env(action_space(agent.env)[A])
        r = reward(agent.env)
        S′ = deepcopy(state(agent.env))
        if agent.algo == :SARSA
            if (burning && step < 0.1 * steps) || rand() < agent.ϵ || !haskey(agent.Q, state(agent.env))
                A′ = rand(1:length(action_space(agent.env)))
            else 
                A′ = argmax(agent.Q[state(agent.env)])
            end
            learn!(agent, S, A, r, S′,A′)
        else
            learn!(agent, S, A, r, S′)
        end
        if !isnothing(animated) 
            global str = str * "FRAME_START_DELIMITER"
            global str = str * "step: $(step)\n"
            global str = str * "episode: $(episode)\n"
            global str = str * repr(MIME"text/plain"(), env)
            global str = str * "\ntotal_reward: $(agent.score)"
        end
        if is_terminated(agent.env)
            eps = agent.ϵ * agent.ϵ_decay
            agent.ϵ = max(agent.ϵ_min, eps)
            agent.score += 1.0
            push!(agent.steps_per_episode, 
                agent.steps_per_episode[end] + (steps_per_episode - agent.steps_per_episode[end])/episode)
            episode += 1.0
            steps_per_episode = 0
            reset!(agent.env)
        end
        step += 1.0 
        steps_per_episode += 1.0
    end
    if !isnothing(animated) 
        write(animated * ".txt", str)
    end
end

agent_SARSA = Agent(env,:SARSA);

run_learning!(agent_SARSA, 2500) @info "agent score: $(agent_SARSA.score)"

after running the code I receive such info, but i don't uderstand why.

KeyError: key ([0 0 … 0 0; 1 1 … 1 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 1 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 1 … 1 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 0 … 0 1; 0 0 … 0 0;;; 0 0 … 0 0; 1 1 … 1 1; 0 0 … 0 0], 1) not found

Stacktrace:
 [1] getindex(h::Dict{Any, Any}, key::Tuple{BitArray{3}, Int64})
   @ Base .\dict.jl:498
 [2] learn!(agent::Agent, S::Tuple{BitArray{3}, Int64}, A::Int64, r::Float32, S′::Tuple{BitArray{3}, Int64}, A′::Int64)
   @ Main .\In[44]:10
 [3] run_learning!(agent::Agent, steps::Int64; burning::Bool, animated::Nothing)
   @ Main .\In[45]:31
 [4] run_learning!(agent::Agent, steps::Int64)
   @ Main .\In[45]:1
 [5] top-level scope
   @ In[51]:1
 [6] eval
   @ .\boot.jl:368 [inlined]
 [7] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base .\loading.jl:1428

I have tried manipulating with Dictionary but without any succes.

0

There are 0 answers