Modulo by a Constant

\newcommand{\bmax}{ {b_\max} } 最近做 PE 的时候也顺便在玩 Rust。我在研究 Rust 编译出的 ASM 的时候,发现一段有趣的代码:

#[inline(never)]
fn opt_mod(x: u64) -> u64 {
    x % 1000_000_007
}

编译出的 ASM 长这样:

playground::opt_mod:
	movabsq	$-8543223828751151131, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	shrq	$29, %rdx
	imulq	$1000000007, %rdx, %rax
	subq	%rax, %rdi
	movq	%rdi, %rax
	retq

我震惊的发现,这里面居然没有除法、取模?有点意思……

我猜这里的 mulq %rcx 是会溢出的,最后高 64 位进入 %rdx。如果输入是 nn,那么输出为那么给出的结果为

y=nna264+bd\labeleq:divtomod\begin{equation} y = n - \floor{\frac{n a}{2^{64 + b}}} d \label{eq:div-to-mod} \end{equation}

其中 a=2648543223828751151131a = 2^{64}-8543223828751151131 (由于这是个 u64,所以要加上 2642^{64}),b=29b = 29。这里面的那个取整很像在求 n/d\floor{n / d},但是巧妙的避免避免除法,改用一次乘法和一次位移来实现。另一件有趣的事是,它非常自信这样得出的结果为 n/d\floor{n/d} ,不会 offset-by-1,因为这里面甚至没有判断 yy 是否大于 dd 或小于 0,如果大于 dd 或小于 0,应该再来一次减法或加法。看起来这后面还藏着一些数学,我们来研究一下。

Disclaimer: 我对汇编几乎不了解,理解基本上是靠猜。我也不了解每个指令的时间、CPU 的流水线作业、分支预测等因素对齐性能的影响。这里贴出汇编代码只是为了确认没有奇怪的操作。

Notations

这里我们约定,我们想计算的是 nmodd=nx/ddn \bmod d = n - \floor{x/d} d,其中 n,dn, d 都是 u64 内的无符号整形。由于对于同一个 dd ,我们有很多个 nn 需要做,所以我们想对 dd 做预处理,看能不能比直接用 div 这种汇编指令更快。NN 定义为常数 2642^{64}为方便处理,我们假定 dd 不是 2 的幂。

在这篇文章中,mod\bmod 运算优先级低于加减乘除取负等运算,但高于括号,且 mod\bmod 的结果均为正数,即 1mod3=2-1 \bmod 3 = 2

Algorithm 1.0

很显然,方程 \eqref{eq

} 成立的充要条件是

na2bN=n/d,ndna2bN<nd+1,\labeleq:strict\begin{equation} \floor{\frac{na}{2^b N}} = \floor{n/d}, \quad \Longleftrightarrow \quad \frac{n}{d} \leq \frac{na}{2^b N} < \floor{\frac{n}{d}} + 1, \label{eq:strict} \end{equation}

我们将不等号最右边缩到 n+1d\frac{n+1}{d}

ndna2bN<n+1d.\labeleq:3\begin{equation} \frac{n}{d} \leq \frac{na}{2^b N} < \frac{n+1}{d}. \label{eq:3} \end{equation}

注意到不等式的三项都是关于 nn 的线性函数,所以这个不等式在区间 [0,N1][0, N - 1] 上成立等价于在区间两端成立。当 n=0n = 0 时,原式显然成立,所以我们只需关注 n=N1n = N - 1 的情况:

N1da(N1)2bN<Nd,2bNad<2bN2N1.\frac{N - 1}{d} \leq \frac{a (N-1)}{2^bN} < \frac{N}{d}, \quad \Longleftrightarrow \quad 2^b N \leq ad < \frac{2^bN^2}{N-1}.

将等式最后边再次放缩成 2b(N+1)2^b (N+1) 后可得

2bNad2bN+2b.2^b N \leq ad \leq 2^bN+2^b.

此时可以看出,aa 这个不等式在

2bNmodd2b\labeleq:req1\begin{equation} -2^b N \bmod d \leq 2^b \label{eq:req1} \end{equation}

时有解 2bNd\ceil{\frac{2^b N}{d}},否则无解。带进去算一下,Rust 选择的这组 (a,b)(a, b) 满足要求。注意到 nana 不会溢出 u128 ,故 %rdx 这里也不会溢出,难怪如此自信。

我就用 Rust 写了一下:

