组合数:从 $n$ 个不同元素中取出 $k$ 个元素的所有不同组合的个数,叫做从 $n$ 个不同元素中取出 $k$ 个元素的组合数,记作 $C_n^k$。

在不同数据规模下,我们使用不同算法来对其求值。由于组合数往往极大,下面提到的算法除最后一个,均将结果取模 $p$.

组合数公式

组合数公式:

$$ C_n^k={n\choose k}=\frac{P_{n}^{k}}{k!}=\frac{n!}{k!(n-k)!} $$

我们直接利用该公式,用阶乘求值,每次可以求出一个组合数。在运算前,我们通常预处理得出范围内的所有阶乘值,和对应的乘法逆元,在计算时即可在 $O(1)$ 时间得到结果。除法需要使用逆元处理,参见模逆元.

  • 时间复杂度:$O(n\log p)$ (预处理阶乘和逆元) ,$O(1)$ (得到组合数)
  • 空间复杂度:$O(n)$
long long fast_pow(long long, long long) // 快速幂
long long inv(long long); // 求逆元

void init()
{
    fact[0] = invf[0] = 1;
    for (int i = 1; i < MAXA; i++)
    {
        fact[i] = fact[i - 1] * i % MOD;
        invf[i] = invf[i - 1] * inv(i) % MOD;
    }
}

int main()
{
    init();
    long long a, b;
    cin >> a >> b;
    cout << fact[a] * invf[b] % MOD * invf[a - b] % MOD << endl;
    return 0;
}

递推公式

组合数公式的递推公式:

$$ C_n^k=C_{n-1}^{k}+C_{n-1}^{k-1} $$

理解这个公式的方式可以是:将 $n$ 个物品分成 $n-1$ 个和 $1$ 个,若不选单独的那一个,则为 $C_{n-1}^{k}$,若选单独的哪一个,则为 $C_{n-1}^{k-1}$.

利用这个递推公式,我们可以算出指定范围内的所有情况的组合数。注意需要先将 $C_n^0$ 全部初始化为 $1$.

  • 时间复杂度:$O(n^2)$
  • 空间复杂度:$O(n^2)$
const int MAXA = 2010;
const long long MOD = 1e9 + 7;
long long ans[MAXA][MAXA];

void init()
{
    for (int i = 0; i < MAXA; i++)
        ans[i][0] = 1;
    for (int i = 1; i < MAXA; i++)
        for (int j = 1; j < MAXA; j++)
            ans[i][j] = (ans[i - 1][j] + ans[i - 1][j - 1]) % MOD;
}

卢卡斯定理

对于非负整数 $m$ 和 $n$ 和素数 $p$, 同余式:

$$ {\displaystyle {\binom {m}{n}}\equiv \prod _{i=0}^{k}{\binom {m_{i}}{n_{i}}}{\pmod {p}},} $$

成立。其中:

$$ {\displaystyle m=m_{k}p^{k}+m_{k-1}p^{k-1}+\cdots +m_{1}p+m_{0},} $$

并且

$$ {\displaystyle n=n_{k}p^{k}+n_{k-1}p^{k-1}+\cdots +n_{1}p+n_{0}} $$

是 $m$ 和 $n$ 的 $p$ 进制展开。当 $m < n$ 时,二项式系数 ${\displaystyle {\tbinom {m}{n}}=0}$.

来源:Wikipedia (使用 CC BY-SA 3.0 协议)

上面的公式有点难看明白,我们先改成熟悉的 $C$ 符号的记法,再简化一下,将一共 $k$ 项缩减为 $2$ 项,即得到:

$$ C_{m_1p+m_2}^{n_1p+n_2}\equiv C_{m_1}^{n_1}\cdot C_{m_2}^{n_2}\pmod p $$

虽然改成了 $2$ 项,但我们可以递归执行这个计算,因此和上面那个公式是一样的。

此时还是不够直观,我们再变形一下:

$$ C_{m}^{n}\equiv C_{m/p}^{n/p}\cdot C_{m\bmod p}^{n\bmod p}\pmod p $$

