Optimizing Kadane's Algorithm for numpy

678 views Asked by At

The standard way to find the maximum subarray is Kadene's algorithm. If the input is a large numpy array, is there anything faster than the native python implementation?

import timeit

setup = '''
import random
import numpy as np

def max_subarray(A):
    max_so_far = max_ending_here = 0
    for x in A:
        max_ending_here = max(0, max_ending_here + x)
        max_so_far      = max(max_so_far, max_ending_here)
    return max_so_far

B = np.random.randint(-100,100,size=100000)
'''

print min(timeit.Timer('max_subarray(B)',setup=setup).repeat(5, 100))
1

There are 1 answers

5
Wolph On BEST ANSWER

Little test with Cython in an iPython notebook (because of that no timeit, doesn't appear to work with the %%cython environment :)

Original version:

import numpy as np

B = np.random.randint(-100,100,size=100000)

def max_subarray(A):
    max_so_far = max_ending_here = 0
    for x in A:
        max_ending_here = max(0, max_ending_here + x)
        max_so_far      = max(max_so_far, max_ending_here)
    return max_so_far

import time

measurements = np.zeros(100, dtype='float')
for i in range(measurements.size):
    a = time.time()
    max_subarray(B)
    measurements[i] = time.time() - a

print 'non-c:', measurements.min(), measurements.max(), measurements.mean()

Cython version:

%%cython

import numpy as np
cimport numpy as np

B = np.random.randint(-100,100,size=100000)

DTYPE = np.int
ctypedef np.int_t DTYPE_t

cdef DTYPE_t c_max_subarray(np.ndarray A):
    # Type checking for safety
    assert A.dtype == DTYPE

    cdef DTYPE_t max_so_far = 0, max_ending_here = 0, x = 0
    for x in A:
        max_ending_here = max(0, max_ending_here + x)
        max_so_far      = max(max_so_far, max_ending_here)
    return max_so_far

import time

measurements = np.zeros(100, dtype='float')
for i in range(measurements.size):
    a = time.time()
    c_max_subarray(B)
    measurements[i] = time.time() - a

print 'Cython:', measurements.min(), measurements.max(), measurements.mean()

Results:

  • Cython: 0.00420188903809 0.00658392906189 0.00474049091339
  • non-c: 0.0485298633575 0.0644249916077 0.0522959709167

Definitely a notable increase without too much effort :)