I am trying to build a RL model with deep Q-Learning using RL4J in the Anylogic PLE as part of my thesis. Unfortunately, I am not overly familiar with Anylogic and DL4J and therefore might be missing some obvious steps. I only have access to the PLE and therefore I am trying to train the RL model inside Anylogic without the use of custom experiments. I changed the Vehicle battery example to be trained inside the Anylogic project's main agent.
The main agent includes a state chart, which is my environment. On top of that functions getObservation()
, takeAction()
, calcReward()
are implemented here as well. I also included a resetStateChart()
function to reset my state chart once a final state is reached. All of these functions work fine when I call them manually. The last function inside main is the train()
function containing all the RL4J code.
Once I start training, the training itself works fine and all the functions get called properly. However, the state chart doesn't update. I invoke the transitions using the receiveMessage()
function. When I call takeAction()
manually receiveMessage()
returns true. When called during training, it returns false. According to the documentation this means that there is no transition inside the state chart, which gets fired by the message. I am unsure how this can be the case as it runs fine when manually called.
Do any of you have an idea what could cause this?
This is my function to take actions:
public void takeAction( int action ) {
if (action == 0 || action == 1)
power -= actionCost;
else if (action == 2)
power -= waitCost;
boolean ret = statechart.receiveMessage(action);
}
This is the code for training the RL model:
void train ( ) {
MDP<Encodable, Integer, DiscreteSpace> mdp = new MDP<Encodable, Integer, DiscreteSpace>() {
ObservationSpace<Encodable> observationSpace = new ArrayObservationSpace<>(new int[]{13});
DiscreteSpace actionSpace = new DiscreteSpace(4);
public ObservationSpace<Encodable> getObservationSpace() { return observationSpace;
}
public DiscreteSpace getActionSpace() { return actionSpace;
}
public Encodable getObservation() {
return new Encodable() {
double[] a = main_getObservation();
public double[] toArray() { return a;
}
};
}
public Encodable reset() {
resetStateChart();
return getObservation();
}
public void close(){}
public StepReply<Encodable> step(Integer action) {
int a = action.intValue();
takeAction(a);
double reward = calcReward();
return new StepReply(getObservation(), reward, isDone(), null);
}
public boolean isDone(){
return inState(win) || inState(dead);
}
public MDP<Encodable, Integer, DiscreteSpace> newInstance() { return null; }
};
try{
DataManager manager = new DataManager(true);
QLearning.QLConfiguration AL_QL =
new QLearning.QLConfiguration(
1,
10000,
100000,
100000,
128,
1000,
10,
1,
0.99,
1.0,
0.1f,
30000,
true
);
DQNFactoryStdDense.Configuration AL_NET =
DQNFactoryStdDense.Configuration.builder()
.l2(0).updater(new
RmsProp(0.001)).numHiddenNodes(300).numLayer(2).build();
QLearningDiscreteDense<Encodable> dql = new QLearningDiscreteDense(mdp, AL_NET, AL_QL, manager);
dql.train();
DQNPolicy<Encodable> pol = dql.getPolicy();
pol.save("Statechart_Policy2.zip");
mdp.close();
}catch(IOException e){
e.printStackTrace();
}
If you need any additional information please let me know.
Looking forward to any hints and thank you!
Edit:
This is my function to reset the state chart to its initial state:
void resetStateChart( ) {
boolean ret = statechart.receiveMessage(4);
power = powerCapacity;
}
This is my function to get Observations:
public double[] main_getObservation( ) {
statechart_state[] states = statechart_state.values();
double[] observation = new double[states.length + 1];
for (int i=0; i<states.length; i++){
statechart_state ss = states[i];
observation[i] = inState(ss) ? 1 : 0;
}
observation[observation.length - 1] = power/powerCapacity;
return observation;
}
Code of my state chart:
// Statecharts
public Statechart<statechart_state> statechart = new Statechart<>( this, (short)6 );
@Override
@AnyLogicInternalCodegenAPI
public String getNameOf( Statechart _s ) {
if(_s == this.statechart) return "statechart";
return super.getNameOf( _s );
}
@Override
@AnyLogicInternalCodegenAPI
public int getIdOf( Statechart _s ) {
if(_s == this.statechart) return 0;
return super.getIdOf( _s );
}
@Override
@AnyLogicInternalCodegenAPI
public void executeActionOf( Statechart _s ) {
if( _s == this.statechart ) {
enterState( operating, true );
return;
}
super.executeActionOf( _s );
}
// States of all statecharts
public enum statechart_state implements IStatechartState<Main, statechart_state> {
operating,
recharging,
reset,
dead,
state_1() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
state_2() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
state_3() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
state_4() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
win() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
branch() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
branch1() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
},
branch2() {
@Override
@AnyLogicInternalCodegenAPI
public statechart_state getContainerState() {
return operating;
}
};
@AnyLogicInternalCodegenAPI
private Collection<statechart_state> _simpleStatesDeep_xjal;
@AnyLogicInternalCodegenAPI
private Set<statechart_state> _fullState_xjal;
@AnyLogicInternalCodegenAPI
private Set<statechart_state> _statesInside_xjal;
@Override
@AnyLogicInternalCodegenAPI
public Collection<statechart_state> getSimpleStatesDeep() {
Collection<statechart_state> result = _simpleStatesDeep_xjal;
if (result == null) {
_simpleStatesDeep_xjal = result = calculateAllSimpleStatesDeep();
}
return result;
}
@Override
public Set<statechart_state> getFullState() {
Set<statechart_state> result = _fullState_xjal;
if (result == null) {
_fullState_xjal = result = calculateFullState();
}
return result;
}
@Override
@AnyLogicInternalCodegenAPI
public Set<statechart_state> getStatesInside() {
Set<statechart_state> result = _statesInside_xjal;
if (result == null) {
_statesInside_xjal = result = calculateStatesInside();
}
return result;
}
@Override
@AnyLogicInternalCodegenAPI
public Statechart<statechart_state> getStatechart( Main _a ) {
return _a.statechart;
}
}
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state operating = statechart_state.operating;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state recharging = statechart_state.recharging;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state reset = statechart_state.reset;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state dead = statechart_state.dead;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state state_1 = statechart_state.state_1;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state state_2 = statechart_state.state_2;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state state_3 = statechart_state.state_3;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state state_4 = statechart_state.state_4;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state win = statechart_state.win;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state branch = statechart_state.branch;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state branch1 = statechart_state.branch1;
@AnyLogicCustomProposalPriority(type = AnyLogicCustomProposalPriority.Type.STATIC_ELEMENT)
public static final statechart_state branch2 = statechart_state.branch2;
@AnyLogicInternalCodegenAPI
private void enterState( statechart_state self, boolean _destination ) {
switch( self ) {
case operating:
logToDBEnterState(statechart, self);
// (Composite state)
trans_op_rech.start();
trans_op_dead.start();
trans_op_res.start();
if ( _destination ) {
enterState( state_1, true );
}
return;
case recharging:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( recharging );
trans_rech_op.start();
trans_rech_rech.start();
trans_rech_re.start();
return;
case reset:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( reset );
trans_res_op.start();
return;
case dead:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( dead );
trans_dead_res.start();
return;
case state_1:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( state_1 );
{
System.out.println("State1");
;}
trans_1_b.start();
trans_1_1.start();
return;
case state_2:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( state_2 );
{
System.out.println("State2");
;}
trans_2_b.start();
trans_2_1.start();
return;
case state_3:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( state_3 );
trans_3_4.start();
trans_3_2.start();
trans_3_2_not_wait.start();
return;
case state_4:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( state_4 );
trans_4_b.start();
trans_4_3.start();
return;
case win:
logToDBEnterState(statechart, self);
// (Simple state (not composite))
statechart.setActiveState_xjal( win );
return;
case branch:
logToDBEnterState(statechart, self);
// (Branch)
if (
randomTrue(slipChance)
) { // trans_slip_1
enterState( state_1, true );
return;
}
// trans_b_2 (default)
enterState( state_2, true );
return;
case branch1:
logToDBEnterState(statechart, self);
// (Branch)
if (
randomTrue(slipChance)
) { // transi_slip_2
enterState( state_1, true );
return;
}
// trans_b_3 (default)
enterState( state_3, true );
return;
case branch2:
logToDBEnterState(statechart, self);
// (Branch)
if (
randomTrue(slipChance)
) { // trans_slip_4
enterState( state_1, true );
return;
}
// trans_b_w (default)
enterState( win, true );
return;
default:
return;
}
}
@AnyLogicInternalCodegenAPI
private void exitState( statechart_state self, Transition _t, boolean _source ) {
switch( self ) {
case operating:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Composite state)
if ( _source ) exitInnerStates(self);
if ( !_source || _t != trans_op_rech ) trans_op_rech.cancel();
if ( !_source || _t != trans_op_dead ) trans_op_dead.cancel();
if ( !_source || _t != trans_op_res ) trans_op_res.cancel();
return;
case recharging:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_rech_op) trans_rech_op.cancel();
if ( !_source || _t != trans_rech_rech) trans_rech_rech.cancel();
if ( !_source || _t != trans_rech_re) trans_rech_re.cancel();
return;
case reset:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_res_op) trans_res_op.cancel();
return;
case dead:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_dead_res) trans_dead_res.cancel();
return;
case state_1:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_1_b) trans_1_b.cancel();
if ( !_source || _t != trans_1_1) trans_1_1.cancel();
return;
case state_2:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_2_b) trans_2_b.cancel();
if ( !_source || _t != trans_2_1) trans_2_1.cancel();
return;
case state_3:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_3_4) trans_3_4.cancel();
if ( !_source || _t != trans_3_2) trans_3_2.cancel();
if ( !_source || _t != trans_3_2_not_wait) trans_3_2_not_wait.cancel();
return;
case state_4:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
if ( !_source || _t != trans_4_b) trans_4_b.cancel();
if ( !_source || _t != trans_4_3) trans_4_3.cancel();
return;
case win:
logToDBExitState(statechart, self);
logToDB(statechart, _t, self);
// (Simple state (not composite))
return;
default:
return;
}
}
@AnyLogicInternalCodegenAPI
private void exitInnerStates( statechart_state _destination ) {
statechart_state _state = statechart.getActiveSimpleState();
while( _state != _destination ) {
exitState( _state, null, false );
_state = _state.getContainerState();
}
}
public TransitionTimeout trans_rech_rech = new TransitionTimeout( this );
public TransitionTimeout trans_3_4 = new TransitionTimeout( this );
@Override
@AnyLogicInternalCodegenAPI
public String getNameOf( TransitionTimeout _t ) {
if ( _t == trans_rech_rech ) return "trans_rech_rech";
if ( _t == trans_3_4 ) return "trans_3_4";
return super.getNameOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public Statechart getStatechartOf( TransitionTimeout _t ) {
if ( _t == trans_rech_rech ) return statechart;
if ( _t == trans_3_4 ) return statechart;
return super.getStatechartOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public void executeActionOf( TransitionTimeout self ) {
if ( self == trans_rech_rech ) {
{
power = limitMax(power+25, powerCapacity)
;}
trans_rech_rech.start();
return;
}
if ( self == trans_3_4 ) {
exitState( state_3, self, true );
enterState( state_4, true );
return;
}
super.executeActionOf( self );
}
@Override
@AnyLogicInternalCodegenAPI
public double evaluateTimeoutOf( TransitionTimeout _t ) {
double _value;
if ( _t == trans_rech_rech ) {
_value =
1
;
_value = toModelTime( _value, SECOND );
return _value;
}
if ( _t == trans_3_4 ) {
_value =
5
;
_value = toModelTime( _value, SECOND );
return _value;
}
return super.evaluateTimeoutOf( _t );
}
public TransitionCondition trans_res_op = new TransitionCondition( this );
public TransitionCondition trans_op_dead = new TransitionCondition( this );
@Override
@AnyLogicInternalCodegenAPI
public String getNameOf( TransitionCondition _t ) {
if ( _t == trans_res_op ) return "trans_res_op";
if ( _t == trans_op_dead ) return "trans_op_dead";
return super.getNameOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public Statechart getStatechartOf( TransitionCondition _t ) {
if ( _t == trans_res_op ) return statechart;
if ( _t == trans_op_dead ) return statechart;
return super.getStatechartOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public void executeActionOf( TransitionCondition self ) {
if ( self == trans_res_op ) {
exitState( reset, self, true );
enterState( operating, true );
return;
}
if ( self == trans_op_dead ) {
exitState( operating, self, true );
enterState( dead, true );
return;
}
super.executeActionOf( self );
}
@Override
@AnyLogicInternalCodegenAPI
public boolean testConditionOf( TransitionCondition _t ) {
if ( _t == trans_res_op ) return
true
;
if ( _t == trans_op_dead ) return
power <= 0
;
return super.testConditionOf( _t );
}
public TransitionMessage trans_op_rech = new TransitionMessage( this );
public TransitionMessage trans_rech_op = new TransitionMessage( this );
public TransitionMessage trans_rech_re = new TransitionMessage( this );
public TransitionMessage trans_dead_res = new TransitionMessage( this );
public TransitionMessage trans_op_res = new TransitionMessage( this );
public TransitionMessage trans_1_b = new TransitionMessage( this );
public TransitionMessage trans_2_b = new TransitionMessage( this );
public TransitionMessage trans_4_b = new TransitionMessage( this );
public TransitionMessage trans_4_3 = new TransitionMessage( this );
public TransitionMessage trans_3_2 = new TransitionMessage( this );
public TransitionMessage trans_3_2_not_wait = new TransitionMessage( this );
public TransitionMessage trans_2_1 = new TransitionMessage( this );
public TransitionMessage trans_1_1 = new TransitionMessage( this );
@Override
@AnyLogicInternalCodegenAPI
public String getNameOf( TransitionMessage _t ) {
if ( _t == trans_op_rech ) return "trans_op_rech";
if ( _t == trans_rech_op ) return "trans_rech_op";
if ( _t == trans_rech_re ) return "trans_rech_re";
if ( _t == trans_dead_res ) return "trans_dead_res";
if ( _t == trans_op_res ) return "trans_op_res";
if ( _t == trans_1_b ) return "trans_1_b";
if ( _t == trans_2_b ) return "trans_2_b";
if ( _t == trans_4_b ) return "trans_4_b";
if ( _t == trans_4_3 ) return "trans_4_3";
if ( _t == trans_3_2 ) return "trans_3_2";
if ( _t == trans_3_2_not_wait ) return "trans_3_2_not_wait";
if ( _t == trans_2_1 ) return "trans_2_1";
if ( _t == trans_1_1 ) return "trans_1_1";
return super.getNameOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public Statechart getStatechartOf( TransitionMessage _t ) {
if ( _t == trans_op_rech ) return statechart;
if ( _t == trans_rech_op ) return statechart;
if ( _t == trans_rech_re ) return statechart;
if ( _t == trans_dead_res ) return statechart;
if ( _t == trans_op_res ) return statechart;
if ( _t == trans_1_b ) return statechart;
if ( _t == trans_2_b ) return statechart;
if ( _t == trans_4_b ) return statechart;
if ( _t == trans_4_3 ) return statechart;
if ( _t == trans_3_2 ) return statechart;
if ( _t == trans_3_2_not_wait ) return statechart;
if ( _t == trans_2_1 ) return statechart;
if ( _t == trans_1_1 ) return statechart;
return super.getStatechartOf( _t );
}
@Override
@AnyLogicInternalCodegenAPI
public void executeActionOf( TransitionMessage self, Object _msg ) {
if ( self == trans_op_rech ) {
exitState( operating, self, true );
enterState( recharging, true );
return;
}
if ( self == trans_rech_op ) {
exitState( recharging, self, true );
enterState( operating, true );
return;
}
if ( self == trans_rech_re ) {
exitState( recharging, self, true );
enterState( reset, true );
return;
}
if ( self == trans_dead_res ) {
exitState( dead, self, true );
enterState( reset, true );
return;
}
if ( self == trans_op_res ) {
exitState( operating, self, true );
enterState( reset, true );
return;
}
if ( self == trans_1_b ) {
exitState( state_1, self, true );
{
Integer msg = (Integer) _msg;
System.out.println("Action");
;}
enterState( branch, true );
return;
}
if ( self == trans_2_b ) {
exitState( state_2, self, true );
enterState( branch1, true );
return;
}
if ( self == trans_4_b ) {
exitState( state_4, self, true );
enterState( branch2, true );
return;
}
if ( self == trans_4_3 ) {
exitState( state_4, self, true );
enterState( state_3, true );
return;
}
if ( self == trans_3_2 ) {
exitState( state_3, self, true );
enterState( state_2, true );
return;
}
if ( self == trans_3_2_not_wait ) {
exitState( state_3, self, true );
enterState( state_2, true );
return;
}
if ( self == trans_2_1 ) {
exitState( state_2, self, true );
enterState( state_1, true );
return;
}
if ( self == trans_1_1 ) {
exitState( state_1, self, true );
enterState( state_1, true );
return;
}
super.executeActionOf( self, _msg );
}
@Override
@AnyLogicInternalCodegenAPI
public boolean testMessageOf( TransitionMessage _t, Object _msg ) {
if ( _t == trans_op_rech ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
3
;
return msg.equals( _g );
}
if ( _t == trans_rech_op ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
3
;
return msg.equals( _g );
}
if ( _t == trans_rech_re ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
4
;
return msg.equals( _g );
}
if ( _t == trans_dead_res ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
4
;
return msg.equals( _g );
}
if ( _t == trans_op_res ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
4
;
return msg.equals( _g );
}
if ( _t == trans_1_b ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
0
;
return msg.equals( _g );
}
if ( _t == trans_2_b ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
0
;
return msg.equals( _g );
}
if ( _t == trans_4_b ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
0
;
return msg.equals( _g );
}
if ( _t == trans_4_3 ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
1
;
return msg.equals( _g );
}
if ( _t == trans_3_2 ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
1
;
return msg.equals( _g );
}
if ( _t == trans_3_2_not_wait ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
0
;
return msg.equals( _g );
}
if ( _t == trans_2_1 ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
1
;
return msg.equals( _g );
}
if ( _t == trans_1_1 ) {
if ( !(_msg instanceof Integer) )
return false;
Integer
msg = (Integer) _msg;
Object _g =
1
;
return msg.equals( _g );
}
return super.testMessageOf( _t, _msg );
}