I am using Jax for Reinforcement Learning with DQN. For a step in the environment, I am using two alternatives regarding the generation of random seeds. These two approaches lead to significantly different results. Why does this happen? Which approach aligns with the proper use of random seeds in Jax?
The first one is indicated by the Jax documentation:
rng, step_rng = jax.random.split(rng)
next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)
The second one is according to an example of purejaxrl:
rng, _rng = jax.random.split(rng)
_rng, step_rng = jax.random.split(_rng)
next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)
EDIT: This is a short piece of code used in my environment step function, which makes a step in a subenvironment.
state, env_state, _, _, rng = runner
q = self.agent_nn.apply(self.agent_params, state)
action = jnp.argmax(q)
rng, rng_step = jax.random.split(rng)
next_state, next_env_state, reward, terminated, info =
self.env.step(rng_step, env_state, action, self.env_params)
SECOND EDIT: use of random seed within lax.scan
def train(rng):
rng, network_init_rng = jax.random.split(rng)
network = q_network(env.action_space(env_params).n)
init_x = jnp.zeros((1, config["STATE_SIZE"]))
network_params = network.init(network_init_rng, init_x)
training = TrainState.create(apply_fn=network.apply,
params=network_params,
target_params=network_params,
tx=tx)
rng, _rng = jax.random.split(rng)
_rng, reset_rng = jax.random.split(_rng)
state, env_state = env.reset(reset_rng, env_params)
@jit
@scan_tqdm(config["TOTAL_STEPS"])
def _run_step(runner, i_step):
training, env_state, state, rng, buffer_state, i_episode = runner
rng, *_rng = jax.random.split(rng, 3)
random_q_rng, random_number_rng = _rng
q_state = network.apply(training.params, state)
random_number = jax.random.uniform(random_number_rng, minval=0, maxval=1, shape=(1,))
exploitation = jnp.greater(random_number, config["EPS"])
action = jnp.where(exploitation, jnp.argmax(q_state, 1), random_action)
rng, _rng = jax.random.split(rng)
_rng, step_rng = jax.random.split(_rng)
next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)
return runner
rng, _rng = jax.random.split(rng)
runner = (training, env_state, state, _rng, buffer_state, 0)
runner, metrics = lax.scan(_run_step, runner, jnp.arange(config["TOTAL_STEPS"]), config["TOTAL_STEPS"])
return {"runner": runner}