SSE4和字符串函数

简而言之,SSE4.2加入了7个指令,CRC32、PCMPESTRI、PCMPESTRM、PCMPISTRI、PCMPISTRM、PCMPGTQ和POPCNT。
【未完待续】

SSE4指令集的字符串计算库

  1. __m128i _mm_load_si128 (__m128i *p) = MOVDQA
    负责将p位置的值加载到一个寄存器中(用返回的__m128i标记)。
    p必须是16bytes对齐的,否则应当使用_mm_loadu_si128
    根据Milo的博文_mm_loadu_si128能可能造成一定概率的崩溃。
    一般的解决方案是通过(ptr + 15) & ~15得到下一个16bytes对齐的位置。注意到,如果ptr本身就是对齐的,那么它会返回ptr本身。

  2. int _mm_cvtsi128_si32 (__m128i a) = MOVD
    负责将将低32位移到int中。

  3. __m128i _mm_cmpistrm (__m128i a, __m128i b, const int mode)
    封装了PCMPISTRM这个指令,这里最后的M表示Mask。
    按照mode对两组字符进行比较,其中mode参数包含a/b类型(byte或者word)、比较方式、返回方式

    1. SIDD_UBYTE_OPS
      a是pattern,对b中每个字符x,查看a中是否存在。结果rr[i]表示b[i]的结果(所以打印出来看起来是倒过来的)。
    2. _SIDD_CMP_EQUAL_EACH
      a和b逐比特比较。
    3. _SIDD_CMP_RANGES
      a是pattern,如azAZ表示从a到z和从A到Z的字符,对b中每个字符,查看是否在a的范围中。
    4. _SIDD_CMP_EQUAL_ORDERED
      在b中搜索a,1标记第一个位置。
    5. _SIDD_NEGATIVE_POLARITY
      这个按比特会翻转一下结果。
    6. _SIDD_UNIT_MASK_SIDD_BIT_MASK
      一般用_SIDD_BIT_MASK 就行了,根据stackoverflow_SIDD_UNIT_MASK返回16个bytes,而不是bits。。。所以没必要用这个是吧。。。
  4. __m128i _mm_cmpestrm (__m128i a, int la, __m128i b, int lb, const int mode)
    同_mm_cmpistrm,但可以指定长度lalb

  5. int _mm_cmpistri(__m128i a, __m128i b, const int mode)int _mm_cmpestri(__m128i a, int la, __m128i b, int lb, const int mode)
    这两个函数封装了PCMPISTRI这个指令,这里最后的I表示Index,用来返回最高位或者最低位的1的index。_SIDD_LEAST_SIGNIFICANT参数表示从最右起,_SIDD_MOST_SIGNIFICANT表示从最左起,要是没有就返回MaxSize。
    以下面的代码为例,对于_SIDD_LEAST_SIGNIFICANT返回1,对于_SIDD_MOST_SIGNIFICANT返回3。倘若将pat_str改为"c",那么返回值都是16,也就是MaxSize。

    1
    2
    3
    4
    5
    6
    static const char pat_str[] = "a";
    static const char test_str[] = "badab";
    const int i = _mm_cmpistri(pat_w, test_w,
    _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT
    );
    printf("%d\n", i);

    注意我们现在使用的是_SIDD_CMP_EQUAL_ANY,如果换成_SIDD_CMP_EQUAL_EACH的话就是逐位比较,其mask结果是111111111111100000,那么我们返回值分别是5和15。

  6. ACOSZ
    有没有很熟悉呢?这五个辅助函数被用来读取EFLAGS寄存器的值。这些指令一般在_mm_cmpistrm等四个函数之后调用,这样其功能相当于直接获取寄存器的值。有的编译器可能会在调用_mm_cmpistrm后再调用一次pcmpistri,这个看起来是冗余的。具体案例可以查看SoF
    int _mm_cmpistrz(__m128i a, __m128i b, const int mode)为例,这个指令用来检测b中是否存在\0

利用SSE4优化基础字符串函数

为什么strcmp不使用SSE4.2优化

在一些C库中strcmp并没有实现SSE优化。据stackoverflow,原因是需要首先知道字符串的长度,这就需要遍历一遍字符串。但如果已经遍历一遍了,就已经完成了非优化版的strcmp了,就没必要再用SSE2了。

但是SSE4.2指令集具有下列特性

  1. 不要求字符串对齐
  2. 能正确处理zero-terminated string和Pascal-style string
    所谓的Pascal-style就是第一个byte用来表示字符串长度。
  3. 可以对Unicode支付,signed/unsigned bytes使用
  4. 带四个聚合操作

实现strlen

简单解释下last_1_off,x是r的lowbit。这个函数表示最低的1是从右开始数第几位。

1
2
3
4
5
6
7
8
9
int last_1_off(unsigned int r) {
int ans = -1;
unsigned x = r & ((~r) + 1); // ~r + 1 == -n
while (x) {
x >>= 1;
ans++;
}
return ans;
}

下面是主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
inline int sse4_strlen(const char * s) {
int l = 0;
static const char pat_str[16] = "\0";
const __m128i pat_w = _mm_load_si128((const __m128i *)&pat_str[0]);
while (1) {
const __m128i test_w = _mm_load_si128((const __m128i *)&s[l]);
unsigned r = _mm_cvtsi128_si32(_mm_cmpistrm(pat_w, test_w,
_SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_BIT_MASK
));
if (r == 0) {
l += 16;
}
else {
return l + last_1_off(r);
}
}
}

上面的代码在G++下编译可能出现段错误,这是由于对齐所致,此时应当使用_mm_loadu_si128来代替。

实现strcmp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
inline int sse4_strcmp(const char * a, const char * b) {
int l = 0;
while (1) {
const __m128i a_w = _mm_load_si128((const __m128i *)&a[l]);
const __m128i b_w = _mm_load_si128((const __m128i *)&b[l]);
unsigned r = _mm_cvtsi128_si32(_mm_cmpistrm(a_w, b_w,
_SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_BIT_MASK | _SIDD_NEGATIVE_POLARITY
));
if (r == 0) {
// 两个字符串相同
if (a[l] == 0) {
// 两个字符串都结束,注意如果只有一个字符串结束,那么r不可能为0
return 0;
}
l += 16;
}
else {
l += last_1_off(r);
return a[l] - b[l] < 0 ? -1: 1;
}
}
}

实现strchr

Reference

  1. https://www.strchr.com/strcmp_and_strlen_using_sse_4.2