scheduling common loop in discreet pipeline funcs

241 views Asked by At

I have a number of halide pipelines (lowercase p) which all read the same input image and produce unique outputs. Some share common output dimensions, some do not. Every pipeline reads each pixel in the source image once. The output images needed may vary at runtime based on user input.

I'm using a Pipeline to compute all of these outputs into a Realization. Is there any way to schedule these disparate Funcs to achieve a single outer loop in the Pipeline?

It appears I can create a wrapper function which packs these Funcs into a Tuple and but this requires they all output the same dimensions.

Am I missing any other options?

Edited to Add Sample code

//Buffer<> input = Buffer<uint8_t>::make_interleaved(width, height, 4);
//fill buffer with image data

Var x("x"), y("y"), c("c");

Func rgb("rgb");
rgb(x,y,c) = ConciseCasts::u8_sat(input(x,y,c));

// Define a one-dimensional reduction domain over x
RDom r(0, input.width());

Func hist1("hist1");
Func hist2("hist2");

// Histogram buckets start as zero.
hist1(x,y) = 0;
hist2(x,y,c) = 0;

// Make a histogram for every scanline of input
hist1(rgb(r, y, 0), y ) += 1;
hist2(rgb(r, y, c), y, c) += 1;

Func clamp1("clamp1");
clamp1(x,y) = ConciseCasts::u8_sat(hist1(x,y));

Func clamp2("clamp2");
clamp2(x,y,c) = ConciseCasts::u8_sat(hist2(x,y,c));

//use clamp1 as a wrapper
hist1.compute_at(clamp1, y);

//schedule hist2 the same way (but unroll c)
hist2.compute_at(clamp2, y);

clamp2.bound(c,0,3).reorder(c, x, y).unroll(c);

hist2.bound(c,0,3).reorder(c, x, y).unroll(c);
hist2.update(0).reorder(c, r, y).unroll(c);

.bound(x, 0, 256)
.bound(y, 0, input.height());

.bound(x, 0, 256)
.bound(y, 0, input.height());

Pipeline pipe = Pipeline({clamp1, clamp2});

Looking at the lowered statement I see:

produce clamp1 {
    for (clamp1.s0.y, 0, 2160) {
      allocate hist1[int32 * 256 * 1]
      produce hist1 {
        for (hist1.s0.x, 0, 256) {
          hist1[hist1.s0.x] = 0
        for (hist1.s1.r4$x, 0, 4096) {
          hist1[int32(b0[((hist1.s1.r4$x*4) + (clamp1.s0.y*16384))])] = (hist1[int32(b0[((hist1.s1.r4$x*4) + (clamp1.s0.y*16384))])] + 1)
      for (clamp1.s0.x, 0, 256) {
        clamp1[((clamp1.s0.x + (clamp1.s0.y*clamp1.stride.1)) - (clamp1.min.0 + (clamp1.min.1*clamp1.stride.1)))] = uint8(max(min(hist1[clamp1.s0.x], 255), 0))
      free hist1

  produce clamp2 {
    for (clamp2.s0.y, 0, 2160) {
      allocate hist2[int32 * 256 * 1 * 3]
      produce hist2 {
        for (hist2.s0.x, 0, 256) {
          hist2[hist2.s0.x] = 0
          hist2[(hist2.s0.x + 256)] = 0
          hist2[(hist2.s0.x + 512)] = 0
        for (hist2.s1.r4$x, 0, 4096) {
          hist2[int32(b0[((hist2.s1.r4$x*4) + (clamp2.s0.y*16384))])] = (hist2[int32(b0[((hist2.s1.r4$x*4) + (clamp2.s0.y*16384))])] + 1)
          hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 1)]) + 256)] = (hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 1)]) + 256)] + 1)
          hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 2)]) + 512)] = (hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 2)]) + 512)] + 1)
      for (clamp2.s0.x, 0, 256) {
        clamp2[((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[clamp2.s0.x], 255), 0))
        clamp2[(((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) + clamp2.stride.2) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[(clamp2.s0.x + 256)], 255), 0))
        clamp2[(((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) + (clamp2.stride.2*2)) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[(clamp2.s0.x + 512)], 255), 0))
      free hist2

What I'm hoping to achieve is a lowered statement that looks closer to this (I just cut and pasted this together):

produce clamps {
    for (clamp1.s0.y, 0, 2160) {
      allocate hist1[int32 * 256 * 1]
      allocate hist2[int32 * 256 * 1 * 3]
      produce hists {
        for (hist1.s0.x, 0, 256) {
          hist1[hist1.s0.x] = 0
          hist2[hist2.s0.x] = 0
          hist2[(hist2.s0.x + 256)] = 0
          hist2[(hist2.s0.x + 512)] = 0
        for (hist1.s1.r4$x, 0, 4096) {
          hist1[int32(b0[((hist1.s1.r4$x*4) + (clamp1.s0.y*16384))])] = (hist1[int32(b0[((hist1.s1.r4$x*4) + (clamp1.s0.y*16384))])] + 1)
          hist2[int32(b0[((hist2.s1.r4$x*4) + (clamp2.s0.y*16384))])] = (hist2[int32(b0[((hist2.s1.r4$x*4) + (clamp2.s0.y*16384))])] + 1)
          hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 1)]) + 256)] = (hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 1)]) + 256)] + 1)
          hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 2)]) + 512)] = (hist2[(int32(b0[(((hist2.s1.r4$x*4) + (clamp2.s0.y*16384)) + 2)]) + 512)] + 1)
      for (clamp1.s0.x, 0, 256) {
        clamp1[((clamp1.s0.x + (clamp1.s0.y*clamp1.stride.1)) - (clamp1.min.0 + (clamp1.min.1*clamp1.stride.1)))] = uint8(max(min(hist1[clamp1.s0.x], 255), 0))
         clamp2[((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[clamp2.s0.x], 255), 0))
        clamp2[(((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) + clamp2.stride.2) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[(clamp2.s0.x + 256)], 255), 0))
        clamp2[(((clamp2.s0.x + (clamp2.s0.y*clamp2.stride.1)) + (clamp2.stride.2*2)) - ((clamp2.min.0 + (clamp2.min.1*clamp2.stride.1)) + (clamp2.min.2*clamp2.stride.2)))] = uint8(max(min(hist2[(clamp2.s0.x + 512)], 255), 0))
      free hist1
      free hist2

However if I try to add

clamp2.compute_with(clamp1, y); 

I get the following error when jitting

Internal error at /Halide/src/ScheduleFunctions.cpp:2228
Condition failed: injector.found_store_level && injector.found_compute_level

There are 1 answers

Zalman Stern On

This might be another use case for compute_with, which is not merged yet. You can try out the compute_with_directive branch to see if it meets your needs. Hopefully this will be merged soon.