fn algo1(n: u64) -> u64 {
    const MUL: u64 = 9903520244958400485;
    const SHR: u32 = 29;
    const D: u64 = 1000000007;

    let prod_hi = (((x as u128) * (MUL as u128)) >> 64) as u64;
    let q = (prod_hi >> SHR);
    n - q * D
}

得到的汇编为:(和一开始的几乎一样)

playground::Algo1::modulo:
	movabsq	$-8543223828751151131, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	shrq	$29, %rdx
	imulq	$-1000000007, %rdx, %rax
	addq	%rdi, %rax
	retq

仔细看一下,其实还是有小小不同的……这里面只有 7 个指令,Rust 编译器里面有 8 个,区别在于 algo1 里面都是在 %rax 上操作,最后直接返回了 %rax,而 Rust 编译器里都是在 %rdi 里面操作,最后再 mov %rdi, %rdx。为此我还特意去查了一下 Calling convention(之前学的忘光了),好像是 %rdi 里面存参数,返回值在 %rax 里,而且 %rax/%rdi 以及很多寄存器都是 callee 可以随便改的……

要计算这组 (a,b)(a, b) 的话,直接枚举 bb 就行了,反正也不大……实验中也发现,Rust 给出的就是最小的 bb。现在问题来了:我们希望 a<Na < N,这样才能在一个 u64 内存下 aa。如果找不到这样的 aa 怎么办?枚举了一下 [1,109][1, 10^9] 里面所有奇数(从前的分析来看,这和奇偶数没啥关系),大概有 70% 的数是能找到 aa 的,但是剩下的 30% 也不少。

我们还可以观察到:b = b_\max =\lceil \log_2 p\rceil 时,一定能满足 \eqref{eq

},但是此时的 a=2\bmaxNdNa = \ceil{\frac{2^\bmax N}{d}} \geq N 一定存不进一个 u64,而且由于 d2\bmax<2dd \leq 2^{\bmax} < 2d,可得 a=2\bmaxNd<2Na = \frac{2^\bmax N}{d} < 2N 就多了那么一个 bit,真是尴尬……

Algorithm 1.1

第一种方法,我们干脆再 relax \eqref{eq

} 。如果我们能满足

ndna2bN<nd+1,0na2bNnd<1,\frac{n}{d} \leq \frac{na}{2^b N} < \frac{n}{d} + 1, \quad \Longleftrightarrow \quad 0 \leq \frac{na}{2^b N} - \frac{n}{d} < 1,

的话,那么能保证

ndna2bNnd+1,\floor{\frac{n}{d}} \leq \floor{\frac{na}{2^b N}} \leq \floor{\frac{n}{d}} + 1,

也就是说,我们能求得一个大概(offset-by-1)的解,带入 \eqref{eq

} 后求得的 yy 就可能是负数,还需要一个 if 来判断。注意 yy 是用一个 u64 来存,判断 yy 正负需要特殊的技巧。

事实上,AtCoder 里面的 modint 就是这么处理的,不过它用带符号的数。

Algorithm 1.2

另一种方法,我们再看看 Rust 咋整的……我挑了一个 109+9310^9 + 93 来看看 Rust 生成的 ASM 是什么样:

playground::opt_mod:
	movabsq	$1360294712801925637, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	movq	%rdi, %rax
	subq	%rdx, %rax
	shrq	%rax
	addq	%rdx, %rax
	shrq	$29, %rax
	imulq	$1000000093, %rax, %rax
	subq	%rax, %rdi
	movq	%rdi, %rax
	retq

对比一下之前的,这里多了四行:L5-8,我们去看看这到底是在做啥……还是令 a=1360294712801925637a'=1360294712801925637b=29b = 29,于是他算的东西为

