Optax equivalent of flax's optim.Adam.apply_gradient()

103 views Asked by At

I am trying to convert an old code (from 2021) in JAX to current version. The code has following line which I need to convert:

optimizer = flax.optim.Adam(config.lr_init).create(variables)
init_step = state.optimizer.state.step + 1
state = flax.jax_utils.replicate(state)
(_, stats), grad = (jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target))
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
new_state = state.replace(optimizer=new_optimizer)

This question may seem shallow and like I did not try to code myself, or research on google or follow documentation, but there are very less resources available about converting JAX code from FLAX to Optax. Even JAX documentation example does not use train_state handling. I just want to know how I can use the step count for Adam and translate above code for use in optax. Any help or feedback is appreciated.

@flax.struct.dataclass
class TrainState:
    optimizer:optax.GradientTransformation
    params:flax.core.frozen_dict.FrozenDict
    opt_state:optax.OptState

params = jax.random.normal(jax.random.PRNGKey(42), (10,))
optimizer = optax.adam(learning_rate=1e-3)
opt_state_init = optimizer.init(params)
state = utils.TrainState(optimizer, opt_state_init, params)    
init_step = state.optimizer.state.step + 1
state = flax.jax_utils.replicate(state)

#dummy loss function
def loss_fn(x):
    return jnp.sum(x**2)

(_, stats), grad = (jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target))
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
new_state = state.replace(optimizer=new_optimizer)

. I was expecting the code to run fine with these minimal changes but I get different types of errors in replicate,

init_step and apply_gradients
0

There are 0 answers