Based on a boolean flag, I want to either 1) call or 2) not call the following function (which operates on a flax linen module).
def true_fn(module, carry, inputs):
carry, output = flax.linen.scan(function_to_scan_over, variable_broadcast = 'params', split_rngs = {"params": False})(module, carry, inputs)
return carry
I have tried to use flax.linen.cond as follows
carry = flax.linen.cond(pred, true_fn, false_fn, module, carry, inputs)
where false_fn is the identity function with respect to carry:
def false_fn(module, carry, inputs):
return carry
But when I do this I get an error message saying
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have same type structure
The output type structure is the same. I assume, based on the flax.linen.cond documentation, that I am getting the error message because true_fn creates variables that are not created in false_fn (this is a problem for flax.linen.cond but not jax.lax.cond). The module I am passing to true_fn always gets called later in my code, if that matters.
Any advice on what I should do here?
edit: MWE added:
from flax import linen as nn
import jax
class MLP(nn.Module):
dim: int
def setup(self):
self.dense = nn.Dense(self.dim)
def __call__(self, x):
return self.dense(x)
class Dummy(nn.Module):
dim: int
def setup(self):
self.mlp = MLP(self.dim)
def __call__(self, x):
def true_fn(module, x):
return module(x)
def false_fn(module, x):
return x
y = nn.cond(True, true_fn, false_fn, self.mlp, x)
return y + self.mlp(x)
dim_in = 3
dim_out = dim_in
dummy = Dummy(dim_in)
init_vars = dummy.init(x = jax.numpy.ones((dim_in,)), rngs = {'params': jax.random.PRNGKey(0)})
dummy.apply(init_vars, x = jax.numpy.ones((dim_in,)))
I am using flax 0.7.2 and jax/jaxlib 0.4.13.