What is the best method for handling memory when compiling an accumulation in JAX, such as jax.lax.scan
, where a full buffer is excessive?
The following is a geometric progression example. The temptation is to recognise the accumulation only depends on an input size and implement accordingly
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size)(size,x0,a)
However, attempting to use jax.jit
will predictably result in a ConcretizationTypeError
.
The correct way is to pass an argument where the buffer already exists.
def calc_gp_array(array,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,array)
return jnp.concatenate((x0[None],x))
array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_array)(array,x0,a)
My concern is that there is a lot of allocated memory not being utilised (or is it?). Is there a more memory efficient approach to this example, or is the allocated memory being used somehow?
EDIT: Incorporating the comments of @jakevdp, treating the function as main (single call - include compile and exclude caching), and profiling resulted it
%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB
%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB
%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB
Less granular results would require line profiling the jit code (not sure how this could be done).
Sequentially initialising the array and then calling jax.jit
appears to save memory
%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB
%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB
The first version will work if you mark the size argument as static and pass a hashable value:
I think this may be slightly more memory efficient than pre-allocating the array as in your second example, though it would be worth benchmarking if that's important.
Also, if you're doing this sort of operation on GPU, you'll likely find built-in accumulations like
jnp.cumprod
to be much more performant. I believe this is more or less equivalent to your scan-based function: