Speed up flood-fill algorithm to determine inundated areas in a raster image

124 views Asked by At

I have a raster image that contains elevations (DEM). I also have some points with specific elevations associated (something like Point(x, y, h) where h is an elevation). I'm trying to, starting from the location of the given point, find all cells CONNECTED in 8 directions so that their elevation is under the elevation of the given point (i.e under h).

As an example, lets say we have the image M and a Point(2,2,2.1). Then, I would like to get the result from matrix MS.

M = np.array([
    [2, 6, 6, 6, 6],
    [6, 2, 6, 6, 6],
    [6, 6, 2, 6, 6],
    [6, 6, 6, 6, 6],
    [6, 6, 6, 6, 6]
])

MS = np.array([
    [True, False, False, False, False],
    [False, True, False, False, False],
    [False, False, True, False, False],
    [False, False, False, False, False],
    [False, False, False, False, False]
], dtype=bool)

For this I implemented a modified version of the flood-fill algorithm:

def flood_fill(image, x, y, threshold):
    """
    
    Calculates inundated cells. Cells get inundated if the threshold value is higher
    than the starting point elevation.

    Parameters
    ----------
    image: numpy.ndarray
        array of elevations
    cx : int
        starting coordinate x in pixel coordinates
    cy : int
        starting coordinate y in pixel coordinates
    threshold : float
        threshold elevation.

    Returns
    -------
    filled : numpy.ndarray
        array with boolean values where True means the cells is flooded

    """
    
    filled = np.zeros_like(image, dtype=bool)
    toFill = []
    toFill.append((x,y))
    
    while not len(toFill) == 0:
        (x,y) = toFill.pop()
        
        if x < 0 or y < 0 or x >= image.shape[0] or y >= image.shape[1]:
            continue
        
        if filled[x, y] or image[x, y] > threshold:
            continue
        
        filled[x, y] = 1
        toFill.append((x - 1, y))
        toFill.append((x + 1, y))
        toFill.append((x, y - 1))
        toFill.append((x, y + 1))
        toFill.append((x-1, y - 1))
        toFill.append((x-1, y + 1))
        toFill.append((x+1, y - 1))
        toFill.append((x+1, y + 1))
        
    return filled

While it works, is kind of slow for the purposes I'm using it. I also tried with a recursive version, but it explodes when the image is bigger than a few thousand pixels (stack overflow error). Is there a way to make it faster? Or do you know a python library that implements something like this?

As an example for an image of 5000×3000 is the method is taking 15 seconds, which is too much because I have several points.

I uploaded the raster to google drive since it is heavy: RASTER_DEM

3

There are 3 answers

2
Simon Lundberg On BEST ANSWER

Just use the flood function from skimage.

This takes around 0.15 seconds per starting point for a 5000x3000 image on my machine, plus reading and writing the file.

from skimage.segmentation import flood
from skimage.io import imread, imsave

# Load your image; doesn't really matter how as long as it's an ndarray.
# Note that this assumes an rgb image.
im = imread("path")

# Some starting point from which to flood
x, y = 351, 137

# Get the elevation by just grabbing the first color channel at `x, y`
elevation = im[y, x, 0]

# Create a boolean ndarray that's True where the image is equal or less than
# the starting point, and False everywhere else. We only look at the first
# color channel because the image is grayscale, hence the "[:,:,0]".
is_lower = im[:,:,0] <= elevation

# Flood fill using skimage.segmentation.flood
mask = flood(is_lower, (y, x))

# Now we can modify all the pixels in the image based
# on the floodfill mask however we want using the mask
im[mask] = elevation

# Save the image
imsave("path", im)
2
Musabbir Arrafi On

Here's your solution that will divide the raster into patches, then do your modified flood-fill on each patch and at the end stitch the image back together. It takes around 24-30 seconds on my PC.

import numpy as np
import concurrent.futures
import rasterio
import matplotlib.pyplot as plt
import time

