Halide JIT vs Generator Differences

264 views Asked by At

While playing around with Halide, I see that totally different pseudocodes are created for a same pipline when using JIT and a generated function approaches. It looks like I'm missing something and so I'd very appreciate and hint. Here is what I did:

A simple 'dilate' pipline is defined as:

int jit_main ()
{
    Target target = get_jit_target_from_environment ();
    const int width = 1280, height = 1024;
    Buffer <uint8_t> input (width, height);

    for (int y = 0; y < height; y++)
        for (int x = 0; x < width; x++)
            input (x, y) = rand () & 0xff;

    Var x ("x_1"), y ("y_1");

    Func clamped ("clamped_1");
    clamped = BoundaryConditions::repeat_edge (input);

    Func max_x ("max_x_1");
    max_x (x, y) = max (clamped (x - 1, y), clamped (x, y), clamped (x + 1, y));

    Func dilate ("dilate_1");
    dilate (x, y) = max (max_x (x, y - 1), max_x (x, y), max_x (x, y + 1));

    tick (NULL);
    Buffer<uint8_t> out = dilate.realize (width, height, target);
    tick ("inline");

    dilate.print_loop_nest ();
    dilate.compile_to_lowered_stmt ("dilate_1_.html", {}, HTML);
}

The resulting pseudocode looks as follows (fragment):

    produce dilate_1 {
        let t125 = ((dilate_1.min.1 * dilate_1.stride.1) + dilate_1.min.0)
        for (dilate_1.s0.y_1, dilate_1.min.1, dilate_1.extent.1) {
            let t128 = max(min(dilate_1.s0.y_1, 1024), 1)
            let t126 = max(min(dilate_1.s0.y_1, 1023), 0)
            let t127 = max(min(dilate_1.s0.y_1, 1022), -1)
            let t129 = ((dilate_1.s0.y_1 * dilate_1.stride.1) - t125)
            for (dilate_1.s0.x_1, dilate_1.min.0, dilate_1.extent.0) {
                dilate_1[(dilate_1.s0.x_1 + t129)] = max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t126 * 1280)) + 1)], max(b0[(max(min(dilate_1.s0.x_1, 1279), 0) + (t126 * 1280))], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t126 * 1280)) + -1)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t127 * 1280)) + 1279)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t127 * 1280)) + 1280)], max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t127 * 1280)) + 1281)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t128 * 1280)) + -1281)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t128 * 1280)) + -1280)], b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t128 * 1280)) + -1279)]))))))))
            }
        }
    }

Then I defined a generator:

class Dilate0Generator : public Halide::Generator <Dilate0Generator>
{
public:
    Input<Buffer<uint8_t>>  input_0 {"input_0", 2};
    Output<Buffer<uint8_t>> dilate_0 {"dilate_0", 2};
    Var                     x {"x_0"}, y {"y_0"};

    void generate ()
    {
        Func clamped_0 {"clamped_0"};
        clamped_0 = BoundaryConditions::repeat_edge (input_0);

        Func max_x_0 {"max_x_0"};
        max_x_0 (x, y) =
            max (clamped_0 (x - 1, y), clamped_0 (x, y), clamped_0 (x + 1, y));

        dilate_0 (x, y) =
            max (max_x_0 (x, y - 1), max_x_0 (x, y), max_x_0 (x, y + 1));

        dilate_0.print_loop_nest ();
    }
};
HALIDE_REGISTER_GENERATOR (Dilate0Generator, dilate_0)