这个公式就比较直观了,我们可以看到,它能将 $n$、$m$ 较大的组合数分解为两个更小的组合数,通过不断地递归分解,最后可以分解为 $n,m<p$ 的组合数相乘。该算法适用于 $n,m$ 规模极大(如 $10^{18}$)的情况。

$<p$ 的组合数使用组合数公式得出,因此需要预处理阶乘和逆元,下面时间复杂度中的 $p\log p$ 是预处理耗时。

  • 时间复杂度:$O(p\log p\cdot\log_p{n})$
  • 空间复杂度:$O(p)$
#include <bits/stdc++.h>

using namespace std;

const int MAXA = 1e5 + 10;
long long fact[MAXA];

void init(int mod)
{
    fact[0] = 1;
    for (int i = 1; i <= mod; i++)
        fact[i] = fact[i - 1] * i % mod;
}

long long fast_pow(long long a, long long b, long long p)
{
    b %= p;
    long long ans = 1;
    while (b)
    {
        if (b % 2)
            ans = a * ans % p;
        a = a * a % p;
        b /= 2;
    }
    return ans;
}

inline long long inv(long long x, long long p)
{
    return fast_pow(x, p - 2, p);
}

long long comb(long long a, long long b, long long p)
{
    if (b > a)
        return 0;
    if (a < p && b < p)
        return fact[a] * inv(fact[b], p) % p * inv(fact[a - b], p) % p;
    return comb(a % p, b % p, p) * comb(a / p, b / p, p) % p;
}

int main()
{
    int n;
    cin >> n;
    while (n--)
    {
        long long a, b, p;
        cin >> a >> b >> p;
        init(p);
        cout << comb(a, b, p) << endl;
    }
    return 0;
}

高精度算法

若题目要求不取模,并且规模较大,那就使用高精度算法直接求解。例如我们可以使用 Python 暴力求解该问题:

a, b = input().split(' ')
a = int(a)
b = int(b)
res = 1
for i in range(a - b + 1, a + 1):
    res *= i
for i in range(1, b + 1):
    res //= i
print(res)

为了优化高精度的效率,我们可以将其先分解质因数为质数的幂的乘积这种形式:

$$ C_n^m=2^{a_1}\cdot3^{a_2}\cdot5^{a_3}\cdot7^{a_4}\dots $$

最后只需使用高精度乘法即可解决问题。

求解 $x$ 的因子中质数 $p$ 的次数 $a$ 的方法是:

$$ a=\lfloor\frac{x}{p}\rfloor+\lfloor\frac{x}{p^2}\rfloor+\lfloor\frac{x}{p^3}\rfloor+\dots $$

#include <bits/stdc++.h>

using namespace std;

const int MAXN = 5010;
bool is_prime[MAXN];
int prime[MAXN], idx;
int npow[MAXN];

void init_prime(int x)
{
    memset(is_prime, true, sizeof(is_prime));
    is_prime[0] = is_prime[1] = false;
    for (int i = 2; i <= x; i++)
    {
        if (is_prime[i])
            prime[idx++] = i;
        for (int j = 0; i * prime[j] <= x && j < idx; j++)
        {
            is_prime[i * prime[j]] = false;
            if (!(i % prime[j]))
                break;
        }
    }
}

int get(int x, int p)
{
    int res = 0;
    while (x)
    {
        res += x / p;
        x /= p;
    }
    return res;
}

vector<int> mul(vector<int> &a, int b)
{
    vector<int> ans;
    int t = 0;
    for (int i = 0; i < a.size(); i++)
    {
        t += a[i] * b;
        ans.push_back(t % 10);
        t /= 10;
    }
    while (t)
    {
        ans.push_back(t % 10);
        t /= 10;
    }
    return ans;
}

int main()
{
    int a, b;
    cin >> a >> b;
    init_prime(a);
    for (int i = 0; i < idx; i++)
        npow[i] = get(a, prime[i]) - get(b, prime[i]) - get(a - b, prime[i]);
    vector<int> ans;
    ans.push_back(1);
    for (int i = 0; i < idx; i++)
        for (int j = 0; j < npow[i]; j++)
            ans = mul(ans, prime[i]);
    for (int i = ans.size() - 1; i >= 0; i--)
        cout << ans[i];
    cout << endl;
    return 0;
}