Exporting NN parameters in lax.scan

72 views Asked by At

I am training a DQN agent using JAX. The DQN framework is based on lax.scan for iterating over episodes. In order to assess the performance of the agent after training, I would like to store the parameters (weights & biases) of the neural network per episode. However, it appears that my computer's memory is not sufficient when I use larger networks (4 hidden layers with 256, 128, 64, 32 nodes).

When I try to simply store the parameters in the output dictionary of lax.scan, I get the following error: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 655360000000 bytes.

Any ideas on how to bypass this issue?

Thank you in advance.

1

There are 1 answers

2
jakevdp On

Rather than attempting to return the array of intermediate parameters, you could log them at runtime using jax.experimental.io_callback or jax.debug.callback.

This would likely incur some performance penalties, but if you add it for testing/debugging purposes only, that shouldn't be an issue.