y=nnnaN2+naN2bd=nn(1+naN)2b+1d,\labeleq:subshradd\begin{equation} y = n -\floor{\frac{\floor{\frac{n - \floor{\frac{na'}{N}}}{2}} + \floor{\frac{na'}{N}}}{2^b}} d = n - \floor{\frac{n(1 + \frac{na'}{N})}{2^{b+1}}}d, \label{eq:sub-shr-add} \end{equation}

原来如此……我们之前的问题不是 aNa \geq N 出了 u64 吗,但是之前也提到了,必定存在一个 Na<2NN \leq a < 2N 满足要求,那我们就令 a=N+aa = N + a',其中 a<Na' < N 就可以用一个 u64 存下来了。这里有地方需要注意:虽然 nnaN2+naN\floor{\frac{n - \floor{\frac{na'}{N}}}{2}} + \floor{\frac{na'}{N}} 在数学上等于 n+naN2\floor{\frac{n + \floor{\frac{na'}{N}}}{2}} ,但是后者在 nn 大的时候有可能溢出,前者不会。

对比一下 Algorithm 1.1,这里 Rust 用 4 ops 当做一个 if (以及里面的 subq),如果用 y -= (y >= p) as u64 * p 的话多一个乘法,具体会快多少就不知道了……

这里我也给个 Rust 代码:

fn algo1_2(n: u64) -> u64 {
    const MUL: u64 = 1360294712801925637;
    const SHR: u32 = 29;
    const D: u64 = 1000000093;

    let hi = (((n as u128) * (MUL as u128)) >> 64) as u64;
    let q = (((n - hi) >> 1) + hi) >> SHR;
    n - q * D
}

得到的汇编为也和之前的几乎一样:

playground::algo1_2:
	movabsq	$1360294712801925637, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	movq	%rdi, %rax
	subq	%rdx, %rax
	shrq	%rax
	addq	%rdx, %rax
	shrq	$29, %rax
	imulq	$-1000000093, %rax, %rax
	addq	%rdi, %rax
	retq

Algorithm 2

如果还坚持用等式 \eqref{eq

} 来优化,有点优化不动了。这里我们另辟蹊径:我们不再要求 \eqref{eq
} 成立,而是要求

(n+1)a2bN=n/d,nd(n+1)a2bN<nd+1,\labeleq:addone\begin{equation} \floor{\frac{(n+1)a}{2^b N}} = \floor{n/d}, \quad \Longleftrightarrow \quad \frac{n}{d} \leq \frac{(n+1)a}{2^b N} < \floor{\frac{n}{d}} + 1, \label{eq:add-one} \end{equation}

成立。我们继续把右边放缩成 n+1d\frac{n+1}{d},然后用之前同样的 trick,所有项都是线性的,则 [0,N1][0, N-1] 内满足要求等价于区间两端满足要求:

0a2bN<1dN1da2b<Nd,0 \leq \frac{a}{2^b N} < \frac{1}{d} \quad \wedge \quad \frac{N-1}{d} \leq \frac{a}{2^b} < \frac{N}{d},

注意到前者第二个不等号等价于后者第二个不等号,故可以合并起来,化简为

2bN2bad<2bN,2^b N - 2^b \leq ad < 2^b N,

可以看出,aa

2bNmodd2b2^b N \bmod d \leq 2^b

时存在解 2bNd\floor{\frac{2^bN}{d}}。注意到这个条件和 \eqref{eq

} 很像。事实上,我们可以证明,当 bb\bmax1\bmax - 1 时,两者至少有一个满足:首先我们有:对于任意 bb,均有

0<(2bNmodd)+(2bNmodd)<2d,0 < (2^b N \bmod d) + (-2^b N \bmod d) < 2d,

又有对于任意 bb 均有 ((2bNmodd)+(2bNmodd))modd=0modd=0((2^b N \bmod d) + (-2^b N \bmod d)) \bmod d = 0 \bmod d = 0,故有

(2\bmax1Nmodd)+(2\bmax1Nmodd)=d2\bmax,(2^{\bmax - 1} N \bmod d) + (-2^{\bmax - 1} N \bmod d) = d \leq 2^\bmax,

所以 (2\bmax1Nmodd)(2^{\bmax - 1} N \bmod d)(2\bmax1Nmodd)(-2^{\bmax - 1} N \bmod d) 中至少有一个不超过 2\bmax12^{\bmax - 1}

接下来便不难想到,这个算法可以和 Algorithm 1.0 互补:若 Algorithm 1.0 找不到合适的 (a,b)(a, b),则 Algorithm 2 一定能找到合适的 (a,b)(a, b)。剩下的问题就是怎么有效的实现 Algorithm 2 了。

Algorithm 2 有一个问题使得我们无法直接计算 (n+1)a2bN\floor{\frac{(n+1)a}{2^b N}}n+1n+1 可能会溢出 u64。这也就是我们即将解决的问题。注意到 a<Na < N,所以 (n+1)aN(N1)(n+1)a \leq N(N-1) 不会溢出 u128。

Algorithm 2.1

第一种方法,我们使用乘法分配律 (a+b)c=ac+bc(a + b) c = ac + bc,将 (n+1)a(n+1)a 拆成 na+ana + a

fn algo2_1(n: u64) -> u64 {
    const MUL: u64 = 9903519393255738626;
    const SHR: u32 = 29;
    const D: u64 = 1000000093;

    let prod = (n as u128) * (MUL as u128) + (MUL as u128);
    let prod_hi = (prod >> 64) as u64;
    let q = prod_hi >> SHR;
    n - q * D
}

得到的汇编为:

playground::algo2_1:
	movabsq	$-8543224680453812990, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	addq	%rcx, %rax
	adcq	$0, %rdx
	shrq	$29, %rdx
	imulq	$-1000000093, %rdx, %rax
	addq	%rdi, %rax
	retq

Algorithm 2.2

第二种方法,我们注意到,n+1n+1 的溢出只会发生在 n=N1n = N- 1 这种情况下,而这种情况下,由于 dd 不是 2 的幂,故有 dNd \nmid N,即

N1d=Nd,\floor{\frac{N - 1}{d}} = \floor{\frac{N}{d}},

也就是说,这个时候我用 nn 还是 n+1n+1 是没有区别的,那我不加不就好了……所以 Algorithm 2.2 很简单,就是将 \eqref{eq

} 换成

min(n+1,N1)a2bN,\floor{\frac{\min(n+1, N - 1)a}{2^b N}},

这里的 min(n+1,N1)\min(n + 1, N - 1) 听说可以直接两个汇编指令搞定,但是我搞不出来,只能手写汇编了……以下便是代码:


fn algo2_2(n: u64) -> u64 {
    const MUL: u64 = 9903519393255738626;
    const SHR: u32 = 29;
    const D: u64 = 1000000093;

    let mut saturated = n;
    unsafe {
        asm!(
            "add {0}, 1",
            "sbb {0}, 0",
            inout(reg) saturated,
            options(nostack),
        );
    }
    let prod = (saturated as u128) * (MUL as u128);
    let prod_hi = (prod >> 64) as u64;
    let q = prod_hi >> SHR;
    n - q * D
}

汇编代码为:

playground::algo2_2:
	movq	%rdi, %rax
	addq	$1, %rax
	sbbq	$0, %rax
	movabsq	$-8543224680453812990, %rcx
	mulq	%rcx
	shrq	$29, %rdx
	imulq	$-1000000093, %rdx, %rax
	addq	%rdi, %rax
	retq

比较一下 Algorithm 1.2 和 Algorithm 2.2,可以发现 Algorithm 2.2 少用了几个寄存器,而且少几个 op……algo2_2 如果不加 L13 ,就会多出一个 pushq %raxpopq %rcx,我也不知道这是啥,看起来像是保存寄存器,但是不知道为啥是 popq %rcx,希望有懂的人告诉我这是在干啥……

我试图用 n.saturating_add(1) 但是汇编出来的代码不是基于 sbb 而是 cmovne

	incq	%rcx
	movq	$-1, %rax
	cmovneq	%rcx, %rax

Algorithm 3

以上算法都是基于 \eqref{eq

},先算了 nd\floor{\frac{n}{d}} 后再算 nmoddn \bmod d。能不能绕过 nd\floor{\frac{n}{d}} 直接算 nmoddn \bmod d 呢?

这篇 paper 提供了一个思路。在 Algorithm 1.0 里面,如果我们找到了满足条件的 (a,b)(a, b),我们其实可以用来干一些事:由于

ndna2bN<n+1d,0na2bNnd<1d,\frac{n}{d} \leq \frac{na}{2^b N} < \frac{n+1}{d}, \quad \Longrightarrow \quad 0 \leq \frac{na}{2^b N} - \frac{n}{d} < \frac{1}{d},

nmodd=d{nd}=d{na2bN}=d(namod2bN)2bN.n \bmod{d} = d \left\{ \frac{n}{d} \right\} = \floor{d \left\{ \frac{na}{2^b N} \right\}} = \floor{\frac{d(na \bmod 2^b N)}{2^b N} }.

搞定。注意到等式右边上的数都很大,为 2bNd2^bNd 级。一般情况下 b\bmaxb \approx \bmax,则上面的数有 Nd2Nd^2 这么大,所以这个算法对 u64 内的数并不实用,只适用于 u32 以内的数……作者提供的代码也是针对 u32 的……

这代码也不是很好写。我这里取 b=30b = 30,那么我们需要做 30 位 u64 和 90 位 u128 的乘法,而且结果也不是取高、低 64 位,写起来很蛋疼。当然我这里是支持了 nn 在 u64 内,如果只支持 u32 的 nn 应该就舒服很多了,MUL也可以在一个 u64 内就存下了。

fn algo3(n: u64) -> u64 {
    const MUL: u128 = 19807038786511477253u128;
    const SHR: u32 = 30;
    const D: u64 = 1000000093;

    let prod = (n as u128) * (MUL as u128);
    ((prod % (1u128 << (64 + SHR)) * D as u128) >> (64 + SHR)) as u64
}

汇编如下:注意到这里用了 3 个 mul 指令……

playground::algo3:
	movabsq	$1360294712801925637, %rcx
	movq	%rdi, %rax
	mulq	%rcx
	addl	%edx, %edi
	movl	$1000000093, %ecx
	mulq	%rcx
	andl	$1073741823, %edi
	imulq	$1000000093, %rdi, %rax
	addq	%rdx, %rax
	shrq	$30, %rax
	retq

Related Work

这类方式也叫做 Barrett reduction。我在调研的时候,还找到另外一种 reduction 方式:Montgomery reduction,用的比 Barrett reduction 多。它把每个数都搞了个中间表示(xxRmodpx \mapsto xR \bmod pR>pR > p 是一个常数,通常选用 2 的幂),并且他有更高效的方法处理乘法。

baihacker 的 PE 库中有一个 NTT 的 benchmark,里面比较了 FLINT/NTL/LibBF 以及 Min_25 的代码,发现 Min_25 的是最快。我去看了一下 Min_25 的代码UnsafeMod 里面的 reduce 函数和 Wikipedia 里面的 REDC 算法很像(但是不是其实就是,只不过没有归一化到 [0,p)[0, p) 之间罢了,可能这就是为什么这叫 UnsafeMod 的原因吧),不过我也没仔细去看这个函数到底是干啥的了 = = 这个类好奇怪啊,减法里面为啥要加 3 * mod……NTL 的源代码中提到了这篇 paper,我也没看这是在说啥……Division algorithm 的 Wikipedia 中有一小节提到了:

However, unless D itself is a power of two, there is no X and Y that satisfies the conditions above. Fortunately, (N·X)/Y gives exactly the same result as N/D in integer arithmetic even when (X/Y) is not exactly equal to 1/D, but “close enough” that the error introduced by the approximation is in the bits that are discarded by the shift operation.

然后又 ref 了 3 个 link,有一个 link 就是 NTL 提到的 Granlund-Moeller 算法。

Acknowledgement

在 Division algorithm 的 Wikipedia 中,有一个 reference 是一篇博客。我写完之后才发现有这么一个 blog,仔细一看,妈呀写的是我的超集,真是太尴尬了……不过他公式有点丑……想了想,我这也不算抄袭,毕竟这是我自己看源代码看出来的东西。后来我认真读了一下这两篇 post,把他的一些新思路也加了进去,再写了一下自己的理解。我的证明和他稍有不同,但是本质上是一样的。原文中有这么一段

As is well known (and seen in a previous post), compilers optimize unsigned division by constants into multiplication by a “magic number.” But not all constants are created equal, and approximately 30% of divisors require magic numbers that are one bit too large, which necessitates special handling.

这个 30% 怎么看起来这么眼熟啊……行吧 orz……

后来我又想,既然有 fast division,那会不会有 fast modulo 呢?于是我找到了这么一篇 blog 以及对应的 paper。我也向这两篇 blog 的作者表示感谢。

Epilogue

写这篇文章用的时间远比我想象中的长,从查资料,读文献,到写代码,猜汇编,到最后整理成文字。虽然每件事都不用话太久,但是加在一起还是花了很多精力。事实上,我还有一些东西想写没写完,例如写一个线性同余发生器 benchmark 一下各个算法,例如看看 Granlund-Moeller 到底是在干啥,例如讲讲 libdivide 是怎么搞的,但是我感觉我花的时间已经够多的了。我也说不清鼓捣这些东西有什么用,对我毕业毫无用处,但是我就是感到快乐。码农的快乐,往往就是这么朴实无华且枯燥。

顺便一提:我现在发现的最方便的看 Rust 如何编译一小段代码的方式就是:独立写一个函数,加 #[inline(never)],然后 cargo rustc --release -- --emit asm 找对应函数。另外,这些优化应该都是 LLVM 做的……以及,NTL/FLINT 这些库应该都做了这些优化吧,何必自己折腾呢 orz……可惜没有好用的 Rust binding……