第 1 章 SIMD 和 SSE/AVX

SIMD (Single Instruction, Multiple Data),即单指令多数据,顾名思义,是通过一条指令对多条数据进行同时操作。据维基百科说,最早得到广泛应用的SIMD指令集是Intel的MMX指令集,共包含57条指令。MMX提供了8个64位的寄存器(MM0 - MM7),每个寄存器可以存放两个32位整数或4个16位整数或8个8位整数,寄存器中“打包”的多个数据可以通过一条指令同时处理,不再需要分成几次分别处理。

例如做8个8位整数加法(c[i] = a[i] + b[i]. i ∈ {0,1,…,7})。

标量算法流程:
i = 0
将a[i]读入寄存器0
将b[i]读入寄存器1
对寄存器0和1内的8位整数求和并将结果存储在寄存器0中
将寄存器0中的八位整数写入c[i]的内存位置
i = i+1
比较i和8
如果小于则重复以上步骤

矢量算法流程:
将a[0-7]读入矢量寄存器0
将b[0-7]读入矢量寄存器1
对矢量寄存器0和1按照8位一个整数同时求和,结果存储在矢量寄存器0中
将适量寄存器0中的64位数据写入c的内存位置

如上例,对于8位整数的加法,理论上最大可以有8倍的提速。

之后,SSE出现了,提供了8个128位寄存器(XMM0 - XMM7),并且有了处理浮点数的能力。可以同时处理两个双精度浮点数或四个单精度浮点数,或者同时处理四个32位整数或者八个16位整数又或者十六个8位整数。

再后来,又升级了AVX。AVX将SSE的每个128位寄存器扩展到256位,并增加了8个256寄存器。16个256位寄存器称作(YMM0 - YMM15)。再后来Intel又推出了AVX512,把YMM扩展到512位,又新增16个寄存器,共32个512位寄存器(ZMM0 - ZMM31)。(前几天Linus还怒斥了Intel的AVX512🤣)

x40服务器是支持AVX512的,但是本指南不介绍AVX512使用方法(主要是因为我也没用过)。但是原则上与AVX大同小异。而且,大概率您正在使用的桌面电脑或笔记本电脑也支持AVX指令集,但不大可能支持AVX512,因此本章的代码您可以在自己的电脑上运行/测试。

1.1 一个简单的程序

単純な馬鹿にありたい。——『日常』

为便于演示SIMD并比较速度,我们创建一个非常非常简单的程序:拿三个double数组,第一组乘上第二组再加上第三组,结果存储在第四个数组中。想必各位读者都能很快实现这样的算法。下面列出本指南使用的程序(后续章节会基于这个程序进行改造)。为了加大处理压力,使运行时间可以观测到,这里强行让程序算1,000,000遍。可以在输出中看到计算耗时在10s左右。在输出运行时间后,挑出一些结果来对照一下看看对不对(嗯,在这个程序里上面和下面一毛一样,怎么会不对呢!)

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

__attribute__ ((noinline))
void muladd(double* a, double* b, double* c,
            double* d, unsigned long long N){
    unsigned long long i;
    for(i = 0; i < N; i++){
        d[i] = a[i] * b[i] + c[i];
    }
}

int main(){
    double* a; 
    double* b; 
    double* c; 
    double* d;

    a = (double*)(malloc(8192*sizeof(double)));
    b = (double*)(malloc(8192*sizeof(double)));
    c = (double*)(malloc(8192*sizeof(double)));
    d = (double*)(malloc(8192*sizeof(double)));

    //Prepare data
    unsigned long long i;
    for(i = 0; i < 8192; i++){
        a[i] = (double)(rand()%2000) / 200.0;
        b[i] = (double)(rand()%2000) / 200.0;
        c[i] = ((double)i)/10000.0;
    }
    
    clock_t start, stop;
    double elapsed;
    start = clock();

    for(i = 0; i < 1000000; i++){
        muladd(a, b, c, d, 8192);
    }

    stop = clock();
    elapsed = (double)(stop-start) / CLOCKS_PER_SEC;
    printf("Elapsed time = %8.6f s\n", elapsed);
    for(i = 0; i < 8192; i++){
        if(i % 1001 == 0){
            printf("%5llu: %12.8f * %12.8f + %12.8f = %12.8f (%d)\n",
                   i, a[i], b[i], c[i], d[i], d[i]==a[i]*b[i]+c[i]);
        }
    }

    free(a);
    free(b);
    free(c);
    free(d);
}

