py_environment 'time_step' doesn't match 'time_step_spec' - but I can't spot the difference

82 views Asked by At

I'm trying to create a custom tf-agents environment for trading. When I try to validate it by calling utils.validate_py_environment(environment, episodes=1), I'm getting a ValueError 'time_step' doesn't match 'time_step_spec' .
I've been trying to spot the difference for a while now but I can't seem to find it. Am I missing something?

Observation Spec

self._observation_spec = {
        # market_history x [o, h, l, c, sma_5, ema_13, volume, trades, rsi, macd_latest, macd_signal]
        'visible_market_data': array_spec.ArraySpec(  
            shape=(self._market_history,11), dtype=np.float32, name='visible_market_data'),
        'current_trade': {
            'trade_type': array_spec.BoundedArraySpec( # 0 = no position, 1 = long, 2 = short
                shape=(), dtype=np.int32, minimum=0, maximum=2, name='trade_type'),
            'open_intervals': array_spec.ArraySpec( 
                shape=(), dtype=np.int32, name='open_intervals'),
        },
        'action_mask': array_spec.BoundedArraySpec( # [do nothing, buy, sell, close]
            shape=(4,), dtype=np.int32, minimum=0, maximum=1, name='action_mask')
    }

How I Maintain My Observation

self._observation = {
      "visible_market_data": self._data[0],
      "current_trade": {"trade_type": np.array(0, dtype=np.int32), "open_intervals": np.array(0, dtype=np.int32)},
      "action_mask": np.array([1,1,1,0], dtype=np.int32), 
    }

The Error Message

