Capturing state as array in QLearning with Accord.net

358 views Asked by At

I'm trying to implement QLearning to simulated ants in Unity. Following Accord's Animat Example, I managed to implement the gist of the algorithm.

Now my Agent has 5 state inputs - Three of them comes from sensors that detect obstacles in front of it (RayCasts in Unity) and the remaining two are its X and Y position on the map.

My problem is that qLearning.GetAction(currentState) only takes an int as a paramenter. How can I implement my Algorithm using an array (or Tensor) for the agent current state?

Here is my code:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Accord.MachineLearning;
using System;

public class AntManager : MonoBehaviour {
    float direction = 0.01f;
    float rotation = 0;

    // learning settings
    int learningIterations = 100;
    private double explorationRate = 0.5;
    private double learningRate = 0.5;

    private double moveReward = 0;
    private double wallReward = -1;
    private double goalReward = 1;

    private float lastDistance = 0;

    private RaycastHit hit;
    private int hitInteger = 0;

    // Q-Learning algorithm
    private QLearning qLearning = null;


    // Use this for initialization
    void Start () {
        qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
    }

    // Update is called once per frame
    void Update () {        

        // curent coordinates of the agent
        float agentCurrentX = transform.position.x;
        float agentCurrentY = transform.position.y;
        // exploration policy
        TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy;

        EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy;

        // set exploration rate for this iteration
        explorationPolicy.Epsilon = explorationRate - learningIterations * explorationRate;
        // set learning rate for this iteration
        qLearning.LearningRate = learningRate -   learningIterations * learningRate;
        // clear tabu list
        tabuPolicy.ResetTabuList();

        // get agent's current state
        int currentState = ((int)Math.Round(transform.position.x, 0) + (int)Math.Round(transform.position.y, 0) + hitInteger);
        // get the action for this state
        int action = qLearning.GetAction(currentState);
        // update agent's current position and get his reward
        double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action);
        // get agent's next state
        int nextState = currentState;
        // do learning of the agent - update his Q-function
        qLearning.UpdateState(currentState, action, reward, nextState);

        // set tabu action
        tabuPolicy.SetTabuAction((action + 2) % 4, 1);


    }

    // Update agent position and return reward for the move
    private double UpdateAgentPosition(ref float currentX, ref float currentY, int action)
    {
        // default reward is equal to moving reward
        double reward = moveReward;
        GameObject food = GameObject.FindGameObjectWithTag("Food");

        float distance = Vector3.Distance(transform.position, food.transform.position);

        if (distance < lastDistance)
            reward = 0.2f;

        lastDistance = distance;

        Debug.Log(distance);

        switch (action)
        {
            case 0:         // go to north (up)
                rotation += -1f;
                break;
            case 1:         // go to east (right)
                rotation += 1f;
                break;
            case 2:         // go to south (down)
                rotation += 1f;
                break;
            case 3:         // go to west (left)
                rotation += -1f;
                break;
        }

        //transform.eulerAngles = new Vector3(10, rotation, 0);
        transform.Rotate(0, rotation * Time.deltaTime, 0);
        transform.Translate(new Vector3(0, 0, 0.01f));



        float newX = transform.localRotation.x;
        float newY = transform.localRotation.y;

        Ray sensorForward = new Ray(transform.position, transform.forward);
        Debug.DrawRay(transform.position, transform.forward * 1);

        if (Physics.Raycast(sensorForward, out hit, 1))
        {
            if (hit.collider.tag != "Terrain")
            {
                Debug.Log("Sensor Forward hit!");

                reward = wallReward;
            }
            if (hit.collider.tag == "Food")
            {
                Debug.Log("Sensor Found Food!");
                Destroy(food);
                reward = goalReward;
                hitInteger = 1;
            }
            hitInteger = 0;
        }

        return reward;
    }
}
1

There are 1 answers

0
Gracie On

The documentation supplies this as an example:

c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7)

This appears to be bit shifting two value integers into a binary encoding of the state. Your code probably needs something like this:

int currentState = ((int)Math.Round(transform.position.x, 0) | ((int)Math.Round(transform.position.y, 0) << 1) | (hitInteger << 2))

However, you will first need to map your states into binary variables, so this code will only work with a 2x2 grid. Even though the example declares integers they are binary values: it would be meaningless to bit-shift a value of 2 or more.

A useful way to visualise the state is looking directly at the binary:

Convert.ToString(1 | (0 << 1) | (1 << 2), 2)