关于程序随便提一句。__attribute__ ((noinline))的作用是告诉编译器不要对这个函数inline展开。因为我们的muladd函数比较简单,可能会被编译器优化以后直接在调用处原地展开。虽然这样也没什么不行,但是为了后续分析程序行为,这里加上了这个标志。

编译运行一下,可以看到运算过程花了20s左右。

1.2 我的第一个SIMD程序!

May the 4s be with you.

我们来使用AVX,一次算四个。要显式地使用AVX指令集,可以使用所谓“intrinsic”。这些intrinsic与CPU指令有直接的对应。可以参考Intel Intrinsics Guide。要使用这些intrinsic,需要#include <x86intrin.h>. 在编译的时候,还要加上-mavx

gcc -mavx -o muladd_simd muladd_simd.c

改造后的程序是这个样子:

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <x86intrin.h>

__attribute__ ((noinline))
void muladd(double* a, double* b, double* c,
            double* d, unsigned long long N){
    unsigned long long i, j;
    unsigned long long M = N>>2;
    for(i = 0; i < M; i++){
        __m256d ymma = _mm256_load_pd(a+i*4);
        __m256d ymmb = _mm256_load_pd(b+i*4);
        __m256d ymmc = _mm256_load_pd(c+i*4);
        __m256d ymmd = _mm256_mul_pd(ymma, ymmb);
        ymmd = _mm256_add_pd(ymmd, ymmc);
        _mm256_store_pd(d+i*4,ymmd);
    }
    for(i = N-N%4; i < N; i++){
        d[i] = a[i] * b[i] + c[i];
    }
}

int main(){
    double* a; 
    double* b; 
    double* c; 
    double* d;

    a = (double*)(aligned_alloc(32,8192*sizeof(double)));
    b = (double*)(aligned_alloc(32,8192*sizeof(double)));
    c = (double*)(aligned_alloc(32,8192*sizeof(double)));
    d = (double*)(aligned_alloc(32,8192*sizeof(double)));

   //Prepare data
    unsigned long long i;
    for(i = 0; i < 8192; i++){
        a[i] = (double)(rand()%2000) / 200.0;
        b[i] = (double)(rand()%2000) / 200.0;
        c[i] = ((double)i)/10000.0;
    }
    
    clock_t start, stop;
    double elapsed;
    start = clock();
    
    for(i = 0; i < 1000000; i++){
        muladd(a, b, c, d, 8192);
    }

    stop = clock();
    elapsed = (double)(stop-start) / CLOCKS_PER_SEC;
    printf("Elapsed time = %8.6f s\n", elapsed);
    for(i = 0; i < 8192; i++){
        if(i % 1001 == 0){
            printf("%5llu: %16.8f * %16.8f + %16.8f = %16.8f (%d)\n",
                   i, a[i], b[i], c[i], d[i], d[i]==a[i]*b[i]+c[i]);
        }
    }

    free(a);
    free(b);
    free(c);
    free(d);
}

关于这段程序又要提一句。malloc(8192*sizeof(double))换成了aligned_alloc(32, 8192*sizeof(double)). 这是因为_mm256_load_pd要求内存是对齐到256bit的(YMM寄存器的宽度是256bit).aligned_alloc函数可以实现开辟对齐到指定地址的内存空间,第一个参数是对齐方式,单位是字节(Byte)(32Byte × 8bit/byte = 256 bit)。第二个参数是内存大小。要注意第二个参数必须是第一个的整数倍。
另外,也可以直接malloc,而把_mm256_load_pd替换为_mm256_load_upd. u代表“unaligned”。

