DqnAgent.train(): Loss is inf or nan

64 views Asked by At

Problem:

I am trying to use tf agent to train a DQN agent by following this tutorial while using my custom PyEnvironment, and somehow given the following code:

for i in range(num_iterations):

  # Collect a few steps and save to the replay buffer.
  time_step, _ = collect_driver.run(time_step)
  # print(time_step)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  # print(i)
  # print(experience)
  train_loss = agent.train(experience)  # <-- error on this line
  train_loss = train_loss.loss
  # print(train_loss)
  # break

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
    train_checkpointer.save(global_step)

I always got the error

Loss is inf or nan : Tensor had NaN values
[[{{node CheckNumerics}}]] [Op:__inference_train_3344]

after ~2x train loop

What I've tried:

I looked into these similar questions:

https://github.com/tensorflow/agents/issues/589

Tensorflow NaN bug?

and I tried most of the method mentioned inside but none of them works. I'm sure that the output of my custom environment contains no nan or inf input.

I tried to lower the learning rate significantly to like 1e-12.

I tried to print out the loss value(s) calculated from tf agent library code, but they're hidden behind the Tensors so not much could be found. Only found that the error is raised from tf.debugging.check_numerics() inside _train() in DqnAgent, which the checking target is a Tensor value calculated from tf_agent.utils.common.aggregate_losses() inside _loss()

My code is nearly the same as the tutorial mentioned above, except I've changed the environment used and changed the fc_layer_params to (512, 256, 128, 64, 32, 16).

What could I be missing here?

0

There are 0 answers