从 fair coin 说起

这是一个经典老题了,如何用 fair coins 来构造一个 biased coin,以及如何用 fair coins 来构造一个 discrete uniform distribution。基于这两个问题,我又 yy 了另外几个小 follow-up,试图去分析其时间复杂度。

Case 1: Bernoulli(p)\text{Bernoulli}(p)

先来一个 warm-up exercise:在只有 fair coins (即 Bernoulli(0.5)\text{Bernoulli}(0.5))的情况下,给定任意一个 0<p<10 < p < 1,如何从 Bernoulli(p)\text{Bernoulli}(p) 中 sample 一个元素?

这个很好做:我们只要 sample [0,1)[0, 1) 之间的一个实数 xx,并和 pp 比较就好了。我们考虑逐位 sample 并比较:每次得到 xx 的二进制表示第 ii 位,xiBernoulli(p)x_i \sim \text{Bernoulli}(p),如果 xipix_i \neq p_i,则我们就可以得到 xxpp 的相对大小关系了,否则继续下一位。

def sample(p):
    while True:
        x = random.randint(2)
        p_i = int(p * 2)
        if x != p_i:
            return int(x < p_i)
        p = p * 2 - p_i

follow-up: 时间复杂度

上面这个算法有最坏时间复杂度保证吗?

显然没有。

那么存在一个最坏时间复杂度保证的算法吗?

对于几乎所有的 pp 都不存在。假设存在一个算法能保证 nn 步内结束,那么我们考虑用到的所有 nn 个 random bit,一共有 2n2^n 种组合。对于每种组合我们一定会返回一个 0 或者 1,所以算法一定是从 Bernoulli(q2n)\text{Bernoulli}(\frac{q}{2^n}) 中 sample,其中 qq 为整数。而几乎所有的 pp 都不能表示成 q2n\frac{q}{2^n} 这样的形式。

那上面这个算法的期望复杂度是多少?

显然,每一步有 12\frac{1}{2} 的概率直接结束,所以期望就是 2 步结束。挺意外的一个结果,和 pp 无关。

Case 2: U(n)U(n)

nZ+n \in \mathbb{Z}^+,如何用 fair coins 来得到一个 U(n)U(n) 的 sample 呢?这里 U(n)U(n){0,1,,n1}\{0, 1, \dots, n - 1\} 上的 uniform distribution。

我们每次 sample k=log2nk = \lceil \log_2 n \rceil 个 bit,得到一个 U(2k)U(2^k) 的 sample。如果这个数小于 nn,那么返回即可,否则重新开始。

def sample(n):
    k = int(math.ceil(math.log2(n)))
    while True:
        x = 0
        for i in range(k):
            x = x * 2 + random.randint(2)
        if x < n:
            return x

follow-up:时间复杂度

上面这个算法有最坏时间复杂度保证吗?

显然没有。

那存在一个有时间复杂度保证的算法吗?

和之前一模一样的 argument,没有。

那这个算法的期望时间复杂度是多少?

每一轮需要 kk 个 bit,并且有 n2k\frac{n}{2^k} 的概率结束,所以期望时间复杂度就是 k2kn\frac{k 2^k}{n}

follow-up:更优的算法

存在期望复杂度更低的算法吗?

存在。和 warm-up exercise 类似,还是 sample 一个 [0,1)[0, 1) 之间的实数 xx,答案即为 xn\lfloor xn \rfloor。我们从高位到低位枚举 xx 二进制表示,然后用一个区间 [l,r)[l, r) 表示 xx 可能取值,一旦发现 ln=rn\lfloor ln \rfloor = \lfloor r n \rfloor,那么我们就已经得到了 xn\lfloor xn \rfloor,返回即可。

def sample(n):
    l, r = 0, 1
    while True:
        if random.randint(2):
            l = (l + r) / 2
        else:
            r = (l + r) / 2
        if int(l * n) == int(r * n):
            return int(l * n)

follow-up:避免浮点数操作

上面这个做法需要用到两个浮点数 l,rl, r,其精度有限。能避免使用无限精度的数据类型吗?

可以。我们回到一开始的算法,它没有达到最优是因为,如果 nx<2kn \leq x < 2^k,我们只会重来,白白浪费了这里面用到的 randomness。新算法如下:我们维护两个数 x,mx, m,表示现在我们有一个 U(m)U(m) 的 sample xx。如果 mnm \geq n,那么我们观察 xx

  • 如果 x<nx < n,那么 xx 就是 U(n)U(n) 的一个 sample,直接返回即可;
  • 否则我们就知道 xxU(n,m)U(n, m) 的一个 sample,故 x=xnx' = x - n 就是 U(mn)U(m - n) 的一个 sample。 而每次拿到一个新 bit 后,我们可以得到一个 U(2m)U(2m) 下的 sample x=2x+Bernoulli(0.5)x' = 2x + \text{Bernoulli}(0.5),重复操作直到 mnm \geq n 为止。
