Halide hangs during Normalized Cross Correlation

464 views Asked by At

I'm trying to implement normalized cross correlation in Halide.

The code below builds, and Halide JIT compilation doesn't throw any errors. However, Halide seems to hang after JIT compilation. No matter how many trace_* calls I put on different Funcs, only one trace ever prints (on Func output):

Begin realization normxcorr.0(0, 2028, 0, 2028)
Produce normxcorr.0(0, 2028, 0, 2028)

Any advice at all would be helpful.

This algorithm is meant to be equivalent to CV_TM_CCOEFF_NORMED in OpenCV, and normxcorr2 in MATLAB:

void normxcorr( Halide::ImageParam input,
                Halide::ImageParam kernel,
                Halide::Param<pixel_t> kernel_mean,
                Halide::Param<pixel_t> kernel_var,
                Halide::Func& output )
{
    Halide::Var x, y;
    Halide::RDom rk( kernel );

    // reduction domain for cumulative sums
    Halide::RDom ri( 1, input.width() - kernel.width() - 1, 
                     1, input.height() - kernel.height() - 1 );

    Halide::Func input_32( "input32" ),
             bounded_input( "bounded_input"),
             kernel_32( "kernel32" ),
             knorm( "knorm" ),
             conv( "conv" ),
             normxcorr( "normxcorr_internal" ),
             sq_sum_x( "sq_sum_x" ),
             sq_sum_x_local( "sq_sum_x_local" ),
             sq_sum_y( "sq_sum_y" ),
             sq_sum_y_local( "sq_sum_y_local" ),
             sum_x( "sum_x" ),
             sum_x_local( "sum_x_local" ),
             sum_y( "sum_y" ),
             sum_y_local( "sum_y_local" ),
             win_var( "win_var" ),
             win_mean( "win_mean" );

    Halide::Expr ksize = kernel.width() * kernel.height();

    // accessing outside the input image always returns 0
    bounded_input( x, y ) = Halide::BoundaryConditions::constant_exterior( input, 0 )( x, y );

    // cast to 32-bit to make room for multiplication
    input_32( x, y ) = Halide::cast<int32_t>( bounded_input( x, y ) );
    kernel_32( x, y ) = Halide::cast<int32_t>( kernel( x, y ) );

    // cumulative sum along each row
    sum_x( x, y ) = input_32( x, y );
    sum_x( ri.x, ri.y ) += sum_x( ri.x - 1, ri.y );

    // sum of 1 x W strips
    // (W is the width of the kernel)
    sum_x_local( x, y ) = sum_x( x + kernel.width() - 1, y );
    sum_x_local( x, y ) -= sum_x( x - 1, y );

    // cumulative sums of the 1 x W strips along each column
    sum_y( x, y ) = sum_x_local( x, y );
    sum_y( ri.x, ri.y ) += sum_y( ri.x, ri.y - 1);

    // sums up H strips (as above) to get the sum of an H x W rectangle
    // (H is the height of the kernel)
    sum_y_local( x, y ) = sum_y( x, y + kernel.height() - 1 );
    sum_y_local( x, y ) -= sum_y( x, y - 1 );

    // same as above, just with squared image values
    sq_sum_x( x, y ) = input_32( x, y ) * input_32( x, y );
    sq_sum_x( ri.x, ri.y ) += sq_sum_x( ri.x - 1, ri.y );

    sq_sum_x_local( x, y ) = sq_sum_x( x + kernel.width() - 1, y );
    sq_sum_x_local( x, y ) -= sq_sum_x( x - 1, y );

    sq_sum_y( x, y ) = sq_sum_x_local( x, y );
    sq_sum_y( ri.x, ri.y ) += sq_sum_y( ri.x, ri.y - 1);

    sq_sum_y_local( x, y ) = sq_sum_y( x, y + kernel.height() - 1 );
    sq_sum_y_local( x, y ) -= sq_sum_y( x, y - 1 );

    // the mean value of each window
    win_mean( x, y ) = sum_y_local( x, y ) / ksize;

    // the variance of each window
    win_var( x, y ) =  sq_sum_y_local( x, y ) / ksize;
    win_var( x, y) -= win_mean( x, y ) * win_mean( x, y );

    // partially normalize the kernel
    // (we'll divide by std. dev. at the end)
    knorm( x, y ) = kernel_32( x, y ) - kernel_mean;

    // convolve kernel and the input
    conv( x, y ) = Halide::sum( knorm( rk.x, rk.y ) * input_32( x + rk.x, y + rk.y ) );

    // calculate normxcorr, except scaled to 0 to 254 (for an 8-bit image)
    normxcorr( x, y ) = conv( x, y ) * 127 / Halide::sqrt( kernel_var * win_var( x, y ) ) + 127;

    // after scaling pixel values, it's safe to cast down to 8-bit
    output( x, y ) = Halide::cast<pixel_t>( normxcorr( x, y ) );
}
1

There are 1 answers

4
jrk On BEST ANSWER

I believe the issue here is simply that you haven't specified any schedule for any of your functions, so everything is being inlined, resulting in an enormous amount of redundant computation of the intermediate values. So in fact this is not technically hanging, but simply doing a huge amount of work for every pixel and so failing to complete in a reasonable amount of time.

To start out, try saying that every function should be compute_root (e.g., sum_x.compute_root();), most easily in a block at the end of the function. This should proceed at a much more reasonable pace, should print every function (starting from the inputs) one after another, rather than just normxcore.0, and should complete.

In reality, many of your functions are actually just pointwise transformations of their inputs, so those inputs could be left inline (not compute_root), which should speed things up further (especially once you start parallelizing and vectorizing some stages). On first glance, [sq_]sum_{x,y} should probably not be inlined, but everything else can probably be left inline. knorm and input_32 are a toss-up, depending on your target and your schedule otherwise.

I threw a quick runnable revision with this trivial schedule added, and some other small cleanup, up here:

https://gist.github.com/d64823d754a732106a60

In my tests it runs on a 2K^2 input in under a second with nothing fancy.

As an aside, a minor tip: compiling your generator code with debug symbols (-g) should free you from giving name strings in all your Func declarations. This was an unfortunate wart in earlier implementations, but we're able to do a reasonable job setting these names directly from the C++ source symbol names now, as long as you compile with debug symbols enabled.