I am training a DQN agent using JAX. The DQN framework is based on lax.scan for iterating over episodes. In order to assess the performance of the agent after training, I would like to store the parameters (weights & biases) of the neural network per episode. However, it appears that my computer's memory is not sufficient when I use larger networks (4 hidden layers with 256, 128, 64, 32 nodes).
When I try to simply store the parameters in the output dictionary of lax.scan, I get the following error:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 655360000000 bytes.
Any ideas on how to bypass this issue?
Thank you in advance.
 
                        
Rather than attempting to return the array of intermediate parameters, you could log them at runtime using
jax.experimental.io_callbackorjax.debug.callback.This would likely incur some performance penalties, but if you add it for testing/debugging purposes only, that shouldn't be an issue.