以下太长不看
读取双精度浮点数据到YMM有两个指令,VMOVAPD 和 VMOVUPD。VMOVAPD要求内存地址是对齐到256位的,而VMOVUPD无此要求。据说在早期的支持AVX的cpu上VMOVAPD要快一些,似乎在新一些的平台上性能已经相当了。(但我没有详细考证过。)以及,_mm256_load_upd似乎并不会被翻译为vmovupd,而是变成了两条指令,大致是先读入两个double,再把另外两个插入进来。emm…

赶紧编译运行一下,这回可快了,只要……嗯?要16s?4个double同时算,这提升也太小了吧?咱们看一下编译出来的内容:

$ gdb muladd_simd
  (gdb) disassemble muladd
  Dump of assembler code for function muladd:
   0x0000000000400657 <+0>:     lea    0x8(%rsp),%r10
   0x000000000040065c <+5>:     and    $0xffffffffffffffe0,%rsp
   0x0000000000400660 <+9>:     pushq  -0x8(%r10)
   0x0000000000400664 <+13>:    push   %rbp
   0x0000000000400665 <+14>:    mov    %rsp,%rbp
   0x0000000000400668 <+17>:    push   %r10
   0x000000000040066a <+19>:    sub    $0x150,%rsp
   0x0000000000400671 <+26>:    mov    %rdi,-0x198(%rbp)
   0x0000000000400678 <+33>:    mov    %rsi,-0x1a0(%rbp)
   0x000000000040067f <+40>:    mov    %rdx,-0x1a8(%rbp)
   0x0000000000400686 <+47>:    mov    %rcx,-0x1b0(%rbp)
   0x000000000040068d <+54>:    mov    %r8,-0x1b8(%rbp)
   0x0000000000400694 <+61>:    mov    -0x1b8(%rbp),%rax
   0x000000000040069b <+68>:    shr    $0x2,%rax
   0x000000000040069f <+72>:    mov    %rax,-0x20(%rbp)
   0x00000000004006a3 <+76>:    movq   $0x0,-0x18(%rbp)
   0x00000000004006ab <+84>:    jmpq   0x4007e5 <muladd+398>
   0x00000000004006b0 <+89>:    mov    -0x18(%rbp),%rax
   0x00000000004006b4 <+93>:    shl    $0x5,%rax
   0x00000000004006b8 <+97>:    mov    %rax,%rdx
   0x00000000004006bb <+100>:   mov    -0x198(%rbp),%rax
   0x00000000004006c2 <+107>:   add    %rdx,%rax
   0x00000000004006c5 <+110>:   mov    %rax,-0x188(%rbp)
   0x00000000004006cc <+117>:   mov    -0x188(%rbp),%rax
   0x00000000004006d3 <+124>:   vmovapd (%rax),%ymm0
   0x00000000004006d7 <+128>:   vmovapd %ymm0,-0x50(%rbp)
   0x00000000004006dc <+133>:   mov    -0x18(%rbp),%rax
   0x00000000004006e0 <+137>:   shl    $0x5,%rax
   0x00000000004006e4 <+141>:   mov    %rax,%rdx
   0x00000000004006e7 <+144>:   mov    -0x1a0(%rbp),%rax
   0x00000000004006ee <+151>:   add    %rdx,%rax
   0x00000000004006f1 <+154>:   mov    %rax,-0x180(%rbp)
   0x00000000004006f8 <+161>:   mov    -0x180(%rbp),%rax
   0x00000000004006ff <+168>:   vmovapd (%rax),%ymm0
   0x0000000000400703 <+172>:   vmovapd %ymm0,-0x70(%rbp)
   0x0000000000400708 <+177>:   mov    -0x18(%rbp),%rax
   0x000000000040070c <+181>:   shl    $0x5,%rax
   0x0000000000400710 <+185>:   mov    %rax,%rdx
   0x0000000000400713 <+188>:   mov    -0x1a8(%rbp),%rax
   0x000000000040071a <+195>:   add    %rdx,%rax
   0x000000000040071d <+198>:   mov    %rax,-0x178(%rbp)
   0x0000000000400724 <+205>:   mov    -0x178(%rbp),%rax
   0x000000000040072b <+212>:   vmovapd (%rax),%ymm0
   0x000000000040072f <+216>:   vmovapd %ymm0,-0x90(%rbp)
   0x0000000000400737 <+224>:   vmovapd -0x50(%rbp),%ymm0
   0x000000000040073c <+229>:   vmovapd %ymm0,-0x150(%rbp)
   0x0000000000400744 <+237>:   vmovapd -0x70(%rbp),%ymm0
   0x0000000000400749 <+242>:   vmovapd %ymm0,-0x170(%rbp)
   0x0000000000400751 <+250>:   vmovapd -0x150(%rbp),%ymm0
   0x0000000000400759 <+258>:   vmulpd -0x170(%rbp),%ymm0,%ymm0
   0x0000000000400761 <+266>:   vmovapd %ymm0,-0xb0(%rbp)
   0x0000000000400769 <+274>:   vmovapd -0xb0(%rbp),%ymm0
   0x0000000000400771 <+282>:   vmovapd %ymm0,-0x110(%rbp)
   0x0000000000400779 <+290>:   vmovapd -0x90(%rbp),%ymm0
   0x0000000000400781 <+298>:   vmovapd %ymm0,-0x130(%rbp)
   0x0000000000400789 <+306>:   vmovapd -0x110(%rbp),%ymm0
   0x0000000000400791 <+314>:   vaddpd -0x130(%rbp),%ymm0,%ymm0
   0x0000000000400799 <+322>:   vmovapd %ymm0,-0xb0(%rbp)
   0x00000000004007a1 <+330>:   mov    -0x18(%rbp),%rax
   0x00000000004007a5 <+334>:   shl    $0x5,%rax
   0x00000000004007a9 <+338>:   mov    %rax,%rdx
   0x00000000004007ac <+341>:   mov    -0x1b0(%rbp),%rax
   0x00000000004007b3 <+348>:   add    %rdx,%rax
   0x00000000004007b6 <+351>:   mov    %rax,-0xb8(%rbp)
   0x00000000004007bd <+358>:   vmovapd -0xb0(%rbp),%ymm0
   0x00000000004007c5 <+366>:   vmovapd %ymm0,-0xf0(%rbp)
   0x00000000004007cd <+374>:   mov    -0xb8(%rbp),%rax
   0x00000000004007d4 <+381>:   vmovapd -0xf0(%rbp),%ymm0
   0x00000000004007dc <+389>:   vmovapd %ymm0,(%rax)
   0x00000000004007e0 <+393>:   addq   $0x1,-0x18(%rbp)
   0x00000000004007e5 <+398>:   mov    -0x18(%rbp),%rax
   0x00000000004007e9 <+402>:   cmp    -0x20(%rbp),%rax
   0x00000000004007ed <+406>:   jb     0x4006b0 <muladd+89>
   0x00000000004007f3 <+412>:   mov    -0x1b8(%rbp),%rax
   0x00000000004007fa <+419>:   and    $0xfffffffffffffffc,%rax
   0x00000000004007fe <+423>:   mov    %rax,-0x18(%rbp)
   0x0000000000400802 <+427>:   jmp    0x400879 <muladd+546>
   0x0000000000400804 <+429>:   mov    -0x18(%rbp),%rax
   0x0000000000400808 <+433>:   lea    0x0(,%rax,8),%rdx
   0x0000000000400810 <+441>:   mov    -0x198(%rbp),%rax
   0x0000000000400817 <+448>:   add    %rdx,%rax
   0x000000000040081a <+451>:   vmovsd (%rax),%xmm1
   0x000000000040081e <+455>:   mov    -0x18(%rbp),%rax
   0x0000000000400822 <+459>:   lea    0x0(,%rax,8),%rdx
   0x000000000040082a <+467>:   mov    -0x1a0(%rbp),%rax
   0x0000000000400831 <+474>:   add    %rdx,%rax
   0x0000000000400834 <+477>:   vmovsd (%rax),%xmm0
   0x0000000000400838 <+481>:   vmulsd %xmm0,%xmm1,%xmm0
   0x000000000040083c <+485>:   mov    -0x18(%rbp),%rax
   0x0000000000400840 <+489>:   lea    0x0(,%rax,8),%rdx
   0x0000000000400848 <+497>:   mov    -0x1a8(%rbp),%rax
   0x000000000040084f <+504>:   add    %rdx,%rax
   0x0000000000400852 <+507>:   vmovsd (%rax),%xmm1
   0x0000000000400856 <+511>:   mov    -0x18(%rbp),%rax
   0x000000000040085a <+515>:   lea    0x0(,%rax,8),%rdx
   0x0000000000400862 <+523>:   mov    -0x1b0(%rbp),%rax
   0x0000000000400869 <+530>:   add    %rdx,%rax
   0x000000000040086c <+533>:   vaddsd %xmm1,%xmm0,%xmm0
   0x0000000000400870 <+537>:   vmovsd %xmm0,(%rax)
   0x0000000000400874 <+541>:   addq   $0x1,-0x18(%rbp)
   0x0000000000400879 <+546>:   mov    -0x18(%rbp),%rax
   0x000000000040087d <+550>:   cmp    -0x1b8(%rbp),%rax
   0x0000000000400884 <+557>:   jb     0x400804 <muladd+429>
   0x000000000040088a <+563>:   nop
   0x000000000040088b <+564>:   add    $0x150,%rsp
   0x0000000000400892 <+571>:   pop    %r10
   0x0000000000400894 <+573>:   pop    %rbp
   0x0000000000400895 <+574>:   lea    -0x8(%r10),%rsp
   0x0000000000400899 <+578>:   retq