And it's pseudocode is completely different (fragment):

    produce dilate_0 {
        let dilate_0.s0.y_0.prologue = min(max((input_0.min.1 + 1), dilate_0.min.1), (dilate_0.extent.1 + dilate_0.min.1))
        let dilate_0.s0.y_0.epilogue$3 = min(max(max((input_0.min.1 + 1), dilate_0.min.1), ((input_0.extent.1 + input_0.min.1) + -1)), (dilate_0.extent.1 + dilate_0.min.1))
        let t166 = (dilate_0.s0.y_0.prologue - dilate_0.min.1)
        let t168 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t170 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        let t167 = (input_0.extent.1 + input_0.min.1)
        let t169 = (input_0.extent.0 + input_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.min.1, t166) {
            let t171 = ((max(min((t167 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t168)
            let t173 = ((max((min((dilate_0.s0.y_0 + 2), t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
            let t174 = ((max((min(dilate_0.s0.y_0, t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
            let t175 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t170)
            for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
                dilate_0[(dilate_0.s0.x_0 + t175)] = (let t132 = max((min((dilate_0.s0.x_0 + 2), t169) + -1), input_0.min.0) in (let t133 = max(min((t169 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t134 = max((min(dilate_0.s0.x_0, t169) + -1), input_0.min.0) in max(input_0[(t132 + t171)], max(input_0[(t133 + t171)], max(input_0[(t134 + t171)], max(input_0[(t134 + t173)], max(input_0[(t133 + t173)], max(input_0[(t132 + t173)], max(input_0[(t134 + t174)], max(input_0[(t133 + t174)], input_0[(t132 + t174)])))))))))))
            }
        }
        let t183 = (dilate_0.extent.0 + dilate_0.min.0)
        let t184 = (input_0.extent.0 + input_0.min.0)
        let t185 = max((input_0.min.0 + 1), dilate_0.min.0)
        let t178 = min(max((t184 + -1), t185), t183)
        let t177 = min(t183, t185)
        let t176 = (dilate_0.s0.y_0.epilogue$3 - dilate_0.s0.y_0.prologue)
        let t179 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t181 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.s0.y_0.prologue, t176) {
            let t189 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t190 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t187 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t191 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t186 = (t177 - dilate_0.min.0)
            for (dilate_0.s0.x_0, dilate_0.min.0, t186) {
                dilate_0[(dilate_0.s0.x_0 + t191)] = (let t140 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t141 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t142 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t140 + t187)], max(input_0[(t141 + t187)], max(input_0[(t142 + t187)], max(input_0[(t142 + t189)], max(input_0[(t141 + t189)], max(input_0[(t140 + t189)], max(input_0[(t142 + t190)], max(input_0[(t141 + t190)], input_0[(t140 + t190)])))))))))))
            }
            let t194 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t195 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t193 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t196 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t192 = (t178 - t177)
            for (dilate_0.s0.x_0, t177, t192) {
                dilate_0[(dilate_0.s0.x_0 + t196)] = max(input_0[((dilate_0.s0.x_0 + t193) + 1)], max(input_0[(dilate_0.s0.x_0 + t193)], max(input_0[((dilate_0.s0.x_0 + t193) + -1)], max(input_0[((dilate_0.s0.x_0 + t194) + -1)], max(input_0[(dilate_0.s0.x_0 + t194)], max(input_0[((dilate_0.s0.x_0 + t194) + 1)], max(input_0[((dilate_0.s0.x_0 + t195) + -1)], max(input_0[(dilate_0.s0.x_0 + t195)], input_0[((dilate_0.s0.x_0 + t195) + 1)]))))))))
            }
            let t200 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t201 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t198 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t202 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t197 = (t183 - t178)
            for (dilate_0.s0.x_0, t178, t197) {
                dilate_0[(dilate_0.s0.x_0 + t202)] = (let t152 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t153 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t154 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t152 + t198)], max(input_0[(t153 + t198)], max(input_0[(t154 + t198)], max(input_0[(t154 + t200)], max(input_0[(t153 + t200)], max(input_0[(t152 + t200)], max(input_0[(t154 + t201)], max(input_0[(t153 + t201)], input_0[(t152 + t201)])))))))))))
            }
        }
        let t203 = ((dilate_0.extent.1 + dilate_0.min.1) - dilate_0.s0.y_0.epilogue$3)
        let t205 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t207 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        let t204 = (input_0.extent.1 + input_0.min.1)
        let t206 = (input_0.extent.0 + input_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.s0.y_0.epilogue$3, t203) {
            let t208 = ((max(min((t204 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t205)
            let t210 = ((max((min((dilate_0.s0.y_0 + 2), t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
            let t211 = ((max((min(dilate_0.s0.y_0, t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
            let t212 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t207)
            for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
                dilate_0[(dilate_0.s0.x_0 + t212)] = (let t161 = max((min((dilate_0.s0.x_0 + 2), t206) + -1), input_0.min.0) in (let t162 = max(min((t206 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t163 = max((min(dilate_0.s0.x_0, t206) + -1), input_0.min.0) in max(input_0[(t161 + t208)], max(input_0[(t162 + t208)], max(input_0[(t163 + t208)], max(input_0[(t163 + t210)], max(input_0[(t162 + t210)], max(input_0[(t161 + t210)], max(input_0[(t163 + t211)], max(input_0[(t162 + t211)], input_0[(t161 + t211)])))))))))))
            }
        }
    }

The generated version runs in an order of magnitude faster, which is not surprising, given that the pseudocode for it looks a lot more optimized. It runs even faster that an existed example

My noob question is how comes that JIT can not create the same representation? Thanks a lot for any answer/idea/help/hint...

1

There are 1 answers

1
Andrew Adams On

The difference between the two is that in the JIT case, the size of the input (and thus the location of the boundary condition) is known at compile-time.

However the generated code should be similar. I think the fact that you don't get five separate cases in the JIT case is a bug in Halide. I have opened an issue on the Halide github repo. https://github.com/halide/Halide/issues/5353

EDIT: Thanks for uncovering a bug! Fixed in https://github.com/halide/Halide/pull/5355