Fast bignum square computation

6.7k views Asked by At

To speed up my bignum divisons I need to speed up operation y = x^2 for bigints which are represented as dynamic arrays of unsigned DWORDs. To be clear:

DWORD x[n+1] = { LSW, ......, MSW };
  • where n+1 is number of used DWORDs
  • so value of number x = x[0]+x[1]<<32 + ... x[N]<<32*(n)

The question is: How do I compute y = x^2 as fast as possible without precision loss?

  • Using C++ and with integer arithmetics (32bit with Carry) at disposal.

My current approach is applying multiplication y = x*x and avoid multiple multiplications.

For example:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

For simplicity, let me rewrite it:

x = x0+ x1 + x2 + ... + xn

where index represent the address inside the array, so:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)

y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

After a closer look, it is clear that almost all xi*xj appears twice (not the first and last one) which means that N*N multiplications can be replaced by (N+1)*(N/2) multiplications. P.S. 32bit*32bit = 64bit so the result of every mul+add operation is handled as 64+1 bit.

Is there a better way to compute this fast? All I found during searches were sqrts algorithms, not sqr...

Fast sqr

!!! Beware that all numbers in my code are MSW first,... not as in above test (there are LSW first for simplicity of equations, otherwise it would be an index mess).

Current functional fsqr implementation

void arbnum::sqr(const arbnum &x)
{
    // O((N+1)*N/2)
    arbnum c;
    DWORD h, l;
    int N, nx, nc, i, i0, i1, k;
    c._alloc(x.siz + x.siz + 1);
    nx = x.siz - 1;
    nc = c.siz - 1;
    N = nx + nx;
    for (i=0; i<=nc; i++)
        c.dat[i]=0;
    for (i=1; i<N; i++)
        for (i0=0; (i0<=nx) && (i0<=i); i0++)
        {
            i1 = i - i0;
            if (i0 >= i1)
                break;
            if (i1 > nx)
                continue;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            l = x.dat[nx-i1];
            if (!l)
                continue;
            alu.mul(h, l, h, l);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k], l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k],h);
            k--;
            for (; (alu.cy) && (k>=0); k--)
                alu.inc(c.dat[k]);
        }
        c.shl(1);
        for (i = 0; i <= N; i += 2)
        {
            i0 = i>>1;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            alu.mul(h, l, h, h);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k],l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k], h);
            k--;
            for (; (alu.cy) && (k >= 0); k--)
                alu.inc(c.dat[k]);
        }
        c.bits = c.siz<<5;
        c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
        c.sig = sig;
        *this = c;
    }

Use of Karatsuba multiplication

(thanks to Calpis)

I implemented Karatsuba multiplication but the results are massively slower even than by use of simple O(N^2) multiplication, probably because of that horrible recursion that I can't see any way to avoid. It's trade-off must be at really large numbers (bigger than hundreds of digits) ... but even then there are a lot of memory transfers. Is there a way to avoid recursion calls (non-recursive variant,... Almost all recursive algorithms can be done that way). Still, I will try to tweak things up and see what happens (avoid normalizations, etc..., also it could be some silly mistake in the code). Anyway, after solving Karatsuba for case x*x there is not much performance gain.

Optimized Karatsuba multiplication

Performance test for y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication

x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]

x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

After optimizations for Karatsuba, the code is massively faster than before. Still, for smaller numbers it is slightly less than half speed of my O(N^2) multiplication. For bigger numbers, it is faster with the ratio given by the complexities of Booth multiplications. The threshold for multiplication is around 3298 bits and for sqr around 32389 bits, so if the sum of input bits cross this threshold then Karatsuba multiplication will be used for speeding up multiplication and that goes similar for sqr too.

BTW, optimizations included:

  • Minimize heap trashing by too-big recursion argument
  • Avoidance of any bignum aritmetics (+,-) 32-bit ALU with carry is used instead.
  • Ignoring 0*y or x*0 or 0*0 cases
  • Reformatting input x,y number sizes to power of two to avoid reallocating
  • Implement modulo multiplication for z1 = (x0 + x1)*(y0 + y1) to minimize recursion

Modified Schönhage-Strassen multiplication to sqr implementation

I have tested use of FFT and NTT transforms to speed up sqr computation. The results are these:

  1. FFT

Lose accuracy and therefore need high precision complex numbers. This actually slows things down considerably so no speedup is present. The result is not precise (can be wrongly rounded)so FFT is unusable (for now)

  1. NTT

NTT is finite field DFT and so no accuracy loss occurs. It need modular arithmetics on unsigned integers: modpow, modmul, modadd and modsub.

I use DWORD (32bit unsigned integer numbers). The NTT input/otput vector size is limited because of overflow issues!!! For 32-bit modular arithmetics, N is limited to (2^32)/(max(input[])^2) so bigint must be divided to smaller chunks (I use BYTES so maximum size of bigint processed is

    (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)

The sqr uses only 1xNTT + 1xINTT instead of 2xNTT + 1xINTT for multiplication but NTT usage is too slow and the threshold number size is too large for practical use in my implementation (for mul and also for sqr).

Is possible that is even over the overflow limit so 64-bit modular arithmetics should be used which can slow things down even more. So NTT is for my purposes also unusable too.

Some measurements:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

My implementation:

void arbnum::sqr_NTT(const arbnum &x)
{
    // O(N*log(N)*(log(log(N)))) - 1x NTT
    // Schönhage-Strassen sqr
    // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
    int i, j, k, n;
    int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
    i = x.siz;
    for (n = 1; n < i; n<<=1)
        ;
    if (n + n > 0x3000) {
        _error(_arbnum_error_TooBigNumber);
        zero();
        return;
    }
    n <<= 3;
    DWORD *xx, *yy, q, qq;
    xx = new DWORD[n+n];
    #ifdef _mmap_h
    if (xx)
        mmap_new(xx, (n+n) << 2);
    #endif
    if (xx==NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        zero();
        return;
    }
    yy = xx + n;

    // Zero padding (and split DWORDs to BYTEs)
    for (i--, k=0; i >= 0; i--)
    {
        q = x.dat[i];
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++;
    }
    for (;k<n;k++)
        xx[k] = 0;

    //NTT
    fourier_NTT ntt;

    ntt.NTT(yy,xx,n);    // init NTT for n

    // Convolution
    for (i=0; i<n; i++)
        yy[i] = modmul(yy[i], yy[i], ntt.p);

    //INTT
    ntt.INTT(xx, yy);

    //suma
    q=0;
    for (i = 0, j = 0; i<n; i++) {
        qq = xx[i];
        q += qq&0xFF;
        yy[n-i-1] = q&0xFF;
        q>>=8;
        qq>>=8;
        q+=qq;
    }

    // Merge WORDs to DWORDs and copy them to result
    _alloc(n>>2);
    for (i = 0, j = 0; i<siz; i++)
    {
        q  =(yy[j]<<24)&0xFF000000; j++;
        q |=(yy[j]<<16)&0x00FF0000; j++;
        q |=(yy[j]<< 8)&0x0000FF00; j++;
        q |=(yy[j]    )&0x000000FF; j++;
        dat[i] = q;
    }

    #ifdef _mmap_h
    if (xx)
        mmap_del(xx);
    #endif
    delete xx;
    bits = siz<<5;
    sig = s;
    exp = exp0 + (siz<<5) - 1;
        // _normalize();
    }

Conclusion

For smaller numbers, it is the best option my fast sqr approach, and after threshold Karatsuba multiplication is better. But I still think there should be something trivial which we have overlooked. Has anyone other ideas?

NTT optimization

After massively-intense optimizations (mostly NTT): Stack Overflow question Modular arithmetics and NTT (finite field DFT) optimizations.

Some values have changed:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

So now NTT multiplication is finally faster than Karatsuba after about 1500*32-bit threshold.

Some measurements and bug spotted

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

I found out that my Karatsuba (over/under)flows the LSB of each DWORD segment of bignum. When I have researched, I will update the code...

Also, after further NTT optimizations the thresholds changed, so for NTT sqr it is 310*32 bits = 9920 bits of operand, and for NTT mul it is 1396*32 bits = 44672 bits of result (sum of bits of operands).

Karatsuba code repaired thanks to @greybeard

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
    // Recursion for Karatsuba
    // z[2n] = x[n]*y[n];
    // n=2^m
    int i;
    for (i=0; i<n; i++)
        if (x[i]) {
            i=-1;
            break;
        } // x==0 ?

    if (i < 0)
        for (i = 0; i<n; i++)
            if (y[i]) {
                i = -1;
                break;
            } // y==0 ?

    if (i >= 0) {
        for (i = 0; i < n + n; i++)
            z[i]=0;
            return;
        } // 0.? = 0

    if (n == 1) {
        alu.mul(z[0], z[1], x[0], y[0]);
        return;
    }

    if (n< 1)
        return;
    int n2 = n>>1;
    _mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
    _mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
    DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
    BYTE cx,cy;
    if (q == NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        return;
    }
    #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
    #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
    qq = q;
    q0 = x + n2;
    q1 = x;
    i = n2 - 1;
    _add;
    cx = alu.cy; // =x0+x1

    qq = q + n2;
    q0 = y + n2;
    q1 = y;
    i = n2 - 1;
    _add;
    cy = alu.cy; // =y0+y1

    _mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)

    if (cx) {
        qq = q + n;
        q0 = qq;
        q1 = q + n2;
        i = n2 - 1;
        _add;
        cx = alu.cy;
    }// += cx*(y0 + y1) << n2

    if (cy) {
        qq = q + n;
        q0 = qq;
        q1 = q;
        i = n2 -1;
        _add;
        cy = alu.cy;
    }// +=cy*(x0+x1)<<n2

    qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
    qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
    qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2

    DWORD ccc=0;

    if (alu.cy)
        ccc++;    // Handle carry from last operation
    if (cx || cy)
        ccc++;    // Handle carry from before last operation
    if (ccc)
    {
        i = n2 - 1;
        alu.add(z[i], z[i], ccc);
        for (i--; i>=0; i--)
            if (alu.cy)
                alu.inc(z[i]);
            else
                break;
    }

    delete[] q;
    #undef _add
    #undef _sub
    }

//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
    // O(3*(N)^log2(3)) ~ O(3*(N^1.585))
    // Karatsuba multiplication
    //
    int s = x.sig*y.sig;
    arbnum a, b;
    a = x;
    b = y;
    a.sig = +1;
    b.sig = +1;
    int i, n;
    for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
        ;
    a._realloc(n);
    b._realloc(n);
    _alloc(n + n);
    for (i=0; i < siz; i++)
        dat[i]=0;
    _mul_karatsuba(dat, a.dat, b.dat, n);
    bits = siz << 5;
    sig = s;
    exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
    //    _normalize();
    }
//---------------------------------------------------------------------------

My arbnum number representation:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
  • dat[siz] is the mantisa. LSDW means least significant DWORD.

  • exp is the exponent of MSB of dat[0]

  • The first nonzero bit is present in the mantissa!!!

      // |-----|---------------------------|---------------|------|
      // | sig | MSB      mantisa      LSB |   exponent    | bits |
      // |-----|---------------------------|---------------|------|
      // | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
      // | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
      // |-----|---------------------------|---------------|------|
      // | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
      // | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
      // |-----|---------------------------|---------------|------|
      // | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
      // | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
      // |-----|---------------------------|---------------|------|
    
3

There are 3 answers

5
masotann On BEST ANSWER

If I understand your algorithm correctly, it seems O(n^2) where n is the number of digits.

Have you looked at Karatsuba Algorithm? It speeds up multiplication using the divide and conquer approach. It may be worth taking a look at.

1
VoronoiPotato On

If you're looking to write a new better exponent you might have to write it in assembly. This is the code from golang.

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s

0
Arty On

Great question you have, thanks!

Decided to implement from scratch a huge C++ solution for you, based on Number Theoretic Transform (NTT) and Discrete Fourier Transform.

To tell in advance, my FFT/NTT code achieves 330x speedup on 2-core old laptop compared to naive school-grade multiplication for the case of array size 2^16 32-bit words. Even bigger arrays above 2^20 in size will give millions times speedup.

Squaring a number with 2^22 words of 32-bit size (i.e. 4 Million words) takes 7 seconds on my NTT and 13 seconds on my FFT, on old 2GHz 2-core laptop with SSE2 only.

To remind, FFT and NTT give multiplication time O(N * Log(N)), while naive school grade algorithm has O(N^2) time. That's why I have so huge speedup described in previous paragraph.

Both together with code are well described in this article, mainly I was inspired by this article when writing below code. Another good article is Nayuki's NTT article.

I was convinced that for quite large numbers these two transforms will beat any other methods, like Karatsuba.

Besides basic approach described in article I also did dozens of optimizations:

  1. For NTT computed set of my own primitive roots and modulos. And used biggest one closest to 2^62.

  2. Used multi threading almost on every loop of NTT and FFT computation. Through OpenMP.

  3. For squaring definitely I used 2 transforms instead of 3 (used for multiply). This gives 33% speed boost.

  4. For NTT used Montgomery Reduction in all arrays when computing modulus. This gave about 2x-3x speedup.

  5. Used constexpr functions and values and templated programming everywhere where I can. Reduction of runtime values to compile time values where possible gives a lot of speedup.

  6. Re-designed swap/shuffle function that is used at every start of FFT/NTT transforms. Used precomputed table and caching for re-using previous results. Also did swapping in blocks to make cache-friendly reads/writes. Also bit twidling is done not in a loop but using pre-computed bit-table.

  7. Inside main loop of transform factored out computation of W multiplier into separate loop together with pre-computation/caching. This gave about 2x speedup.

  8. Used Intel SIMD instruction sets, currently SSE2 and AVX. These are used only for FFT, as NTT uses 128-bit integer division and multiplication and add/sub-with-carry, these are not available in SIMD. Also for SIMD in FFT I designed loop unrolling with special cache-friendly storage of complex numbers in std::array<>.

  9. Did time/performance measurement of NTT/FFT multiplication versus naive.

  10. Did analysis of error rate inside FFT. To remind NTT has no errors at all.

My code is self-contained, if you compile+run it then it will run tests measuring speed. Inside test function you can see how to use my library. Test runs FFT/NTT/Naive multiplication, measures time and compares if all multiplication results are correct, i.e. equal to naive version.

Note: No matter how I struggled to speedup FFT through SIMD, yet my NTT is so optimized that it is 1.3-1.8x times faster than FFT. As you know FFT gives errors which grow with size of big number. And if to take into account a fact that my NTT got faster then NTT is the only option for you!

It appeared that FFT can be used only for array sizes like 2^16 32-bit words, no more, then error size becomes to critical and destructs final result. Or you can decrease size of input 32-bit numbers, to 10-12 bits, this helps to reduce errors, yet you can't go bigger than 2^18 array size with critical error. You have to compute error size experimentally to figure out what is best.

Code can be compiled in CLang/MSVC/GCC. Maybe other compilers too. It has no external libraries dependencies at all, maybe except OpenMP library which is usually shipped with compiler. Only computation of Primitive Roots (NTT modulus) requires Boost library but only for MSVC and uses only 128-bit integer from there.

CODE GOES HERE. Only because code size is 65 KB, I can't inline it inside this post, as StackOverflow post size limit is 30 000 symbols. Hence I'm providing my code in below Github Gist link. Also click Try it online! link to run my code on online server of GodBolt.

Try it online!

Github Gist source code

Example console output:

Using SIMD SSE2
Test FindNttMod 
FindNttEntry<T>{.k = 57, .c = 29, .p = 4179340454199820289, .g = 3, .root = 68630377364883, .plog2 = 61.86},
FindNttEntry<T>{.k = 54, .c = 177, .p = 3188548536178311169, .g = 7, .root = 3055434446054240334, .plog2 = 61.47},
FindNttEntry<T>{.k = 54, .c = 163, .p = 2936346957045563393, .g = 3, .root = 83050791888939419, .plog2 = 61.35},
FindNttEntry<T>{.k = 55, .c = 69, .p = 2485986994308513793, .g = 5, .root = 1700750308946223057, .plog2 = 61.11},
FindNttEntry<T>{.k = 54, .c = 127, .p = 2287828610704211969, .g = 3, .root = 878887558841786394, .plog2 = 60.99},
FindNttEntry<T>{.k = 55, .c = 57, .p = 2053641430080946177, .g = 7, .root = 640559856471874596, .plog2 = 60.83},
FindNttEntry<T>{.k = 56, .c = 27, .p = 1945555039024054273, .g = 5, .root = 1613915479851665306, .plog2 = 60.75},
FindNttEntry<T>{.k = 53, .c = 161, .p = 1450159080013299713, .g = 3, .root = 359678689516082930, .plog2 = 60.33},
FindNttEntry<T>{.k = 53, .c = 143, .p = 1288029493427961857, .g = 3, .root = 531113314168589713, .plog2 = 60.16},
FindNttEntry<T>{.k = 55, .c = 35, .p = 1261007895663738881, .g = 6, .root = 397650301651152680, .plog2 = 60.13},
0.025 sec
Test CompareNttMultWithReg 
Time NTT 0.035 FFT 0.081 Reg 11.614 Boost_NTT 333.588x (FFT 142.644)
Swap 0.776 (Slow 0.000) ToMontg 0.079 Main 3.056 (0.399, 2.656) Invert 0.000 All 3.911
MidMul 0.110
Swap 0.510 (Slow 0.000) ToMontg 0.000 Main 2.535 (0.336, 2.198) Invert 0.094 All 3.139
AssignComplex 0.495
Swap 1.373 FromComplex 0.309 Main 4.875 (0.382, 4.493) Invert 0.000 ToComplex 0.224 All 6.781
MidMul 0.147
Swap 1.106 FromComplex 0.296 Main 4.209 (0.277, 3.931) Invert 0.166 ToComplex 0.199 All 5.975
Round 0.143
Time NTT 7.457 FFT 14.097 Boost_NTT 1.891x
Run Time: 33.719 sec