Why does jnp.einsum produce a different result from manual looping?

182 views Asked by At

Let's say I want to compute an inner product along the last dimension of two matrices

a = jax.random.normal(jax.random.PRNGKey(0), shape=(64,16), dtype=jnp.float32)
b = jax.random.normal(jax.random.PRNGKey(1), shape=(64,16), dtype=jnp.float32)

I can do it with jnp.einsum:

inner_prod1 = jnp.einsum('i d, j d -> i j', a, b)

or manually call jnp.dot in a loop:

inner_prod2 = jnp.zeros((64,64))
for i1 in range(64):
  for i2 in range(64):
    inner_prod2 = inner_prod2.at[i1, i2].set(jnp.dot(a[i1], b[i2]))
print(jnp.amax(inner_prod1 - inner_prod2)) # 0.03830552

This is quite a large difference between the two, even if they are mathematically equivalent. What gives?

1

There are 1 answers

0
jakevdp On BEST ANSWER

All operations in floating point accumulate rounding errors, so in general when you express the same operation in two different ways, you should expect the results to not be bitwise-equivalent.

The magnitude of the difference you're seeing is larger than is typical for float32 precision; it makes me think you're probably running your code on TPU, where matrix multiplication is done at lower-precision by default. You can adjust this using the default_matmul_precision configuration; for example like this:

with jax.default_matmul_precision('float32'):
  inner_prod1 = jnp.einsum('i d, j d -> i j', a, b)
  inner_prod2 = jnp.zeros((64,64))
  for i1 in range(64):
    for i2 in range(64):
      inner_prod2 = inner_prod2.at[i1, i2].set(jnp.dot(a[i1], b[i2]))

If you do the computation this way, I suspect you'll probably see a smaller difference more typical of float32 computations, on order 1E-6 or so.