Implementing numpy.roll on a flattened array

180 views Asked by At

I am trying to generalize my 2d Ising Model (with periodic boundary conditions) simulator to be Nd as a personal project.

As a quick recap to what that is, please refer to this wiki-page (since Latex rendering is not supported on Stack Overflow) Ising-Wiki. However, knowing the details for this is not strictly necessary for my question.

In any case, to make everything as general as possible, I figured it would be easiest to work with flattened arrays

import numpy as np

# create nd-spin config
Nx, Ny, Nz = 32, 32, 32  # size of each dimension of 3d lattice
spin_config = np.random.choice([-1, 1], (Nx, Ny, Nz)).astype(np.int8)
spin_iter = np.asarray(spin_config.strides) / spin_config.itemsize

The nearest neighbors at each site would be, for site (i, j, k)

#  a[(i-1) % Nz, j, k],  a[(i+1) % Nz, j, k],
#  a[i, (j-1) % Ny, k],  a[i, (j+1) % Ny, k],
#  a[i, j, (k-1) % Nx],  a[i, j, (k+1) % Nk]

The problem with the above is

  • It is a lot of typing
  • Not very generalizable for an nd array

One way out of this is by employing numpy.roll

for j in range(spin_config.shape):
    np.roll(spin_config, -1, axes=j)
    # do stuff
    np.roll(spin_config, 1, axes=j)
    # do stuff

The problem with this is that, if I want to take advantage of the numba jit compiler, the jit compiler does not have support for np.roll with an axes argument see here

If I flatten my array, then the nearest neighbors can be constructed in the following way

# for a spin s at location niter in spin_config.ravel()
s_nn = np.zeros(2*len(spin_iter))       # sites for nearest spin on array
for j, sval in enumerate(spin_iter):
    s_nn[2*j: 2*(j+1)] = [(niter-sval) % spin_config.size, (niter+sval) % spin_config.size]
spin_nn = spin_config.ravel()[s_nn]     # sub array containing nearest neighbors

Unfortunately I am not sure how to deal with edge cases where any of my site indices (i, j, k) are Nz-1, Ny-1, or Nx-1 respectively.

Do you guys have a clever way of rolling around an axes in an unravelled numpy array?

1

There are 1 answers

0
firest On

I think I have a decent answer. For starters, a flattened nd array has indices associated with the following index formula:

# index = (Ny*Nx)*nz + (Nx)*ny + nx
# spin_config.strides / spin_config.itemsize = (Ny*Nx, Nx, 1)

Issues near the boundary occur whenever nj = Nj-1 for j=x, y, z. This is useful because now, testing whether or not I am near a boundary is if a neighboring index divides Nj-1. So here is a new implementation

spin_shape = spin_config.shape
s_nn = np.zeros(2*len(spin_iter))
for j, (sval, s_shape) in enumerate(zip(spin_iter, spin_shape)):
    shift_down = s_shape if (niter-sval) % s_shape-1 else (niter-sval)
    shift_up = s_shape if (niter+sval) % s_shape-1 else (niter+sval)
    s_nn[2*j: 2*(j+1)] = [shift_down, shift_up]
spin_nn = spin_config.ravel()[s_nn]

Please let me know if I made some kind of horrible mistake here.