Let's say I want to compute an inner product along the last dimension of two matrices
a = jax.random.normal(jax.random.PRNGKey(0), shape=(64,16), dtype=jnp.float32)
b = jax.random.normal(jax.random.PRNGKey(1), shape=(64,16), dtype=jnp.float32)
I can do it with jnp.einsum
:
inner_prod1 = jnp.einsum('i d, j d -> i j', a, b)
or manually call jnp.dot
in a loop:
inner_prod2 = jnp.zeros((64,64))
for i1 in range(64):
for i2 in range(64):
inner_prod2 = inner_prod2.at[i1, i2].set(jnp.dot(a[i1], b[i2]))
print(jnp.amax(inner_prod1 - inner_prod2)) # 0.03830552
This is quite a large difference between the two, even if they are mathematically equivalent. What gives?
All operations in floating point accumulate rounding errors, so in general when you express the same operation in two different ways, you should expect the results to not be bitwise-equivalent.
The magnitude of the difference you're seeing is larger than is typical for float32 precision; it makes me think you're probably running your code on TPU, where matrix multiplication is done at lower-precision by default. You can adjust this using the
default_matmul_precision
configuration; for example like this:If you do the computation this way, I suspect you'll probably see a smaller difference more typical of float32 computations, on order
1E-6
or so.