Discrete Difference on Sharded JAX Arrays

72 views Asked by At

It is possible to speed-up discrete difference calculations on sharded (across CPU cores) JAX arrays?

The following is an attempt to follow the documentation of JAX automatic parallelism. It uses AOT compilation of essentially a JAX NumPy API call for different sharded arrays. Several device meshes are used to test partitioning across and along the difference direction. The measured run times show that there is no benefit to sharding this way.

os.environ["XLA_FLAGS"] = (
    f'--xla_force_host_platform_device_count=8'
)

import jax as jx
import jax.numpy as jnp
import jax.experimental.mesh_utils as jxm
import jax.sharding as jsh

def calc_fd_kernel(x):
    # Calculate 1st-order fd along the first axis
    return jnp.diff(
        x, 1, axis=0, prepend=jnp.zeros((1, *x.shape[1:]))
    )

def make_fd(shape, shardings):
    # Compiled fd kernel factory
    return jx.jit(
        calc_fd_kernel,
        in_shardings=shardings,
        out_shardings=shardings,
    ).lower(
        jx.ShapeDtypeStruct(shape, jnp.dtype('f8'))
    ).compile()

# Create 2D array to partition
n = 2**12
shape = (n,n,)

x = jx.random.normal(jx.random.PRNGKey(0), shape, dtype='f8')

shardings_test = {
    (1, 1,) : jsh.PositionalSharding(jxm.create_device_mesh((1,), devices=jx.devices("cpu")[:1])).reshape(1, 1),
    (8, 1,) : jsh.PositionalSharding(jxm.create_device_mesh((8,), devices=jx.devices("cpu")[:8])).reshape(8, 1),
    (1, 8,) : jsh.PositionalSharding(jxm.create_device_mesh((8,), devices=jx.devices("cpu")[:8])).reshape(1, 8),
}

x_test = {
    mesh : jx.device_put(x, shardings)
    for mesh, shardings in shardings_test.items()
}

calc_fd_test = {
    mesh : make_fd(shape, shardings)
    for mesh, shardings in shardings_test.items()
}

for x_mesh, calc_fd_mesh in zip(x_test.values(), calc_fd_test.values()):
    %timeit calc_fd_mesh(x_mesh).block_until_ready()

# (1, 1)
# 48.9 ms ± 414 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# (8, 1)
# 977 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# (1, 8)
# 48.3 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

1

There are 1 answers

2
Kavy Gupta On

In JAX, sharding is a mechanism to distribute computation across multiple devices (such as GPUs or TPUs) by splitting an array into smaller chunks called shards. JAX provides the sharded_device_array function for creating sharded arrays. When you perform operations on sharded arrays, JAX internally handles the distribution of computation.