def process_patch(image_patch, threshold, patch_idx, patch_size, image_size):
    # Assuming flood fill starts at (0, 0) in each patch
    filled_patch = flood_fill(image_patch, 0, 0, threshold)  
    # Calculate the position of the patch in the full image
    patch_x, patch_y = patch_idx
    x_start = patch_x * patch_size
    y_start = patch_y * patch_size
    x_end = min(x_start + patch_size, image_size[0])
    y_end = min(y_start + patch_size, image_size[1])

    return filled_patch[:x_end - x_start, :y_end - y_start], patch_idx

def flood_fill(image, x, y, threshold):
    """
    
    Calculates inundated cells. Cells get inundated if the threshold value is higher
    than the starting point elevation.

    Parameters
    ----------
    image: numpy.ndarray
        array of elevations
    cx : int
        starting coordinate x in pixel coordinates
    cy : int
        starting coordinate y in pixel coordinates
    threshold : float
        threshold elevation.

    Returns
    -------
    filled : numpy.ndarray
        array with boolean values where True means the cells is flooded

    """
    
    filled = np.zeros_like(image, dtype=bool)
    toFill = []
    toFill.append((x,y))
    
    while not len(toFill) == 0:
        (x,y) = toFill.pop()
        
        if x < 0 or y < 0 or x >= image.shape[0] or y >= image.shape[1]:
            continue
        
        if filled[x, y] or image[x, y] > threshold:
            continue
        
        filled[x, y] = 1
        toFill.append((x - 1, y))
        toFill.append((x + 1, y))
        toFill.append((x, y - 1))
        toFill.append((x, y + 1))
        toFill.append((x-1, y - 1))
        toFill.append((x-1, y + 1))
        toFill.append((x+1, y - 1))
        toFill.append((x+1, y + 1))
        
    return filled

# Define image parameters
image_path = 'clipped_dem_romboutswervepolder.tif'

# choose your desired threshold
threshold_value = 1.05

# Adjust based on memory constraints
patch_size = 1024  

# Read the dem
with rasterio.open(image_path) as dataset:
    tif_array = dataset.read(1)
    image_size = tif_array.shape
    print("loaded raster-image size: ", image_size)
print("Unique values in input array:", np.unique(tif_array))

# Create an output array
output_array = np.zeros_like(tif_array, dtype=bool)


start = time.time()
# Process patches using ProcessPoolExecutor
with concurrent.futures.ProcessPoolExecutor() as executor:
    futures = []

    num_patches_x = image_size[0] // patch_size + 1
    num_patches_y = image_size[1] // patch_size + 1

    for px in range(num_patches_x):
        for py in range(num_patches_y):
            patch = tif_array[px * patch_size:(px + 1) * patch_size, py * patch_size:(py + 1) * patch_size]
            futures.append(executor.submit(process_patch, patch, threshold_value, (px, py), patch_size, image_size))

    # Retrieve results and place them in the output array
    for future in concurrent.futures.as_completed(futures):
        filled_patch, patch_idx = future.result()
        patch_x, patch_y = patch_idx
        x_start = patch_x * patch_size
        y_start = patch_y * patch_size
        x_end = min(x_start + patch_size, image_size[0])
        y_end = min(y_start + patch_size, image_size[1])
        output_array[x_start:x_end, y_start:y_end] = filled_patch
        
print("time taken: ", time.time() - start)

# Visualize the result
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

axs[0].imshow(tif_array, cmap='gray')
axs[0].set_title('Input Data')

axs[1].imshow(output_array, cmap='gray')
axs[1].set_title('Output Data')
plt.show()

Output:

loaded raster-image size:  (5604, 3996)
Unique values in input array: [-999.           1.09         1.0965841 ...   10.15        10.16
   10.18     ]
time taken:  24.95163631439209

Output Image:

enter image description here

You need to set your threshold value properly according to your dem, which will generate more appropriate output. Also, set the patch size according to your memory constraint.

0
Matt Timmermans On

You flood fill implementation is not bad. You could make a little faster by careful optimization (or a lot faster by using a different language)...

But your biggest improvement can be had by combining the flood fills for each point into one:

  1. Process the source points in order of decreasing h
  2. Don't reset the filled array after flooding from each point.

This takes advantage of the fact that you don't need to flood a point if it's already been flooded with a greater or equal threshold value. Each pixel will be processed at most once across all of your source points.