Flax much slower than pure Jax for neural nentworks?

591 views Asked by At

for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is about 20 times slower than the pure jax implementation. What am I doing wrong here?

import time
import jax.numpy as np
from jax import random, jit, vmap, jacfwd
from jax.nn import sigmoid, softplus
import jax
from flax import linen as nn  
import numpy as np
from typing import Sequence

def MLP(layers):
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            W = random.normal(k1, (d_in, d_out))
            b = random.normal(k2, (d_out,))
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-1], layers[1:]))
        return params

    def apply(params, inputs):
        for W, b in params[:-1]:
            outputs = np.dot(inputs, W) + b
            inputs = sigmoid(outputs)
        W, b = params[-1]
        outputs = np.dot(inputs, W) + b
        return outputs
    return init, apply


class FlaxNet(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x_in):
        x = nn.Dense(self.features[0], use_bias=False)(x_in)
        x = sigmoid(x)

        for feat in self.features[1:-1]:
            x = nn.Dense(feat, use_bias=False)(x)
            x = sigmoid(x)
        x = nn.Dense(self.features[-1], use_bias=False)(x)

        return x


rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
D = np.pi

layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))

inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
_ = net_apply(params, inputs)

inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time()-t1)

#############################################################################

model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(rng, inputs)

_ = model.apply(params, inputs)
t1 = time.time()
outputs = model.apply(params, inputs)
print('TIME FLAX ', time.time()-t1)

Which produces the output:

TIME JAX  0.0033071041107177734
TIME FLAX  0.08791708946228027
1

There are 1 answers

0
Mani Shemi On

You'll just need to omit additions lines :)

import time
import jax.numpy as jnp
from jax import random
from jax.nn import sigmoid
import jax
from flax import linen as nn
from typing import Sequence


def MLP(layers):
    def init(rng_key):
       def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          W = random.normal(k1, (d_in, d_out))
          b = random.normal(k2, (d_out,))
          return W, b

       key, *keys = random.split(rng_key, len(layers))
       params = list(map(init_layer, keys, layers[:-1], layers[1:]))
       return params

    def apply(params, inputs):
        for W, b in params[:-1]:
          outputs = jnp.dot(inputs, W) + b
          inputs = sigmoid(outputs)
        W, b = params[-1]
        outputs = jnp.dot(inputs, W) + b
        return outputs
     

    return init, apply

class FlaxNet(nn.Module):
     features: Sequence[int]

     @nn.compact
     def __call__(self, x_in):
        x = nn.Dense(self.features[0], use_bias=False)(x_in)
        x = sigmoid(x)

        for feat in self.features[1:-1]:
           x = nn.Dense(feat, use_bias=False)(x)
           x = sigmoid(x)
        x = nn.Dense(self.features[-1], use_bias=False)(x)

        return x

D = jnp.pi
layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))

inputs = jax.random.uniform(random.PRNGKey(1), minval=-D, maxval=D, 
shape= (128, 1))

t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time() - t1)

model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(random.PRNGKey(0), inputs)

t1 = time.time()
_ = model.apply(params, inputs)
print('TIME FLAX ', time.time() - t1)

New times:

TIME JAX  0.854097843170166
TIME FLAX  0.04825115203857422