ValueError: Given `time_step`: TimeStep(
{'discount': array(1., dtype=float32),
 'observation': {'action_mask': array([1, 1, 1, 0]),
                 'current_trade': {'open_intervals': array(0),
                                   'trade_type': array(0)},
                 'visible_market_data': array([[0.35, 0.42, 0.33, 0.38, 0.41, 0.53, 0.34, 0.27, 0.31, 0.43, 0.5 ],
       [0.38, 0.39, 0.29, 0.31, 0.39, 0.52, 0.42, 0.36, 0.28, 0.42, 0.47],
       [0.31, 0.33, 0.2 , 0.2 , 0.34, 0.5 , 0.62, 0.52, 0.2 , 0.37, 0.44],
       [0.2 , 0.35, 0.2 , 0.35, 0.32, 0.48, 0.32, 0.35, 0.2 , 0.36, 0.4 ],
       [0.35, 0.41, 0.32, 0.39, 0.33, 0.48, 0.33, 0.32, 0.28, 0.29, 0.36],
       [0.37, 0.42, 0.31, 0.34, 0.32, 0.47, 0.29, 0.3 , 0.21, 0.23, 0.3 ],
       [0.34, 0.48, 0.31, 0.4 , 0.34, 0.46, 0.46, 0.46, 0.29, 0.22, 0.25],
       [0.42, 0.53, 0.41, 0.48, 0.39, 0.46, 0.51, 0.29, 0.43, 0.2 , 0.22],
       [0.48, 0.53, 0.42, 0.44, 0.41, 0.47, 0.35, 0.29, 0.39, 0.22, 0.23],
       [0.44, 0.44, 0.38, 0.44, 0.42, 0.45, 0.3 , 0.27, 0.43, 0.25, 0.24],
       [0.44, 0.44, 0.39, 0.39, 0.43, 0.46, 0.4 , 0.21, 0.37, 0.3 , 0.26],
       [0.39, 0.4 , 0.29, 0.33, 0.42, 0.44, 0.39, 0.37, 0.33, 0.29, 0.25],
       [0.33, 0.5 , 0.3 , 0.5 , 0.42, 0.44, 0.46, 0.33, 0.44, 0.32, 0.24],
       [0.49, 0.51, 0.43, 0.5 , 0.43, 0.45, 0.25, 0.28, 0.48, 0.31, 0.25],
       [0.5 , 0.51, 0.46, 0.49, 0.44, 0.45, 0.2 , 0.22, 0.48, 0.3 , 0.25],
       [0.5 , 0.55, 0.44, 0.45, 0.45, 0.46, 0.34, 0.3 , 0.5 , 0.36, 0.27],
       [0.46, 0.55, 0.43, 0.5 , 0.49, 0.44, 0.49, 0.25, 0.62, 0.37, 0.29],
       [0.48, 0.5 , 0.34, 0.37, 0.46, 0.46, 0.37, 0.42, 0.43, 0.43, 0.34],
       [0.37, 0.45, 0.34, 0.37, 0.44, 0.43, 0.45, 0.3 , 0.39, 0.43, 0.37],
       [0.37, 0.38, 0.31, 0.33, 0.41, 0.44, 0.38, 0.29, 0.4 , 0.5 , 0.41],
       [0.33, 0.35, 0.27, 0.31, 0.38, 0.41, 0.4 , 0.3 , 0.33, 0.47, 0.44],
       [0.27, 0.43, 0.27, 0.38, 0.36, 0.42, 0.3 , 0.3 , 0.32, 0.53, 0.47],
       [0.42, 0.42, 0.33, 0.38, 0.36, 0.41, 0.24, 0.22, 0.35, 0.5 , 0.5 ],
       [0.4 , 0.53, 0.38, 0.43, 0.37, 0.41, 0.31, 0.28, 0.39, 0.53, 0.53],
       [0.5 , 0.5 , 0.43, 0.44, 0.39, 0.41, 0.32, 0.22, 0.45, 0.56, 0.59],
       [0.44, 0.49, 0.42, 0.45, 0.42, 0.41, 0.31, 0.25, 0.52, 0.6 , 0.62],
       [0.49, 0.5 , 0.42, 0.44, 0.43, 0.41, 0.26, 0.26, 0.33, 0.63, 0.65],
       [0.44, 0.47, 0.4 , 0.4 , 0.43, 0.42, 0.21, 0.22, 0.29, 0.66, 0.68],
       [0.41, 0.42, 0.39, 0.4 , 0.43, 0.41, 0.2 , 0.2 , 0.29, 0.7 , 0.7 ],
       [0.4 , 0.45, 0.38, 0.4 , 0.42, 0.41, 0.24, 0.21, 0.34, 0.68, 0.68],
       [0.4 , 0.54, 0.4 , 0.54, 0.44, 0.41, 0.39, 0.3 , 0.46, 0.69, 0.68],
       [0.52, 0.77, 0.52, 0.67, 0.48, 0.43, 0.8 , 0.8 , 0.72, 0.68, 0.68],
       [0.67, 0.74, 0.65, 0.73, 0.55, 0.44, 0.44, 0.54, 0.75, 0.68, 0.69],
       [0.72, 0.73, 0.68, 0.68, 0.6 , 0.47, 0.35, 0.38, 0.74, 0.74, 0.72],
       [0.71, 0.8 , 0.68, 0.78, 0.68, 0.47, 0.5 , 0.43, 0.8 , 0.77, 0.71],
       [0.78, 0.8 , 0.69, 0.73, 0.72, 0.51, 0.37, 0.45, 0.71, 0.8 , 0.7 ]])},
 'reward': array(0., dtype=float32),
 'step_type': array(0)}) 

does not match expected `time_step_spec`: 

TimeStep(
{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, 
  maximum=1.0),
 'observation': {'action_mask': BoundedArraySpec(shape=(4,), dtype=dtype('int32'), name='action_mask', minimum=0, maximum=1),
                 'current_trade': {'open_intervals': ArraySpec(shape=(), dtype=dtype('int32'), name='open_intervals'),
                                   'trade_type': BoundedArraySpec(shape=(), dtype=dtype('int32'), name='trade_type', minimum=0, maximum=2)},
                 'visible_market_data': ArraySpec(shape=(36, 11), dtype=dtype('float32'), name='visible_market_data')},
 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
 'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')})
1

There are 1 answers

0
Top Snek On BEST ANSWER

I found my issue. I had to re-initialize the visible_market_data numpy array with dtype=np.float32 explicitly.