组合数:n 个不同元素中取出 k 个元素的所有不同组合的个数,叫做从 n 个不同元素中取出 k 个元素的组合数,记作 Cnk

在不同数据规模下,我们使用不同算法来对其求值。由于组合数往往极大,下面提到的算法除最后一个,均将结果取模 p.

组合数公式

组合数公式:

Cnk=(nk)=Pnkk!=n!k!(nk)!

我们直接利用该公式,用阶乘求值,每次可以求出一个组合数。在运算前,我们通常预处理得出范围内的所有阶乘值,和对应的乘法逆元,在计算时即可在 O(1) 时间得到结果。除法需要使用逆元处理,参见模逆元.

  • 时间复杂度:O(nlogp) (预处理阶乘和逆元) ,O(1) (得到组合数)
  • 空间复杂度:O(n)
long long fast_pow(long long, long long) // 快速幂
long long inv(long long); // 求逆元

void init()
{
    fact[0] = invf[0] = 1;
    for (int i = 1; i < MAXA; i++)
    {
        fact[i] = fact[i - 1] * i % MOD;
        invf[i] = invf[i - 1] * inv(i) % MOD;
    }
}

int main()
{
    init();
    long long a, b;
    cin >> a >> b;
    cout << fact[a] * invf[b] % MOD * invf[a - b] % MOD << endl;
    return 0;
}

递推公式

组合数公式的递推公式:

Cnk=Cn1k+Cn1k1

理解这个公式的方式可以是:将 n 个物品分成 n1 个和 1 个,若不选单独的那一个,则为 Cn1k,若选单独的哪一个,则为 Cn1k1.

利用这个递推公式,我们可以算出指定范围内的所有情况的组合数。注意需要先将 Cn0 全部初始化为 1.

  • 时间复杂度:O(n2)
  • 空间复杂度:O(n2)
const int MAXA = 2010;
const long long MOD = 1e9 + 7;
long long ans[MAXA][MAXA];

void init()
{
    for (int i = 0; i < MAXA; i++)
        ans[i][0] = 1;
    for (int i = 1; i < MAXA; i++)
        for (int j = 1; j < MAXA; j++)
            ans[i][j] = (ans[i - 1][j] + ans[i - 1][j - 1]) % MOD;
}

卢卡斯定理

对于非负整数 mn 和素数 p, 同余式:

(mn)i=0k(mini)(modp),

成立。其中:

m=mkpk+mk1pk1++m1p+m0,

并且

n=nkpk+nk1pk1++n1p+n0

mnp 进制展开。当 m<n 时,二项式系数 (mn)=0.

来源:Wikipedia (使用 CC BY-SA 3.0 协议)

上面的公式有点难看明白,我们先改成熟悉的 C 符号的记法,再简化一下,将一共 k 项缩减为 2 项,即得到:

Cm1p+m2n1p+n2Cm1n1Cm2n2(modp)

虽然改成了 2 项,但我们可以递归执行这个计算,因此和上面那个公式是一样的。

此时还是不够直观,我们再变形一下:

CmnCm/pn/pCmmodpnmodp(modp)

这个公式就比较直观了,我们可以看到,它能将 nm 较大的组合数分解为两个更小的组合数,通过不断地递归分解,最后可以分解为 n,m<p 的组合数相乘。该算法适用于 n,m 规模极大(如 1018)的情况。

<p 的组合数使用组合数公式得出,因此需要预处理阶乘和逆元,下面时间复杂度中的 plogp 是预处理耗时。

  • 时间复杂度:O(plogplogpn)
  • 空间复杂度:O(p)
#include <bits/stdc++.h>

using namespace std;

const int MAXA = 1e5 + 10;
long long fact[MAXA];

void init(int mod)
{
    fact[0] = 1;
    for (int i = 1; i <= mod; i++)
        fact[i] = fact[i - 1] * i % mod;
}

long long fast_pow(long long a, long long b, long long p)
{
    b %= p;
    long long ans = 1;
    while (b)
    {
        if (b % 2)
            ans = a * ans % p;
        a = a * a % p;
        b /= 2;
    }
    return ans;
}

inline long long inv(long long x, long long p)
{
    return fast_pow(x, p - 2, p);
}

long long comb(long long a, long long b, long long p)
{
    if (b > a)
        return 0;
    if (a < p && b < p)
        return fact[a] * inv(fact[b], p) % p * inv(fact[a - b], p) % p;
    return comb(a % p, b % p, p) * comb(a / p, b / p, p) % p;
}

int main()
{
    int n;
    cin >> n;
    while (n--)
    {
        long long a, b, p;
        cin >> a >> b >> p;
        init(p);
        cout << comb(a, b, p) << endl;
    }
    return 0;
}

高精度算法

若题目要求不取模,并且规模较大,那就使用高精度算法直接求解。例如我们可以使用 Python 暴力求解该问题:

a, b = input().split(' ')
a = int(a)
b = int(b)
res = 1
for i in range(a - b + 1, a + 1):
    res *= i
for i in range(1, b + 1):
    res //= i
print(res)

为了优化高精度的效率,我们可以将其先分解质因数为质数的幂的乘积这种形式:

Cnm=2a13a25a37a4

最后只需使用高精度乘法即可解决问题。

求解 x 的因子中质数 p 的次数 a 的方法是:

a=xp+xp2+xp3+

#include <bits/stdc++.h>

using namespace std;

const int MAXN = 5010;
bool is_prime[MAXN];
int prime[MAXN], idx;
int npow[MAXN];

void init_prime(int x)
{
    memset(is_prime, true, sizeof(is_prime));
    is_prime[0] = is_prime[1] = false;
    for (int i = 2; i <= x; i++)
    {
        if (is_prime[i])
            prime[idx++] = i;
        for (int j = 0; i * prime[j] <= x && j < idx; j++)
        {
            is_prime[i * prime[j]] = false;
            if (!(i % prime[j]))
                break;
        }
    }
}

int get(int x, int p)
{
    int res = 0;
    while (x)
    {
        res += x / p;
        x /= p;
    }
    return res;
}

vector<int> mul(vector<int> &a, int b)
{
    vector<int> ans;
    int t = 0;
    for (int i = 0; i < a.size(); i++)
    {
        t += a[i] * b;
        ans.push_back(t % 10);
        t /= 10;
    }
    while (t)
    {
        ans.push_back(t % 10);
        t /= 10;
    }
    return ans;
}

int main()
{
    int a, b;
    cin >> a >> b;
    init_prime(a);
    for (int i = 0; i < idx; i++)
        npow[i] = get(a, prime[i]) - get(b, prime[i]) - get(a - b, prime[i]);
    vector<int> ans;
    ans.push_back(1);
    for (int i = 0; i < idx; i++)
        for (int j = 0; j < npow[i]; j++)
            ans = mul(ans, prime[i]);
    for (int i = ans.size() - 1; i >= 0; i--)
        cout << ans[i];
    cout << endl;
    return 0;
}