End of assembler dump.

对比一下标量版得到的汇编:

$ gdb muladd
  (gdb) disassemble muladd
  Dump of assembler code for function muladd:
   0x0000000000400637 <+0>:     push   %rbp
   0x0000000000400638 <+1>:     mov    %rsp,%rbp
   0x000000000040063b <+4>:     mov    %rdi,-0x18(%rbp)
   0x000000000040063f <+8>:     mov    %rsi,-0x20(%rbp)
   0x0000000000400643 <+12>:    mov    %rdx,-0x28(%rbp)
   0x0000000000400647 <+16>:    mov    %rcx,-0x30(%rbp)
   0x000000000040064b <+20>:    mov    %r8,-0x38(%rbp)
   0x000000000040064f <+24>:    movq   $0x0,-0x8(%rbp)
   0x0000000000400657 <+32>:    jmp    0x4006c2 <muladd+139>
   0x0000000000400659 <+34>:    mov    -0x8(%rbp),%rax
   0x000000000040065d <+38>:    lea    0x0(,%rax,8),%rdx
   0x0000000000400665 <+46>:    mov    -0x18(%rbp),%rax
   0x0000000000400669 <+50>:    add    %rdx,%rax
   0x000000000040066c <+53>:    movsd  (%rax),%xmm1
   0x0000000000400670 <+57>:    mov    -0x8(%rbp),%rax
   0x0000000000400674 <+61>:    lea    0x0(,%rax,8),%rdx
   0x000000000040067c <+69>:    mov    -0x20(%rbp),%rax
   0x0000000000400680 <+73>:    add    %rdx,%rax
   0x0000000000400683 <+76>:    movsd  (%rax),%xmm0
   0x0000000000400687 <+80>:    mulsd  %xmm1,%xmm0
   0x000000000040068b <+84>:    mov    -0x8(%rbp),%rax
   0x000000000040068f <+88>:    lea    0x0(,%rax,8),%rdx
   0x0000000000400697 <+96>:    mov    -0x28(%rbp),%rax
   0x000000000040069b <+100>:   add    %rdx,%rax
   0x000000000040069e <+103>:   movsd  (%rax),%xmm1
   0x00000000004006a2 <+107>:   mov    -0x8(%rbp),%rax
   0x00000000004006a6 <+111>:   lea    0x0(,%rax,8),%rdx
   0x00000000004006ae <+119>:   mov    -0x30(%rbp),%rax
   0x00000000004006b2 <+123>:   add    %rdx,%rax
   0x00000000004006b5 <+126>:   addsd  %xmm1,%xmm0
   0x00000000004006b9 <+130>:   movsd  %xmm0,(%rax)
   0x00000000004006bd <+134>:   addq   $0x1,-0x8(%rbp)
   0x00000000004006c2 <+139>:   mov    -0x8(%rbp),%rax
   0x00000000004006c6 <+143>:   cmp    -0x38(%rbp),%rax
   0x00000000004006ca <+147>:   jb     0x400659 <muladd+34>
   0x00000000004006cc <+149>:   nop
   0x00000000004006cd <+150>:   pop    %rbp
   0x00000000004006ce <+151>:   retq
