Train multiple NN in parallel with Jax

331 views Asked by At

I would like to train the same neural network model in parallel instead of using a loop in Jax.

We construct the model as:

class Model(flax.linen.Module):
    ........
   return out


# create a list that contains n times the NN model
models = [Model() for i in range(n)]

          
models_list = []

# loop through all the NNs
for f in models: 
   models_list.append(Model(x)) # x is the input of the NN
            
# stack the outputs of the list
y = jnp.stack(models_list, axis=1)

Instead of using the for loop how can we parallelize the jax model so that the n functions are optimized in parallel?

I tried using vmap to map all the elements in the models list onto the input batches, but I keep getting an error.

import jax.numpy as jnp
from jax import random
import jax


class FFN(nn.Module):
    alpha: int = 1

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(features=self.alpha*x.shape[0])(x)
        y = nn.relu(y)
        return jnp.sum(y, axis=-1)


# Define random input tensor
key = random.PRNGKey(0)
batch_size = 16
input_shape = (32,)
x = random.normal(key, (batch_size,) + input_shape)

# Initialize model
model = FFN()

# Apply model to input
params = model.init(key, x)
output = model.apply(params, x)


# list of models

models = [FFN() for i in range(3)]

# loop through models list
for net in models:
    net.apply(params, x)
    


# run models in parallel
def model_apply(model, n):
    return model.apply(params, n)
            
            
out = jax.vmap(model_apply, in_axes=(0,None))(models, x)

ERROR:

ValueError: vmap was requested to map its argument along axis 0, which implies that 
its rank should be at least 1, but is only 0 (its shape is ())
0

There are 0 answers