Optimizing a large if-else branch with binary search

1.1k views Asked by At

So there is an if-else branch in my program with about 30 if-else statements. This part runs more than 100 times per second, so I saw it as an opportunity to optimize, and made it do binary search with a function pointer array (practically a balanced tree map) instead of doing linear if-else condition checks. But it ran slower about 70% of the previous speed.

I made a simple benchmark program to test the issue and it also gave similar result that the if-else part runs faster, both with and without compiler optimizations.

I also counted the number of comparisons done, and as expected the one doing binary search did about half number of comparisons than the simple if-else branch. But still it ran 20~30% slower.

I want to know where all my computing time is being wasted, and why the linear if-else runs faster than the logarithmic binary search?

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

long long ifElseCount = 0;
long long binaryCount = 0;

int ifElseSearch(int i) {
    ++ifElseCount;
    if (i == 0) {
        return 0;
    }
    ++ifElseCount;
    if (i == 1) {
        return 1;
    }
    ++ifElseCount;
    if (i == 2) {
        return 2;
    }
    ++ifElseCount;
    if (i == 3) {
        return 3;
    }
    ++ifElseCount;
    if (i == 4) {
        return 4;
    }
    ++ifElseCount;
    if (i == 5) {
        return 5;
    }
    ++ifElseCount;
    if (i == 6) {
        return 6;
    }
    ++ifElseCount;
    if (i == 7) {
        return 7;
    }
    ++ifElseCount;
    if (i == 8) {
        return 8;
    }
    ++ifElseCount;
    if (i == 9) {
        return 9;
    }
}

int getZero(void) {
    return 0;
}

int getOne(void) {
    return 1;
}

int getTwo(void) {
    return 2;
}

int getThree(void) {
    return 3;
}

int getFour(void) {
    return 4;
}

int getFive(void) {
    return 5;
}

int getSix(void) {
    return 6;
}

int getSeven(void) {
    return 7;
}

int getEight(void) {
    return 8;
}

int getNine(void) {
    return 9;
}

struct pair {
    int n;
    int (*getN)(void);
};

struct pair zeroToNine[10] = {
    {0, getZero},
    {2, getTwo},
    {4, getFour},
    {6, getSix},
    {8, getEight},
    {9, getNine},
    {7, getSeven},
    {5, getFive},
    {3, getThree},
    {1, getOne},
};

int sortCompare(const void *p, const void *p2) {
    if (((struct pair *)p)->n < ((struct pair *)p2)->n) {
        return -1;
    }
    if (((struct pair *)p)->n > ((struct pair *)p2)->n) {
        return 1;
    }
    return 0;
}

int searchCompare(const void *pKey, const void *pElem) {
    ++binaryCount;
    if (*(int *)pKey < ((struct pair *)pElem)->n) {
        return -1;
    }
    if (*(int *)pKey > ((struct pair *)pElem)->n) {
        return 1;
    }
    return 0;
}

int binarySearch(int key) {
    return ((struct pair *)bsearch(&key, zeroToNine, 10, sizeof(struct pair), searchCompare))->getN();
}

struct timer {
    clock_t start;
    clock_t end;
};

void startTimer(struct timer *timer) {
    timer->start = clock();
}

void endTimer(struct timer *timer) {
    timer->end = clock();
}

double getSecondsPassed(struct timer *timer) {
    return (timer->end - timer->start) / (double)CLOCKS_PER_SEC;
}

int main(void) {
    #define nTests 500000000
    struct timer timer;
    int i;

    srand((unsigned)time(NULL));
    printf("%d\n\n", rand());
    for (i = 0; i < 10; ++i) {
        printf("%d ", zeroToNine[i].n);
    }
    printf("\n");
    qsort(zeroToNine, 10, sizeof(struct pair), sortCompare);
    for (i = 0; i < 10; ++i) {
        printf("%d ", zeroToNine[i].n);
    }
    printf("\n\n");

    startTimer(&timer);
    for (i = 0; i < nTests; ++i) {
        ifElseSearch(rand() % 10);
    }
    endTimer(&timer);
    printf("%f\n", getSecondsPassed(&timer));

    startTimer(&timer);
    for (i = 0; i < nTests; ++i) {
        binarySearch(rand() % 10);
    }
    endTimer(&timer);
    printf("%f\n", getSecondsPassed(&timer));
    printf("\n%lli %lli\n", ifElseCount, binaryCount);
    return EXIT_SUCCESS;
}

possible output:

78985494

0 2 4 6 8 9 7 5 3 1 
0 1 2 3 4 5 6 7 8 9 

12.218656
16.496393

2750030239 1449975849
2

There are 2 answers

1
mtijanic On BEST ANSWER

You should look at the generated instructions to see (gcc -S source.c), but generally it comes down to these three:

1) N is too small.

If you only have a 8 different branches, you execute an average of 4 checks (assuming equally probable cases, otherwise it could be even faster).

If you make it a binary search, that is log(8) == 3 checks, but these checks are much more complex, resulting in an overall more code executed.

So, unless your N is in the hundreds, it probably doesn't make sense to do this. You could do some profiling to find the actual value for N.

2) Branch prediction is harder.

In case of a linear search, every condition is true in 1/N cases, meaning the compiler and branch predictor can assume no branching, and then recover only once. For a binary search, you likely end up flushing the pipeline once every layer. And for N < 1024, 1/log(N) chance of misprediction actually hurts the performance.

3) Pointers to functions are slow

When executing a pointer to a function you have to get it from memory, then you have to load your function into instruction cache, then execute the call instruction, the function setup and return. You can not inline functions called through a pointer, so that is several extra instructions, plus memory access, plus moving things in/out of the cache. It adds up pretty quickly.


All in all, this only makes sense for a large N, and you should always profile before applying these optimizations.

3
gnasher729 On

Use a switch statement.

Compilers are clever. They will produce the most efficient code for your particular values. They will even do a binary search (with inline code) if that is deemed more efficient.

And as a huge benefit, the code is readable, and doesn't require you to make changes in half a dozen places to add a new case.

PS. Obviously your code is a good learning experience. Now you've learned, so don't do it again :-)