End of assembler dump.

差好多!仔细看一下,SIMD版本中<+89>到<+128>这部分。简单解释一下:

<位置> 汇编 含义
< +89> mov -0x18(%rbp),%rax 将循环变量(i) 读入寄存器rax
< +93> shl $0x5,%rax rax中的数左移5位(相当于*32).
< +97> mov %rax,%rdx 将rax中的数拷贝到rdx寄存器中
<+100> mov -0x198(%rbp),%rax 将数组A的地址读入寄存器rax
<+107> add %rdx,%rax rax = rax + rdx
<+124> vmovapd (%rax),%ymm0 从rax寄存器所示地址处读取256位到YMM0
<+128> vmovapd %ymm0,-0x50(%rbp) 将YMM0中的数据保存在栈上的一个位置

如果您曾经学习过《汇编语言程序设计》或者《微机原理》等课程,可能会对上述汇编代码有疑问。这是由于通常微机原理课程使用的是Intel 格式的汇编(ins dst, src),而GCC使用AT&T格式(ins src, dst),一般来讲操作数的顺序和Intel格式都是反过来的。
如果您并不懂汇编,也可以直接看结论。如果对汇编感兴趣,可以参考附录。

简单来说,就是每一次循环把需要的值取到寄存器中,再存到栈上(可以理解为函数自有的内存)。三个数据都存好后,再把他们依次从栈里取到不同的YMM寄存器里,做求积和求和,再存到相应位置。(这只是大致过程,实际看一看,它干的傻事还不少)。

