for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is about 20 times slower than the pure jax implementation. What am I doing wrong here?
import time
import jax.numpy as np
from jax import random, jit, vmap, jacfwd
from jax.nn import sigmoid, softplus
import jax
from flax import linen as nn
import numpy as np
from typing import Sequence
def MLP(layers):
def init(rng_key):
def init_layer(key, d_in, d_out):
k1, k2 = random.split(key)
W = random.normal(k1, (d_in, d_out))
b = random.normal(k2, (d_out,))
return W, b
key, *keys = random.split(rng_key, len(layers))
params = list(map(init_layer, keys, layers[:-1], layers[1:]))
return params
def apply(params, inputs):
for W, b in params[:-1]:
outputs = np.dot(inputs, W) + b
inputs = sigmoid(outputs)
W, b = params[-1]
outputs = np.dot(inputs, W) + b
return outputs
return init, apply
class FlaxNet(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x_in):
x = nn.Dense(self.features[0], use_bias=False)(x_in)
x = sigmoid(x)
for feat in self.features[1:-1]:
x = nn.Dense(feat, use_bias=False)(x)
x = sigmoid(x)
x = nn.Dense(self.features[-1], use_bias=False)(x)
return x
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
D = np.pi
layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))
inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
_ = net_apply(params, inputs)
inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time()-t1)
#############################################################################
model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(rng, inputs)
_ = model.apply(params, inputs)
t1 = time.time()
outputs = model.apply(params, inputs)
print('TIME FLAX ', time.time()-t1)
Which produces the output:
TIME JAX 0.0033071041107177734
TIME FLAX 0.08791708946228027
You'll just need to omit additions lines :)
New times: