AVX512 Vector Multiplication Speed

434 views Asked by At

I have Function like this:

#define SPLIT(zmm, ymmA, ymmB) \
ymmA = _mm512_castsi512_si256(zmm); \
ymmB = _mm512_extracti32x8_epi32(zmm, 1)

#define PAIR_AND_BLEND(src1, src2, dst1, dst2) \
dst1 = _mm256_blend_epi32(src1, src2, 0b11110000); \
dst2 = _mm256_permute2x128_si256(src1, src2, 0b00100001);

#define OPERATE_ROW_2(i, ymmA, ymmB)        \
zmm##i = _mm512_maddubs_epi16(zmm30, zmm##i);     \
zmm##i = _mm512_madd_epi16(zmm##i, zmm31);        \
SPLIT(zmm##i, ymmA, ymmB);

/*
 * multiply query to each code in codes.
 * @param n: number of codes
 * @param query: 64 x uint8_t array data
 * @param codes: 64 x n x uint8_t array data
 * @param output: n x int32_t array data, to store output data.
 */
void avx_IP_distance_64_2(size_t n,
                          const uint8_t *query,
                          const uint8_t *codes,
                          int32_t *output){
    __m512i zmm0, zmm1, zmm2, zmm3,
            zmm4, zmm5, zmm6, zmm7,
            zmm8, zmm9, zmm10, zmm11,
            zmm12, zmm13, zmm14, zmm15,
            zmm16, zmm17, zmm18, zmm19,
            zmm20, zmm21, zmm22, zmm23,
            zmm24, zmm25, zmm26, zmm27,
            zmm28, zmm29, zmm30, zmm31;

    __m256i ymm0, ymm1, ymm2, ymm3,
            ymm4, ymm5, ymm6, ymm7,
            ymm8, ymm9, ymm10, ymm11,
            ymm12, ymm13, ymm14, ymm15;

    zmm30 = _mm512_loadu_si512(query);
    zmm31 = _mm512_set1_epi16(1);

    int k_8 = n / 8;
    int left = n % 8;
    for (int i = 0; i < k_8; ++i){
        zmm0 = _mm512_loadu_si512(codes);
        zmm1 = _mm512_loadu_si512(codes + 64 * 1);
        zmm2 = _mm512_loadu_si512(codes + 64 * 2);
        zmm3 = _mm512_loadu_si512(codes + 64 * 3);
        zmm4 = _mm512_loadu_si512(codes + 64 * 4);
        zmm5 = _mm512_loadu_si512(codes + 64 * 5);
        zmm6 = _mm512_loadu_si512(codes + 64 * 6);
        zmm7 = _mm512_loadu_si512(codes + 64 * 7);

        OPERATE_ROW_2(0, ymm0, ymm1);
        OPERATE_ROW_2(1, ymm2, ymm3);
        OPERATE_ROW_2(2, ymm4, ymm5);
        OPERATE_ROW_2(3, ymm6, ymm7);
        OPERATE_ROW_2(4, ymm8, ymm9);
        OPERATE_ROW_2(5, ymm10, ymm11);
        OPERATE_ROW_2(6, ymm12, ymm13);
        OPERATE_ROW_2(7, ymm14, ymm15);

        ymm0 = _mm256_add_epi32(ymm0, ymm1);
        ymm2 = _mm256_add_epi32(ymm2, ymm3);
        ymm4 = _mm256_add_epi32(ymm4, ymm5);
        ymm6 = _mm256_add_epi32(ymm6, ymm7);
        ymm8 = _mm256_add_epi32(ymm8, ymm9);
        ymm10 = _mm256_add_epi32(ymm10, ymm11);
        ymm12 = _mm256_add_epi32(ymm12, ymm13);
        ymm14 = _mm256_add_epi32(ymm14, ymm15);

        PAIR_AND_BLEND(ymm0, ymm8, ymm1, ymm9);
        PAIR_AND_BLEND(ymm2, ymm10, ymm3, ymm11);
        PAIR_AND_BLEND(ymm4, ymm12, ymm5, ymm13);
        PAIR_AND_BLEND(ymm6, ymm14, ymm7, ymm15);

        ymm1 = _mm256_add_epi32(ymm1, ymm9);
        ymm3 = _mm256_add_epi32(ymm3, ymm11);
        ymm5 = _mm256_add_epi32(ymm5, ymm13);
        ymm7 = _mm256_add_epi32(ymm7, ymm15);

        ymm1 = _mm256_hadd_epi32(ymm1, ymm3);
        ymm5 = _mm256_hadd_epi32(ymm5, ymm7);

        ymm1 = _mm256_hadd_epi32(ymm1, ymm5);
        _mm256_storeu_si256((__m256i *)(output), ymm1);

        codes += 8 * 64;
        output += 8;
    }

    for (int i = 0; i < left; ++i){
        OPERATE_ROW_1(0);
    }
}


#define LOOP 10

int main(){
    int d = 64; 
    int q = 1;
    int n = 100000;

    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;

    uint8_t *codes = new uint8_t[d * n]; 
    uint8_t *query = new uint8_t[d * q]; 

    int32_t *output = new int32_t[n];

    for (int i = 0; i < n; ++i){
        for (int j = 0; j < d; ++j){
            // codes[d*i+j] = j;
            codes[d * i + j] = int(distrib(rng)) * 127;
        }   
    }   

    for (int i = 0; i < q; ++i){
        for (int j = 0; j < d; ++j){
            // query[d*i+j] = j;
            query[d * i + j] = int(distrib(rng)) * 127 - 64; 
        }   
    }

    Timer timer;
    timer.start();
    for (int i = 0; i < LOOP; ++i){
        avx_IP_distance_64_2(n, query, codes, output);
    }
    timer.end("Second type");
    return 0;
}

When n = 10k, time duration is: 0.143917 ms

When n = 100k, time duration is: 3.2002 ms

When N is less than 10k, the time consumption basically increases linearly.

I suspect it’s a caching problem, but I’m not sure.

I want to know why the time consumption does not increase linearly with n?

0

There are 0 answers