仔细想想,其实编译器这么做是有道理的。在程序里,我们写了__m256d ymma = _mm256_load_pd(a+i*4);,这意味着我们要有一个ymma变量。编译器并不知道我们会不会对这个变量做 ·一·些·奇·怪·的·事· 因此便把它又写进了栈里。那么有没有办法让编译器明白我们不会乱动ymma呢?有!

1.3 搞个寄存器变量

我买几个橘子去。你就在此地,不要走动。——《背影》

只要在声明变量的时候,在类型前面加上register关键字,编译器就明白这个东西应该常驻在寄存器里面,就不会来回读写内存了。需要注意不能对寄存器变量取地址!顺便,这个浓眉大眼的循环次数也可以放到寄存器里面的样子,把这些东西加上register关键字:

    register unsigned long long M = N>>2;
    for(i = 0; i < M; i++){
        register __m256d ymma = _mm256_load_pd(a+i*4);
        register __m256d ymmb = _mm256_load_pd(b+i*4);
        register __m256d ymmc = _mm256_load_pd(c+i*4);
        register __m256d ymmd = _mm256_mul_pd(ymma, ymmb);
        ymmd = _mm256_add_pd(ymmd, ymmc);
        _mm256_store_pd(d+i*4,ymmd);
    }

再来运行下。别忘了编译时加-mavx
大概12秒左右。已经比最原始的版本快了近一倍了!四舍五入一个亿啊!

在新标准的c++中,据说register关键字已经没有意义了,加了register编译器也不一定用寄存器,不加也不一定就不用。但是在我用c++测试的时候,发现还是有用的。
在g++里增加-std=c++11或者-std=c++14都可以顺利编译,运行速度一致。如果加-std=c++17则会报几个warning,说是新标准不许regester,但是还是会比不加register关键字要快。(啊哈,你个c++17,嘴上说着warning,身体倒是很老实嘛!)

1.4 我比编译器聪明系列

コンピューターも、天国へいけるかな。——『Lost Universe』

还能不能再快一点?我想要的四倍提速呢?
下面我们使用c语言的内联汇编做这件事。

__attribute__ ((noinline))
void muladd(double* a, double* b, double* c,
            double* d, unsigned long long N){
    unsigned long long i;
    __asm__ __volatile__(
            "movq %0, %%rax \n\t"
            "movq %1, %%rbx \n\t"
            "movq %2, %%rcx \n\t"
            "movq %3, %%rdx \n\t"
            "movq %4, %%r8  \n\t"
            "shr  $2, %%r8  \n\t"
            "movq $0, %%r9  \n\t"
            "jmp  .check_%= \n\t"
            ".loop_%=:         \n\t"
            "shl $2, %%r9   \n\t"
            "vmovupd (%%rax, %%r9, 8), %%ymm0 \n\t"
            "vmovupd (%%rbx, %%r9, 8), %%ymm1 \n\t"
            "vmovupd (%%rcx, %%r9, 8), %%ymm2 \n\t"
            "vmulpd %%ymm0, %%ymm1, %%ymm3    \n\t"
            "vaddpd %%ymm2, %%ymm3, %%ymm3    \n\t"
            "vmovupd %%ymm3, (%%rdx, %%r9, 8) \n\t"
            "shr $2, %%r9                  \n\t"
            "add $1, %%r9                  \n\t"
            ".check_%=:                    \n\t"
            "cmpq %%r8, %%r9               \n\t"
            "jl .loop_%=                   \n\t"
            :
            :"m"(a), "m"(b), "m"(c), "m"(d), "m"(N)
            :"%rax", "%rbx", "%rcx", "%rdx", "%r8", "%r9",
             "%ymm0", "%ymm1", "%ymm2", "%ymm3", "memory"
            );
    if(N%4!=0){
        for(i = N-N%4; i<N; i++){
            d[i] = a[i]*b[i]+c[i];
        }
    }
}

关于内联汇编的介绍,请参见附录。这里用一段汇编代替了前面的循环部分。当N不是4的整数倍时余下的1-3个运算使用普通c代码,并不会对性能产生太大影响。以及,这次编译不需要-mavx了。

编译,运行一下。不到4s!达到了预期的速度。

当然,也可以找一下AVX512指令怎么写(其实跟AVX是基本一样的),获得进一步加速。

1.5 编译器比我聪明系列

Doing nothing often leads to the very best something.——Winnie-the-Pooh

虽然这里手写了汇编,但是在一般的程序里不建议这么做,除非有信心做得好并且能加速。

实测,gcc -mavx -O3 -o muladd muladd.c编译出来的程序与上述内联汇编的程序速度相差无几。在-O3优化下gcc是有矢量化优化的,也是把运算优化到了四个一组同时做。

同理,gcc -mavx512f -O3 -o muladd muladd.c会进行AVX512矢量化,八个一组,再快一倍(可能不到点一倍)。

此外,使用intrinsic的同时,如果打开-O2优化,速度也会很快,甚至比手写汇编快那么一点。看一下反汇编,是少了一些取数到寄存器的操作(好吧我承认我学艺不精)。但是-O2优化本身是不进行矢量化的,也就是说在-O2优化下如果使用我们最开始的muladd.c的话是利用不到SIMD来加速的。因此用-O2配合intrinsic可能是一个好的选择。

顺便提一下,无论是汇编还是intrinsic,可读性都比较差,最好多加一些注释。

1.6 本章小结

SIMD大概是粒度最小的并行了吧?如同上面的例子,适当使用已经可以极大地提高程序运行速度了。

要注意一点,这里面“矢量运算”只是同时做好几个运算地意思,如果你想真的把它当作矢量,甚至想求个外积,那就超出处理器的能力范围了。只能把外积运算分解成加减乘除再进行计算。

这一章到这里就应该结束了,但我觉得阅读指南的你脑中一定充满了问号。所以在这里尝试自问自答一下。


问:为什么要做1,000,000次8192维数组的运算,而不直接用8,192,000,000维的数组?

答:是的。在我制作这份指南之前也是这么想的。但是发现SIMD以后速度并没有提高。经查,发现原理在于内存带宽吃满了。虽然处理器可以算得飞快,但是内存速度跟不上导致速度被限制了,因此矢量算法和标量算法速度相近。而减小数据规模以后,8192个double都会存储在CPU的高速缓存中,这样就避免了内存带宽的限制,从而可以比较向量和标量指令的速度差异。
在实际使用中,对数据的处理显然不只是先求积再求和这种简单运算。随着运算变得复杂,当处理数据的时间达到甚至超过内存存取用时的时候,SIMD的提速效果自然就可以体现出来了。

此外,\(8,192,000,000 double \times 8 Byte/double = 65536 \times 10^{6} Byte \approx 64GB\)吃内存稍微有点大。


问:既然gcc可以优化得这么好,搞这么复杂就没有意义了吧?

答:这可不好说了。在很久以前,听说gcc的-O3有可能会搞出莫名其妙的bug。虽然我没遇到过这种情况,虽然这种情况肯定非常少见,但是总的来说给人一种不太可靠的感觉。此外,纯凭个人感觉,可能gcc的优化力量也是有限的,如果计算过程比较长比较复杂,可能gcc并不一定能选出最合适的优化方法(可能会找不到数据/变量之间的关联?)。还有,如果你的程序不使用C/C++的话,还有一定的可能你的编译器并不支持AVX,这时就要自己动手丰衣足食(当然也可以用C/C++搞一个动态链接库)。另外-O2不会自动进行矢量化,所以至少intrinsic在-O2下还是有意义的。不开优化开关的话,还是要汇编才快一些。具体怎么写代码用什么优化还是要自己取舍。


问:作者你好厉害哦!咋啥都懂?

答:正常小朋友一般问不出来这种问题。

1.7 补充内容

在这里补充一些比较常用(大概)的AVX指令和对应的Intrinsic。(指令都是AT&T格式)

Intrinsic 指令 描述
_mm256_loadu_pd VMOVUPD m256, ymm 将内存中连续的4个double读入ymm
_mm256_storeu_pd VMOVUPD ymm, m256 将ymm寄存器中的4个double写入连续内存
_mm256_add_pd VADDPD ymm1, ymm2, ymm3 ymm3 = ymm2 + ymm1
_mm256_sub_pd VSUBPD ymm1, ymm2, ymm3 ymm3 = ymm2 - ymm1
_mm256_mul_pd VMULPD ymm1, ymm2, ymm3 ymm3 = ymm2 * ymm1
_mm256_div_pd VDIVPD ymm1, ymm2, ymm3 ymm3 = ymm2 / ymm1
_mm256_permute4x64_pd VPERMPD imm8, ymm1, ymm2 ymm1按照imm8所述重排入ymm2

解释一下除法和减法反过来的问题。前面提到AT&T汇编和Intel汇编是反向的。在MOV之类的指令中,AT&T格式看起来和谐一点,结果就是在减法和除法的地方反过来了。

关于permute,imm8是一个立即数,每两位为一组,表示ymm中的第某个数。

例:
ymm1 = (a, b, c, d)
VPEMPD $0xc5, ymm1, ymm2
=> ymm2 = (b, b, a, d)

是这样的,0xc5 = 11 00 01 01 (b). 00对应a,01对应b,10对应c,11对应d。 而从MSB到LSB的顺序是(3,2,1,0)的顺序。因此目标的第0个double是根据最低两位01取数据,取到b; 第一个同理,01->b;第二个是00,取a;第三个是11,取d。另外,如果内存里连续的double是(a,b,c,d)的话,VMOVUPD取进ymm寄存器的值就是(a,b,c,d).如果参考一些汇编手册,可能会发现手册里写的是反过来的(d,c,b,a).这只是表示方法不同。在汇编手册里习惯按照低位在右高位在左的顺序表示。