I'm trying to implement QLearning to simulated ants in Unity. Following Accord's Animat Example, I managed to implement the gist of the algorithm.
Now my Agent has 5 state inputs - Three of them comes from sensors that detect obstacles in front of it (RayCasts in Unity) and the remaining two are its X and Y position on the map.
My problem is that qLearning.GetAction(currentState)
only takes an int as a paramenter. How can I implement my Algorithm using an array (or Tensor) for the agent current state?
Here is my code:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Accord.MachineLearning;
using System;
public class AntManager : MonoBehaviour {
float direction = 0.01f;
float rotation = 0;
// learning settings
int learningIterations = 100;
private double explorationRate = 0.5;
private double learningRate = 0.5;
private double moveReward = 0;
private double wallReward = -1;
private double goalReward = 1;
private float lastDistance = 0;
private RaycastHit hit;
private int hitInteger = 0;
// Q-Learning algorithm
private QLearning qLearning = null;
// Use this for initialization
void Start () {
qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
}
// Update is called once per frame
void Update () {
// curent coordinates of the agent
float agentCurrentX = transform.position.x;
float agentCurrentY = transform.position.y;
// exploration policy
TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy;
EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy;
// set exploration rate for this iteration
explorationPolicy.Epsilon = explorationRate - learningIterations * explorationRate;
// set learning rate for this iteration
qLearning.LearningRate = learningRate - learningIterations * learningRate;
// clear tabu list
tabuPolicy.ResetTabuList();
// get agent's current state
int currentState = ((int)Math.Round(transform.position.x, 0) + (int)Math.Round(transform.position.y, 0) + hitInteger);
// get the action for this state
int action = qLearning.GetAction(currentState);
// update agent's current position and get his reward
double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action);
// get agent's next state
int nextState = currentState;
// do learning of the agent - update his Q-function
qLearning.UpdateState(currentState, action, reward, nextState);
// set tabu action
tabuPolicy.SetTabuAction((action + 2) % 4, 1);
}
// Update agent position and return reward for the move
private double UpdateAgentPosition(ref float currentX, ref float currentY, int action)
{
// default reward is equal to moving reward
double reward = moveReward;
GameObject food = GameObject.FindGameObjectWithTag("Food");
float distance = Vector3.Distance(transform.position, food.transform.position);
if (distance < lastDistance)
reward = 0.2f;
lastDistance = distance;
Debug.Log(distance);
switch (action)
{
case 0: // go to north (up)
rotation += -1f;
break;
case 1: // go to east (right)
rotation += 1f;
break;
case 2: // go to south (down)
rotation += 1f;
break;
case 3: // go to west (left)
rotation += -1f;
break;
}
//transform.eulerAngles = new Vector3(10, rotation, 0);
transform.Rotate(0, rotation * Time.deltaTime, 0);
transform.Translate(new Vector3(0, 0, 0.01f));
float newX = transform.localRotation.x;
float newY = transform.localRotation.y;
Ray sensorForward = new Ray(transform.position, transform.forward);
Debug.DrawRay(transform.position, transform.forward * 1);
if (Physics.Raycast(sensorForward, out hit, 1))
{
if (hit.collider.tag != "Terrain")
{
Debug.Log("Sensor Forward hit!");
reward = wallReward;
}
if (hit.collider.tag == "Food")
{
Debug.Log("Sensor Found Food!");
Destroy(food);
reward = goalReward;
hitInteger = 1;
}
hitInteger = 0;
}
return reward;
}
}
The documentation supplies this as an example:
This appears to be bit shifting two value integers into a binary encoding of the state. Your code probably needs something like this:
However, you will first need to map your states into binary variables, so this code will only work with a 2x2 grid. Even though the example declares integers they are binary values: it would be meaningless to bit-shift a value of 2 or more.
A useful way to visualise the state is looking directly at the binary: