Modulo by a Constant
最近做 PE 的时候也顺便在玩 Rust。我在研究 Rust 编译出的 ASM 的时候,发现一段有趣的代码:
#[inline(never)]
fn opt_mod(x: u64) -> u64 {
x % 1000_000_007
}
编译出的 ASM 长这样:
playground::opt_mod:
movabsq $-8543223828751151131, %rcx
movq %rdi, %rax
mulq %rcx
shrq $29, %rdx
imulq $1000000007, %rdx, %rax
subq %rax, %rdi
movq %rdi, %rax
retq
我震惊的发现,这里面居然没有除法、取模?有点意思……
我猜这里的 mulq %rcx 是会溢出的,最后高 64 位进入 %rdx。如果输入是 ,那么输出为那么给出的结果为
其中 (由于这是个 u64,所以要加上 ),。这里面的那个取整很像在求 ,但是巧妙的避免避免除法,改用一次乘法和一次位移来实现。另一件有趣的事是,它非常自信这样得出的结果为 ,不会 offset-by-1,因为这里面甚至没有判断 是否大于 或小于 0,如果大于 或小于 0,应该再来一次减法或加法。看起来这后面还藏着一些数学,我们来研究一下。
Disclaimer: 我对汇编几乎不了解,理解基本上是靠猜。我也不了解每个指令的时间、CPU 的流水线作业、分支预测等因素对齐性能的影响。这里贴出汇编代码只是为了确认没有奇怪的操作。
Notations
这里我们约定,我们想计算的是 ,其中 都是 u64 内的无符号整形。由于对于同一个 ,我们有很多个 需要做,所以我们想对 做预处理,看能不能比直接用 div 这种汇编指令更快。 定义为常数 。为方便处理,我们假定 不是 2 的幂。
在这篇文章中, 运算优先级低于加减乘除取负等运算,但高于括号,且 的结果均为正数,即 。
Algorithm 1.0
很显然,方程 \eqref{eq
} 成立的充要条件是我们将不等号最右边缩到 :
注意到不等式的三项都是关于 的线性函数,所以这个不等式在区间 上成立等价于在区间两端成立。当 时,原式显然成立,所以我们只需关注 的情况:
将等式最后边再次放缩成 后可得
此时可以看出, 这个不等式在
时有解 ,否则无解。带进去算一下,Rust 选择的这组 满足要求。注意到 不会溢出 u128 ,故 %rdx 这里也不会溢出,难怪如此自信。
我就用 Rust 写了一下:
fn algo1(n: u64) -> u64 {
const MUL: u64 = 9903520244958400485;
const SHR: u32 = 29;
const D: u64 = 1000000007;
let prod_hi = (((x as u128) * (MUL as u128)) >> 64) as u64;
let q = (prod_hi >> SHR);
n - q * D
}
得到的汇编为:(和一开始的几乎一样)
playground::Algo1::modulo:
movabsq $-8543223828751151131, %rcx
movq %rdi, %rax
mulq %rcx
shrq $29, %rdx
imulq $-1000000007, %rdx, %rax
addq %rdi, %rax
retq
仔细看一下,其实还是有小小不同的……这里面只有 7 个指令,Rust 编译器里面有 8 个,区别在于 algo1 里面都是在 %rax 上操作,最后直接返回了 %rax,而 Rust 编译器里都是在 %rdi 里面操作,最后再 mov %rdi, %rdx。为此我还特意去查了一下 Calling convention(之前学的忘光了),好像是 %rdi 里面存参数,返回值在 %rax 里,而且 %rax/%rdi 以及很多寄存器都是 callee 可以随便改的……
要计算这组 的话,直接枚举 就行了,反正也不大……实验中也发现,Rust 给出的就是最小的 。现在问题来了:我们希望 ,这样才能在一个 u64 内存下 。如果找不到这样的 怎么办?枚举了一下 里面所有奇数(从前的分析来看,这和奇偶数没啥关系),大概有 70% 的数是能找到 的,但是剩下的 30% 也不少。
我们还可以观察到:b = b_\max =\lceil \log_2 p\rceil 时,一定能满足 \eqref{eq
},但是此时的 一定存不进一个 u64,而且由于 ,可得 就多了那么一个 bit,真是尴尬……Algorithm 1.1
第一种方法,我们干脆再 relax \eqref{eq
} 。如果我们能满足的话,那么能保证
也就是说,我们能求得一个大概(offset-by-1)的解,带入 \eqref{eq
} 后求得的 就可能是负数,还需要一个 if 来判断。注意 是用一个 u64 来存,判断 正负需要特殊的技巧。事实上,AtCoder 里面的 modint 就是这么处理的,不过它用带符号的数。
Algorithm 1.2
另一种方法,我们再看看 Rust 咋整的……我挑了一个 来看看 Rust 生成的 ASM 是什么样:
playground::opt_mod:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
movq %rdi, %rax
subq %rdx, %rax
shrq %rax
addq %rdx, %rax
shrq $29, %rax
imulq $1000000093, %rax, %rax
subq %rax, %rdi
movq %rdi, %rax
retq
对比一下之前的,这里多了四行:L5-8,我们去看看这到底是在做啥……还是令 ,,于是他算的东西为
原来如此……我们之前的问题不是 出了 u64 吗,但是之前也提到了,必定存在一个 满足要求,那我们就令 ,其中 就可以用一个 u64 存下来了。这里有地方需要注意:虽然 在数学上等于 ,但是后者在 大的时候有可能溢出,前者不会。
对比一下 Algorithm 1.1,这里 Rust 用 4 ops 当做一个 if (以及里面的 subq),如果用 y -= (y >= p) as u64 * p 的话多一个乘法,具体会快多少就不知道了……
这里我也给个 Rust 代码:
fn algo1_2(n: u64) -> u64 {
const MUL: u64 = 1360294712801925637;
const SHR: u32 = 29;
const D: u64 = 1000000093;
let hi = (((n as u128) * (MUL as u128)) >> 64) as u64;
let q = (((n - hi) >> 1) + hi) >> SHR;
n - q * D
}
得到的汇编为也和之前的几乎一样:
playground::algo1_2:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
movq %rdi, %rax
subq %rdx, %rax
shrq %rax
addq %rdx, %rax
shrq $29, %rax
imulq $-1000000093, %rax, %rax
addq %rdi, %rax
retq
Algorithm 2
如果还坚持用等式 \eqref{eq
} 来优化,有点优化不动了。这里我们另辟蹊径:我们不再要求 \eqref{eq} 成立,而是要求成立。我们继续把右边放缩成 ,然后用之前同样的 trick,所有项都是线性的,则 内满足要求等价于区间两端满足要求:
注意到前者第二个不等号等价于后者第二个不等号,故可以合并起来,化简为
可以看出, 在
时存在解 。注意到这个条件和 \eqref{eq
} 很像。事实上,我们可以证明,当 取 时,两者至少有一个满足:首先我们有:对于任意 ,均有又有对于任意 均有 ,故有
所以 和 中至少有一个不超过 。
接下来便不难想到,这个算法可以和 Algorithm 1.0 互补:若 Algorithm 1.0 找不到合适的 ,则 Algorithm 2 一定能找到合适的 。剩下的问题就是怎么有效的实现 Algorithm 2 了。
Algorithm 2 有一个问题使得我们无法直接计算 : 可能会溢出 u64。这也就是我们即将解决的问题。注意到 ,所以 不会溢出 u128。
Algorithm 2.1
第一种方法,我们使用乘法分配律 ,将 拆成 。
fn algo2_1(n: u64) -> u64 {
const MUL: u64 = 9903519393255738626;
const SHR: u32 = 29;
const D: u64 = 1000000093;
let prod = (n as u128) * (MUL as u128) + (MUL as u128);
let prod_hi = (prod >> 64) as u64;
let q = prod_hi >> SHR;
n - q * D
}
得到的汇编为:
playground::algo2_1:
movabsq $-8543224680453812990, %rcx
movq %rdi, %rax
mulq %rcx
addq %rcx, %rax
adcq $0, %rdx
shrq $29, %rdx
imulq $-1000000093, %rdx, %rax
addq %rdi, %rax
retq
Algorithm 2.2
第二种方法,我们注意到, 的溢出只会发生在 这种情况下,而这种情况下,由于 不是 2 的幂,故有 ,即
也就是说,这个时候我用 还是 是没有区别的,那我不加不就好了……所以 Algorithm 2.2 很简单,就是将 \eqref{eq
} 换成这里的 听说可以直接两个汇编指令搞定,但是我搞不出来,只能手写汇编了……以下便是代码:
fn algo2_2(n: u64) -> u64 {
const MUL: u64 = 9903519393255738626;
const SHR: u32 = 29;
const D: u64 = 1000000093;
let mut saturated = n;
unsafe {
asm!(
"add {0}, 1",
"sbb {0}, 0",
inout(reg) saturated,
options(nostack),
);
}
let prod = (saturated as u128) * (MUL as u128);
let prod_hi = (prod >> 64) as u64;
let q = prod_hi >> SHR;
n - q * D
}
汇编代码为:
playground::algo2_2:
movq %rdi, %rax
addq $1, %rax
sbbq $0, %rax
movabsq $-8543224680453812990, %rcx
mulq %rcx
shrq $29, %rdx
imulq $-1000000093, %rdx, %rax
addq %rdi, %rax
retq
比较一下 Algorithm 1.2 和 Algorithm 2.2,可以发现 Algorithm 2.2 少用了几个寄存器,而且少几个 op……algo2_2 如果不加 L13 ,就会多出一个 pushq %rax 和 popq %rcx,我也不知道这是啥,看起来像是保存寄存器,但是不知道为啥是 popq %rcx,希望有懂的人告诉我这是在干啥……
我试图用 n.saturating_add(1) 但是汇编出来的代码不是基于 sbb 而是 cmovne。
incq %rcx
movq $-1, %rax
cmovneq %rcx, %rax
Algorithm 3
以上算法都是基于 \eqref{eq
},先算了 后再算 。能不能绕过 直接算 呢?这篇 paper 提供了一个思路。在 Algorithm 1.0 里面,如果我们找到了满足条件的 ,我们其实可以用来干一些事:由于
故
搞定。注意到等式右边上的数都很大,为 级。一般情况下 ,则上面的数有 这么大,所以这个算法对 u64 内的数并不实用,只适用于 u32 以内的数……作者提供的代码也是针对 u32 的……
这代码也不是很好写。我这里取 ,那么我们需要做 30 位 u64 和 90 位 u128 的乘法,而且结果也不是取高、低 64 位,写起来很蛋疼。当然我这里是支持了 在 u64 内,如果只支持 u32 的 应该就舒服很多了,MUL也可以在一个 u64 内就存下了。
fn algo3(n: u64) -> u64 {
const MUL: u128 = 19807038786511477253u128;
const SHR: u32 = 30;
const D: u64 = 1000000093;
let prod = (n as u128) * (MUL as u128);
((prod % (1u128 << (64 + SHR)) * D as u128) >> (64 + SHR)) as u64
}
汇编如下:注意到这里用了 3 个 mul 指令……
playground::algo3:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
addl %edx, %edi
movl $1000000093, %ecx
mulq %rcx
andl $1073741823, %edi
imulq $1000000093, %rdi, %rax
addq %rdx, %rax
shrq $30, %rax
retq
Related Work
这类方式也叫做 Barrett reduction。我在调研的时候,还找到另外一种 reduction 方式:Montgomery reduction,用的比 Barrett reduction 多。它把每个数都搞了个中间表示(, 是一个常数,通常选用 2 的幂),并且他有更高效的方法处理乘法。
baihacker 的 PE 库中有一个 NTT 的 benchmark,里面比较了 FLINT/NTL/LibBF 以及 Min_25 的代码,发现 Min_25 的是最快。我去看了一下 Min_25 的代码, UnsafeMod 里面的 reduce 函数和 Wikipedia 里面的 REDC 算法很像(但是不是其实就是,只不过没有归一化到 之间罢了,可能这就是为什么这叫 UnsafeMod 的原因吧),不过我也没仔细去看这个函数到底是干啥的了 = = 这个类好奇怪啊,减法里面为啥要加 3 * mod……NTL 的源代码中提到了这篇 paper,我也没看这是在说啥……Division algorithm 的 Wikipedia 中有一小节提到了:
However, unless D itself is a power of two, there is no X and Y that satisfies the conditions above. Fortunately, (N·X)/Y gives exactly the same result as N/D in integer arithmetic even when (X/Y) is not exactly equal to 1/D, but “close enough” that the error introduced by the approximation is in the bits that are discarded by the shift operation.
然后又 ref 了 3 个 link,有一个 link 就是 NTL 提到的 Granlund-Moeller 算法。
Acknowledgement
在 Division algorithm 的 Wikipedia 中,有一个 reference 是一篇博客。我写完之后才发现有这么一个 blog,仔细一看,妈呀写的是我的超集,真是太尴尬了……不过他公式有点丑……想了想,我这也不算抄袭,毕竟这是我自己看源代码看出来的东西。后来我认真读了一下这两篇 post,把他的一些新思路也加了进去,再写了一下自己的理解。我的证明和他稍有不同,但是本质上是一样的。原文中有这么一段
As is well known (and seen in a previous post), compilers optimize unsigned division by constants into multiplication by a “magic number.” But not all constants are created equal, and approximately 30% of divisors require magic numbers that are one bit too large, which necessitates special handling.
这个 30% 怎么看起来这么眼熟啊……行吧 orz……
后来我又想,既然有 fast division,那会不会有 fast modulo 呢?于是我找到了这么一篇 blog 以及对应的 paper。我也向这两篇 blog 的作者表示感谢。
Epilogue
写这篇文章用的时间远比我想象中的长,从查资料,读文献,到写代码,猜汇编,到最后整理成文字。虽然每件事都不用话太久,但是加在一起还是花了很多精力。事实上,我还有一些东西想写没写完,例如写一个线性同余发生器 benchmark 一下各个算法,例如看看 Granlund-Moeller 到底是在干啥,例如讲讲 libdivide 是怎么搞的,但是我感觉我花的时间已经够多的了。我也说不清鼓捣这些东西有什么用,对我毕业毫无用处,但是我就是感到快乐。码农的快乐,往往就是这么朴实无华且枯燥。
顺便一提:我现在发现的最方便的看 Rust 如何编译一小段代码的方式就是:独立写一个函数,加 #[inline(never)],然后 cargo rustc --release -- --emit asm 找对应函数。另外,这些优化应该都是 LLVM 做的……以及,NTL/FLINT 这些库应该都做了这些优化吧,何必自己折腾呢 orz……可惜没有好用的 Rust binding……