I am trying to convert an old code (from 2021) in JAX to current version. The code has following line which I need to convert:
optimizer = flax.optim.Adam(config.lr_init).create(variables)
init_step = state.optimizer.state.step + 1
state = flax.jax_utils.replicate(state)
(_, stats), grad = (jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target))
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
new_state = state.replace(optimizer=new_optimizer)
This question may seem shallow and like I did not try to code myself, or research on google or follow documentation, but there are very less resources available about converting JAX code from FLAX to Optax. Even JAX documentation example does not use train_state handling. I just want to know how I can use the step count for Adam and translate above code for use in optax. Any help or feedback is appreciated.
@flax.struct.dataclass
class TrainState:
optimizer:optax.GradientTransformation
params:flax.core.frozen_dict.FrozenDict
opt_state:optax.OptState
params = jax.random.normal(jax.random.PRNGKey(42), (10,))
optimizer = optax.adam(learning_rate=1e-3)
opt_state_init = optimizer.init(params)
state = utils.TrainState(optimizer, opt_state_init, params)
init_step = state.optimizer.state.step + 1
state = flax.jax_utils.replicate(state)
#dummy loss function
def loss_fn(x):
return jnp.sum(x**2)
(_, stats), grad = (jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target))
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
new_state = state.replace(optimizer=new_optimizer)
. I was expecting the code to run fine with these minimal changes but I get different types of errors in replicate,
init_step and apply_gradients