Haskell Optimization Failure

142 views Asked by At

I have the following haskell function written in continuation-passing style:

import Data.Bits ((.|.), shiftR)

nextPowerOf2 :: Int -> Int
nextPowerOf2 0 = 1
nextPowerOf2 x = (go $ go $ go $ go $ go $ \y m -> y + 1) (x - 1) 1
    where go k y m = k (y .|. (y `shiftR` m)) $ m * 2

I expected this to compile down to the obvious well-optimized code, but instead I am getting core equivalent to this:

nextPowerOf2 0 = 1
nextPowerOf2 x = 
    let y1 = x - 1 
        y2 = y1 >> 1
        y3 = (y1 .|. y2) >> 2
        y4 = (y1 .|. y2 .|. y3) >> 4
        y5 = (y1 .|. y2 .|. y3 .|. y4) >> 8  
    in 1 + (y1 .|. y2 .|. y3 .|. y4 .|. y5) .|. ((y1 .|. y2 .|. y3 .|. y4 .|. y5) >> 16)

even with optimization enabled. Can anyone explain why it is doing so much duplication, or how to rewrite it in a way GHC understands?

Edit:

The resultant ASM looks like this:

Main.nextPowerOf2_closure:
        .quad   Main.nextPowerOf2_info
.text
        .align 8
        .quad   0
        .quad   32
s25G_info:
_c28g:
        addq $16,%r12
        cmpq 144(%r13),%r12
        ja _c28m
        movq 7(%rbx),%rax
        testq %rax,%rax
        jne _c28p
        movl $lvl_r23w_closure+1,%ebx
        addq $8,%rbp
        addq $-16,%r12
        jmp *0(%rbp)
_c28m:
        movq $16,192(%r13)
_c28k:
        jmp *-16(%r13)
_c28p:
        leaq -1(%rax),%rsi
        movq %rsi,%rdx
        sarq $1,%rdx
        movq %rsi,%rcx
        orq %rdx,%rcx
        sarq $2,%rcx
        movq %rdx,%rax
        orq %rcx,%rax
        movq %rsi,%rbx
        orq %rax,%rbx
        sarq $4,%rbx
        movq %rcx,%rax
        orq %rbx,%rax
        movq %rdx,%rdi
        orq %rax,%rdi
        movq %rsi,%rax
        orq %rdi,%rax
        sarq $8,%rax
        movq $GHC.Types.I#_con_info,-8(%r12)
        movq %rbx,%rdi
        orq %rax,%rdi
        movq %rcx,%r8
        orq %rdi,%r8
        movq %rdx,%rdi
        orq %r8,%rdi
        movq %rsi,%r8
        orq %rdi,%r8
        sarq $16,%r8
        orq %r8,%rax
        orq %rax,%rbx
        orq %rbx,%rcx
        orq %rcx,%rdx
        orq %rdx,%rsi
        leaq 1(%rsi),%rax
        movq %rax,0(%r12)
        leaq -7(%r12),%rbx
        addq $8,%rbp
        jmp *0(%rbp)
        .size s25G_info, .-s25G_info
.text
        .align 8
        .quad   4294967301
        .quad   0
        .quad   15
.globl Main.nextPowerOf2_info
.type Main.nextPowerOf2_info, @object
Main.nextPowerOf2_info:
_c28U:
        leaq -8(%rbp),%rax
        cmpq %r15,%rax
        jb _c28W
        movq %r14,%rbx
        movq $s25G_info,-8(%rbp)
        addq $-8,%rbp
        testq $7,%rbx
        jne s25G_info
        jmp *(%rbx)
_c28W:
        movl $Main.nextPowerOf2_closure,%ebx
        jmp *-8(%r13)
        .size Main.nextPowerOf2_info, .-Main.nextPowerOf2_info
.section .data
        .align 8
.align 1

and the expected core is something like this:

nextPowerOf2 0 = 1
nextPowerOf2 x = 
    let y1 = x - 1 
        y2 = y1 .|. (y1 >> 1)
        y3 = y2 .|. (y2 >> 2)
        y4 = y3 .|. (y3 >> 4)
        y5 = y4 .|. (y4 >> 8) 
        y6 = y5 .|. (y5 >> 16)
    in 1 + y6
1

There are 1 answers

0
Noughtmare On

GHC 9.8.1 does optimize it to your desired Core:

nextPowerOf1 = I# 1#

nextPowerOf2
  = \ ds ->
      case ds of { I# ds1 ->
      case ds1 of ds2 {
        __DEFAULT ->
          I#
            (let { y = -# ds2 1# } in
             let { y1 = orI# y (uncheckedIShiftRA# y 1#) } in
             let { y2 = orI# y1 (uncheckedIShiftRA# y1 2#) } in
             let { y3 = orI# y2 (uncheckedIShiftRA# y2 4#) } in
             let { y4 = orI# y3 (uncheckedIShiftRA# y3 8#) } in
             +# (orI# y4 (uncheckedIShiftRA# y4 16#)) 1#);
        0# -> nextPowerOf1
      }
      }