How can I apply a function to each channel of a three-channel image numpy array without using a for-loop?

205 views Asked by At

Suppose that I have an image of three channels, and its corresponding numpy array is of the shape (H,W,3). Now I create a function which does something to a 2D numpy array. How can I apply the funtion to each channel of the image without using a for-loop since it is slow?

I'm looking for a function similar to apply in pandas. I've tried np.apply_over_axes and np.apply_along_axis, but both results are not correct.

In the code below, I hope to add a gaussian noise seperately to each channel of the image.

test_img = np.array(train_data[0][0])
test_img.shape
# (198, 180, 3)

def add_gaussian_noise(img_, amplitude=1.0, mean=0.0, variance=1.0):
    img_ = img_.astype('float32')
    # print(img_.shape)

    # assertion failed
    # assert img_.shape == (test_img.shape[0], test_img.shape[1])
    img_ = img_ + amplitude * np.random.normal(mean, variance, img_.shape)
    img_ = np.clip(img_, 0.0, 255.0)
    return img_

test_img = np.apply_along_axis(arr= test_img, func1d= add_gaussian_noise, axis = 2)

print(test_img)

I'm anticipating that img_.shape == (test_img.shape[0], test_img.shape[1]), but it turns out that img_.shape == (3,) in my code.

In this case, the optimization won't make a big difference, but I am looking for a more generalized solution so that in other cases, e.g. I have to apply the same function to hundreds of grayscale pictures, the code works more efficiently than the for-loop.

1

There are 1 answers

1
David On

What you're looking for is broadcasting. This allows you to run operations on multiple slices of an array. Without more information, it's impossible to say if this will work in your specific case, but this is the normal way to do this sort of thing using Numpy.

Many Numpy functions support broadcasting (you'll see references to this in the documentation). This is frequently done on the early matrix axes, so you may need to transpose your channels axis to the start of the matrix.

Edit: As mentioned in the comments, for a small number of repetitions, looping in Python is a perfectly reasonable approach. But I assume you're looking for a more general case.