While playing around with Halide, I see that totally different pseudocodes are created for a same pipline when using JIT and a generated function approaches. It looks like I'm missing something and so I'd very appreciate and hint. Here is what I did:
A simple 'dilate' pipline is defined as:
int jit_main ()
{
Target target = get_jit_target_from_environment ();
const int width = 1280, height = 1024;
Buffer <uint8_t> input (width, height);
for (int y = 0; y < height; y++)
for (int x = 0; x < width; x++)
input (x, y) = rand () & 0xff;
Var x ("x_1"), y ("y_1");
Func clamped ("clamped_1");
clamped = BoundaryConditions::repeat_edge (input);
Func max_x ("max_x_1");
max_x (x, y) = max (clamped (x - 1, y), clamped (x, y), clamped (x + 1, y));
Func dilate ("dilate_1");
dilate (x, y) = max (max_x (x, y - 1), max_x (x, y), max_x (x, y + 1));
tick (NULL);
Buffer<uint8_t> out = dilate.realize (width, height, target);
tick ("inline");
dilate.print_loop_nest ();
dilate.compile_to_lowered_stmt ("dilate_1_.html", {}, HTML);
}
The resulting pseudocode looks as follows (fragment):
produce dilate_1 {
let t125 = ((dilate_1.min.1 * dilate_1.stride.1) + dilate_1.min.0)
for (dilate_1.s0.y_1, dilate_1.min.1, dilate_1.extent.1) {
let t128 = max(min(dilate_1.s0.y_1, 1024), 1)
let t126 = max(min(dilate_1.s0.y_1, 1023), 0)
let t127 = max(min(dilate_1.s0.y_1, 1022), -1)
let t129 = ((dilate_1.s0.y_1 * dilate_1.stride.1) - t125)
for (dilate_1.s0.x_1, dilate_1.min.0, dilate_1.extent.0) {
dilate_1[(dilate_1.s0.x_1 + t129)] = max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t126 * 1280)) + 1)], max(b0[(max(min(dilate_1.s0.x_1, 1279), 0) + (t126 * 1280))], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t126 * 1280)) + -1)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t127 * 1280)) + 1279)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t127 * 1280)) + 1280)], max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t127 * 1280)) + 1281)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t128 * 1280)) + -1281)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t128 * 1280)) + -1280)], b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t128 * 1280)) + -1279)]))))))))
}
}
}
Then I defined a generator:
class Dilate0Generator : public Halide::Generator <Dilate0Generator>
{
public:
Input<Buffer<uint8_t>> input_0 {"input_0", 2};
Output<Buffer<uint8_t>> dilate_0 {"dilate_0", 2};
Var x {"x_0"}, y {"y_0"};
void generate ()
{
Func clamped_0 {"clamped_0"};
clamped_0 = BoundaryConditions::repeat_edge (input_0);
Func max_x_0 {"max_x_0"};
max_x_0 (x, y) =
max (clamped_0 (x - 1, y), clamped_0 (x, y), clamped_0 (x + 1, y));
dilate_0 (x, y) =
max (max_x_0 (x, y - 1), max_x_0 (x, y), max_x_0 (x, y + 1));
dilate_0.print_loop_nest ();
}
};
HALIDE_REGISTER_GENERATOR (Dilate0Generator, dilate_0)
And it's pseudocode is completely different (fragment):
produce dilate_0 {
let dilate_0.s0.y_0.prologue = min(max((input_0.min.1 + 1), dilate_0.min.1), (dilate_0.extent.1 + dilate_0.min.1))
let dilate_0.s0.y_0.epilogue$3 = min(max(max((input_0.min.1 + 1), dilate_0.min.1), ((input_0.extent.1 + input_0.min.1) + -1)), (dilate_0.extent.1 + dilate_0.min.1))
let t166 = (dilate_0.s0.y_0.prologue - dilate_0.min.1)
let t168 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
let t170 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
let t167 = (input_0.extent.1 + input_0.min.1)
let t169 = (input_0.extent.0 + input_0.min.0)
for (dilate_0.s0.y_0, dilate_0.min.1, t166) {
let t171 = ((max(min((t167 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t168)
let t173 = ((max((min((dilate_0.s0.y_0 + 2), t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
let t174 = ((max((min(dilate_0.s0.y_0, t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
let t175 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t170)
for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
dilate_0[(dilate_0.s0.x_0 + t175)] = (let t132 = max((min((dilate_0.s0.x_0 + 2), t169) + -1), input_0.min.0) in (let t133 = max(min((t169 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t134 = max((min(dilate_0.s0.x_0, t169) + -1), input_0.min.0) in max(input_0[(t132 + t171)], max(input_0[(t133 + t171)], max(input_0[(t134 + t171)], max(input_0[(t134 + t173)], max(input_0[(t133 + t173)], max(input_0[(t132 + t173)], max(input_0[(t134 + t174)], max(input_0[(t133 + t174)], input_0[(t132 + t174)])))))))))))
}
}
let t183 = (dilate_0.extent.0 + dilate_0.min.0)
let t184 = (input_0.extent.0 + input_0.min.0)
let t185 = max((input_0.min.0 + 1), dilate_0.min.0)
let t178 = min(max((t184 + -1), t185), t183)
let t177 = min(t183, t185)
let t176 = (dilate_0.s0.y_0.epilogue$3 - dilate_0.s0.y_0.prologue)
let t179 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
let t181 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
for (dilate_0.s0.y_0, dilate_0.s0.y_0.prologue, t176) {
let t189 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
let t190 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
let t187 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
let t191 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
let t186 = (t177 - dilate_0.min.0)
for (dilate_0.s0.x_0, dilate_0.min.0, t186) {
dilate_0[(dilate_0.s0.x_0 + t191)] = (let t140 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t141 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t142 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t140 + t187)], max(input_0[(t141 + t187)], max(input_0[(t142 + t187)], max(input_0[(t142 + t189)], max(input_0[(t141 + t189)], max(input_0[(t140 + t189)], max(input_0[(t142 + t190)], max(input_0[(t141 + t190)], input_0[(t140 + t190)])))))))))))
}
let t194 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
let t195 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
let t193 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
let t196 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
let t192 = (t178 - t177)
for (dilate_0.s0.x_0, t177, t192) {
dilate_0[(dilate_0.s0.x_0 + t196)] = max(input_0[((dilate_0.s0.x_0 + t193) + 1)], max(input_0[(dilate_0.s0.x_0 + t193)], max(input_0[((dilate_0.s0.x_0 + t193) + -1)], max(input_0[((dilate_0.s0.x_0 + t194) + -1)], max(input_0[(dilate_0.s0.x_0 + t194)], max(input_0[((dilate_0.s0.x_0 + t194) + 1)], max(input_0[((dilate_0.s0.x_0 + t195) + -1)], max(input_0[(dilate_0.s0.x_0 + t195)], input_0[((dilate_0.s0.x_0 + t195) + 1)]))))))))
}
let t200 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
let t201 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
let t198 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
let t202 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
let t197 = (t183 - t178)
for (dilate_0.s0.x_0, t178, t197) {
dilate_0[(dilate_0.s0.x_0 + t202)] = (let t152 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t153 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t154 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t152 + t198)], max(input_0[(t153 + t198)], max(input_0[(t154 + t198)], max(input_0[(t154 + t200)], max(input_0[(t153 + t200)], max(input_0[(t152 + t200)], max(input_0[(t154 + t201)], max(input_0[(t153 + t201)], input_0[(t152 + t201)])))))))))))
}
}
let t203 = ((dilate_0.extent.1 + dilate_0.min.1) - dilate_0.s0.y_0.epilogue$3)
let t205 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
let t207 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
let t204 = (input_0.extent.1 + input_0.min.1)
let t206 = (input_0.extent.0 + input_0.min.0)
for (dilate_0.s0.y_0, dilate_0.s0.y_0.epilogue$3, t203) {
let t208 = ((max(min((t204 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t205)
let t210 = ((max((min((dilate_0.s0.y_0 + 2), t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
let t211 = ((max((min(dilate_0.s0.y_0, t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
let t212 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t207)
for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
dilate_0[(dilate_0.s0.x_0 + t212)] = (let t161 = max((min((dilate_0.s0.x_0 + 2), t206) + -1), input_0.min.0) in (let t162 = max(min((t206 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t163 = max((min(dilate_0.s0.x_0, t206) + -1), input_0.min.0) in max(input_0[(t161 + t208)], max(input_0[(t162 + t208)], max(input_0[(t163 + t208)], max(input_0[(t163 + t210)], max(input_0[(t162 + t210)], max(input_0[(t161 + t210)], max(input_0[(t163 + t211)], max(input_0[(t162 + t211)], input_0[(t161 + t211)])))))))))))
}
}
}
The generated version runs in an order of magnitude faster, which is not surprising, given that the pseudocode for it looks a lot more optimized. It runs even faster that an existed example
My noob question is how comes that JIT can not create the same representation? Thanks a lot for any answer/idea/help/hint...
The difference between the two is that in the JIT case, the size of the input (and thus the location of the boundary condition) is known at compile-time.
However the generated code should be similar. I think the fact that you don't get five separate cases in the JIT case is a bug in Halide. I have opened an issue on the Halide github repo. https://github.com/halide/Halide/issues/5353
EDIT: Thanks for uncovering a bug! Fixed in https://github.com/halide/Halide/pull/5355