def sample(n):
    x, m = 0, 1
    while True:
        x, m = x * 2 + random.randint(2), m * 2
        if m >= n:
            if x < n:
                return x
            x, m = x - n, m - n

follow-up:时间复杂度分析

上面这个算法期望时间复杂度,f(n)f(n),具体为多少?

显然对 (x,m)(x, m) 整一个高斯消元是可行的,但是这有点复杂了。聪明一点的同学可以考虑 fmf_m 表示在已经有了一个 U(m)U(m) 的 sample 后得到 U(n)U(n) 的 sample 的期望复杂度,我们有:

fm={mnmfmn,mnf2m+1,m<nf_m = \begin{cases} \frac{m-n}{m} f_{m-n}, & m \geq n \\ f_{2m} + 1 , & m < n\end{cases}

这样未知数会少一些……

但是 somehow 我们有更聪明的方法。令 YY 表示几轮之后算法结束,我们先来一个常用技巧:

E[Y]=y0Pr[Yy].\mathbb{E}[Y] = \sum_{y \geq 0}^\infty \Pr[Y \geq y].

于是问题转化为计算 yy 步之后算法还没有结束的概率是多少。我们把算法作一点点修改:之前是如果 x<nx < n 那么我们返回 xx,现在是如果 xmnx \geq m - n 我们返回 xmodnx \bmod n,在 xU(m)x \sim U(m) 的情况下这两者是等价的。

def sample(n):
    x, m = 0, 1
    while True:
        x, m = x * 2 + random.randint(2), m * 2
        if m >= n:
            if x >= m - n:
                return x % n
            m = m - n

我们可以观察到,第 yy 轮结束之后的 mm 刚好就是 2ymodn2^y \bmod n,于是假设有无限精度,再改动一点点点点:

def sample(n):
    x, m = 0, 1
    while True:
        x, m = x * 2 + random.randint(2), m * 2
        if x >= m % n:
            return x % n

然后就可以很清晰的看到,yy 轮之后仍然不结束的概率是

Pr[Yy]=Pr[x<(2ymodn)]=2ymodn2y.\Pr[Y \geq y] = \Pr[x < (2^y \bmod n)] = \frac{2^y \bmod n}{2^y}.

故算法的期望时间复杂度为

f(n)=E[Y]=y=02ymodn2y.f(n) = \mathbb{E}[Y] = \sum_{y=0}^\infty \frac{2^y \bmod n}{2^y}.

虽然这个求和有无数项,但是只要注意到 2ymodn2^y \bmod n 有循环节就好办了。

follow-up:时间复杂度上下界

刚才我们得到了 f(n)f(n) 的精确表达式,但是看起来不存在一个对于任何一个 nn 均成立的 closed-form。那我们能否得到一个 f(n)f(n) 的上下界?

一个显然的,不需要任何推导的下界是:

f(n)log2n,f(n) \geq \log_2 n,

因为 U(n)U(n) 包含 log2n\log_2 n bit 的信息,所以我们期望至少 log2n\log_2 n 次才能得到一个 sample……

稍微一分析,我们就可以得到一个 f(n)f(n) 的更强的下界。令 k=log2nk = \lceil \log_2 n\rceil,而 [0,k1][0, k-1] 内的 yy 均有 2ymodn2y=1\frac{2^y \bmod n}{2^y} = 1,故有:

f(n)klog2n.f(n) \geq k \geq \log_2 n.

那上界呢?我能想到的最好的一个是:我们把 yy 分成三组来估计 2ymodn2y\frac{2^y \bmod n}{2^y},

  • 0y<k0 \leq y < k:很显然,此时 2ymodn2y=1\frac{2^y \bmod n}{2^y} = 1,这一组内所有的和为 kk
  • y=ky = k:有 2kmodn2k2kn2k=1n2k\frac{2^k \bmod n}{2^k} \leq \frac{2^k - n}{2^k} = 1 - \frac{n}{2^k}
  • y>ky > k:有 2ymodn2y<n2y\frac{2^y \bmod n}{2^y} < \frac{n}{2^y},这一组内所有的和有上界 y>kn2y=n2k\sum_{y > k}^\infty \frac{n}{2^y} = \frac{n}{2^k}

f(n)f(n) 有上界

f(n)=y=02ymodn2y<k+1n2k+n2k=k+1.f(n) = \sum_{y=0}^\infty \frac{2^y \bmod n}{2^y} < k + 1 - \frac{n}{2^k} + \frac{n}{2^k} = k + 1.

综上,我对 f(n)f(n) 的最好的 bound 就是

log2nlog2nf(n)<log2n+1,\log_2 n \leq \lceil \log_2 n \rceil \leq f(n) < \lceil \log_2 n \rceil + 1,

前面的 log2n\log_2 n 就是信息论的 lower bound。可以看到,这个算法最多比理论下界多 2 bits,还算是不错的吧……