I've written this very simple Rust function:
fn iterate(nums: &Box<[i32]>) -> i32 {
let mut total = 0;
let len = nums.len();
for i in 0..len {
if nums[i] > 0 {
total += nums[i];
} else {
total -= nums[i];
}
}
total
}
I've written a basic benchmark that invokes the method with an ordered array and a shuffled one:
fn criterion_benchmark(c: &mut Criterion) {
const SIZE: i32 = 1024 * 1024;
let mut group = c.benchmark_group("Branch Prediction");
// setup benchmarking for an ordered array
let mut ordered_nums: Vec<i32> = vec![];
for i in 0..SIZE {
ordered_nums.push(i - SIZE/2);
}
let ordered_nums = ordered_nums.into_boxed_slice();
group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));
// setup benchmarking for a shuffled array
let mut shuffled_nums: Vec<i32> = vec![];
for i in 0..SIZE {
shuffled_nums.push(i - SIZE/2);
}
let mut rng = thread_rng();
let mut shuffled_nums = shuffled_nums.into_boxed_slice();
shuffled_nums.shuffle(&mut rng);
group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
I'm surprised that the two benchmarks have almost exactly the same runtime, while a similar benchmark in Java shows a distinct difference between the two, presumably due to branch prediction failure in the shuffled case.
I've seen mention of conditional move instructions, but if I otool -tv the executable (I'm running on a Mac), I don't see any in the iterate method output.
Can anyone shed light on why there's no perceptible performance difference between the ordered and the unordered cases in Rust?
Summary: LLVM was able to remove/hide the branch by using either the
cmovinstruction or a really clever combination of SIMD instructions.I used Godbolt to view the full assembly (with
-C opt-level=3). I will explain the important parts of the assembly below.It starts like this:
Here, the function differentiates between 3 different "states":
LBB0_4)LBB0_5)So let's take a look at the two different kinds of algorithms!
Standard sequential algorithm
Remember that
rsi(esi) andrax(eax) were set to 0 and thatrdxis the base pointer to the data.This is a simple loop iterating over all elements of
num. In the loop's body there is a little trick though: from the original elementecx, a negated value is stored inedi. By usingcmovl,ediis overwritten with the original value if that original value is positive. That means thatediwill always turn out positive (i.e. contain the absolute value of the original element). Then it is added toeax(which is returned in the end).So your
ifbranch was hidden in thecmovinstruction. As you can see in this benchmark, the time required to execute acmovinstruction is independent of the probability of the condition. It's a pretty amazing instruction!SIMD algorithm
The SIMD version consists of quite a few instructions that I won't fully paste here. The main loop handles 16 integers at once!
They are loaded from memory into the registers
xmm0,xmm1,xmm3andxmm5. Each of those registers contains four 32 bit values, but to follow along more easily, just imagine each register contains exactly one value. All following instructions operate on each value of those SIMD registers individually, so that mental model is fine! My explanation below will also sound as ifxmmregisters would only contain a single value.The main trick is now in the following instructions (which handle
xmm5):The logical right shift fills the "empty high-order bits" (the ones "shifted in" on the left) with the value of the sign bit. By shifting by 31, we end up with only the sign bit in every position! So any positive number will turn into 32 zeroes and any negative number will turn into 32 ones. So
xmm6is now either000...000(ifxmm5is positive) or111...111(ifxmm5is negative).Next this artificial
xmm6is added toxmm5. Ifxmm5was positive,xmm6is 0, so adding it won't changexmm5. Ifxmm5was negative, however, we add111...111which is equivalent to subtracting 1. Finally, we xorxmm5withxmm6. Again, ifxmm5was positive in the beginning, we xor with000...000which does not have an effect. Ifxmm5was negative in the beginning we xor with111...111, meaning we flip all the bits. So for both cases:addandxordidn't have any effect)So with these 4 instructions we calculated the absolute value of
xmm5! Here again, there is no branch because of this bit-fiddling trick. And remember thatxmm5actually contains 4 integers, so it's quite speedy!This absolute value is now added to an accumulator and the same is done with the three other
xmmregisters that contain values from the slice. (We won't discuss the remaining code in detail.)SIMD with AVX2
If we allow LLVM to emit AVX2 instructions (via
-C target-feature=+avx2), it can even use thepabsdinstruction instead of the four "hacky" instructions:It loads the values directly from memory, calculates the absolute and stores it in
ymm2in one instruction! And remember thatymmregisters are twice as large asxmmregisters (fitting eight 32 bit values)!