How to use Triton server "ensemble model" with 1:N input/output to create patches from large image?

1.4k views Asked by At

I am trying to feed a very large image into Triton server. I need to divide the input image into patches and feed the patches one by one into a tensorflow model. The image has a variable size, so the number of patches N is variable for each call.

I think a Triton ensemble model that calls the following steps would do the job:

  1. A python model (pre-process) to create the patches
  2. The segmentation model
  3. Finally another python model (post-process) to merge the output patches into a big output mask

However, for this, I would have to write a config. pbtxt file with 1:N and N:1 relation, meaning the ensemble scheduler needs to call the 2nd step multiple times and the 3rd once with the aggregated output.

Is this possible, or do I need to use some other technique?

1

There are 1 answers

2
Innat On

Disclaimer

The below answer isn't the actual solution to the above question. I misunderstood the above query. But I'm leaving this response in case of future readers find it useful.


Input

import cv2 
import matplotlib.pyplot as plt

input_img = cv2.imread('/content/2.jpeg')
print(input_img.shape) # (719, 640, 3)
plt.imshow(input_img) 

Slice and Stitch

The following functionality is adopted from here. More details and discussion can be found here.. Apart from the original code, we bring together the necessary functionality and put them in a single class (ImageSliceRejoin).

# ref: https://github.com/idealo/image-super-resolution
class ImageSliceRejoin:
    def pad_patch(self, image_patch, padding_size, channel_last=True):
        """ Pads image_patch with padding_size edge values. """
        if channel_last:
            return np.pad(
                image_patch,
                ((padding_size, padding_size), 
                (padding_size, padding_size), (0, 0)),
                'edge',
            )
        else:
            return np.pad(
                image_patch,
                ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
                'edge',
            )

    # function to split the image into patches        
    def split_image_into_overlapping_patches(self, image_array, patch_size, padding_size=2):
        """ Splits the image into partially overlapping patches.
        The patches overlap by padding_size pixels.
        Pads the image twice:
            - first to have a size multiple of the patch size,
            - then to have equal padding at the borders.
        Args:
            image_array: numpy array of the input image.
            patch_size: size of the patches from the original image (without padding).
            padding_size: size of the overlapping area.
        """
        xmax, ymax, _ = image_array.shape
        x_remainder = xmax % patch_size
        y_remainder = ymax % patch_size
        
        # modulo here is to avoid extending of patch_size instead of 0
        x_extend = (patch_size - x_remainder) % patch_size
        y_extend = (patch_size - y_remainder) % patch_size
        
        # make sure the image is divisible into regular patches
        extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
        
        # add padding around the image to simplify computations
        padded_image = self.pad_patch(extended_image, padding_size, channel_last=True)
        
        xmax, ymax, _ = padded_image.shape
        patches = []
        
        x_lefts = range(padding_size, xmax - padding_size, patch_size)
        y_tops = range(padding_size, ymax - padding_size, patch_size)
        
        for x in x_lefts:
            for y in y_tops:
                x_left = x - padding_size
                y_top = y - padding_size
                x_right = x + patch_size + padding_size
                y_bottom = y + patch_size + padding_size
                patch = padded_image[x_left:x_right, y_top:y_bottom, :]
                patches.append(patch)
        
        return np.array(patches), padded_image.shape

    # joing the patches 
    def stich_together(self, patches, padded_image_shape, target_shape, padding_size=4):
        """ Reconstruct the image from overlapping patches.
        After scaling, shapes and padding should be scaled too.
        Args:
            patches: patches obtained with split_image_into_overlapping_patches
            padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
            target_shape: shape of the final image
            padding_size: size of the overlapping area.
        """
        xmax, ymax, _ = padded_image_shape

        # unpad patches
        patches = patches[:, padding_size:-padding_size, padding_size:-padding_size, :]

        patch_size = patches.shape[1]
        n_patches_per_row = ymax // patch_size
        complete_image = np.zeros((xmax, ymax, 3))

        row = -1
        col = 0
        for i in range(len(patches)):
            if i % n_patches_per_row == 0:
                row += 1
                col = 0
            complete_image[
            row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
            ] = patches[i]
            col += 1
        return complete_image[0: target_shape[0], 0: target_shape[1], :]

Initiate Slicing

import numpy as np 

isr = ImageSliceRejoin()
padding_size = 1

patches, p_shape = isr.split_image_into_overlapping_patches(
    input_img, 
    patch_size=220, 
    padding_size=padding_size
)

patches.shape, p_shape, input_img.shape
((12, 222, 222, 3), (882, 662, 3), (719, 640, 3))

Verify

n = np.ceil(patches.shape[0] / 2)
plt.figure(figsize=(20, 20))
patch_size = patches.shape[1]

for i in range(patches.shape[0]):
    patch = patches[i] 
    ax = plt.subplot(n, n, i + 1)
    patch_img = np.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.astype("uint8"))
    plt.axis("off")

enter image description here

Inference

I'm using the Image-Super-Resolution model for demonstration.

# import model
from ISR.models import RDN
model = RDN(weights='psnr-small')

# number of patches that will pass to model for inference: 
# here, batch_size < len(patches)
batch_size = 2

for i in range(0, len(patches), batch_size):
    # get some patches
    batch = patches[i: i + batch_size]

    # pass them to model to give patches output 
    batch = model.model.predict(batch)

    # save the output patches 
    if i == 0:
        collect = batch
    else:
        collect = np.append(collect, batch, axis=0)

Now, the collect holds the output of each patch from the model.

patches.shape, collect.shape
((12, 222, 222, 3), (12, 444, 444, 3))

Rejoin Patches

scale = 2
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
scaled_image_shape = tuple(np.multiply(input_img.shape[0:2], scale)) + (3,)

sr_img = isr.stich_together(
    collect,
    padded_image_shape=padded_size_scaled,
    target_shape=scaled_image_shape,
    padding_size=padding_size * scale,
)

Verify

print(input_img.shape, sr_img.shape)
# (719, 640, 3) (1438, 1280, 3)

fig, ax = plt.subplots(1,2)
fig.set_size_inches(18.5, 10.5)
ax[0].imshow(input_img)
ax[1].imshow(sr_img.astype('uint8'))

enter image description here