2023牛客暑期多校训练营6

B - Distance

https://ac.nowcoder.com/acm/contest/57360/B

题意

给定长度为 n 的两个序列 {s}i=1n{t}i=1n, 从中选出两个大小相等的集合 AS,BT, 每次可以选择 xAyB 并执行 xx+1yy+1。记让 A=B 的最小操作次数为 C(A,B), 求 ASBTC(A,B)mod9982443531n3×103,1ai,bi109

题解

可以选择 AB 中的任意一个数对它 +1,由于题目目标是将 A,B 变成一样的,所以这个操作其实等价于:选择 A 中的任意一个数 ai 对它 +11. 那么将 ai 变成 x 的代价就是 |aix|.

对于一对大小为 mA,B,首先分别对其排序,根据上面的转化,那么将 A,B 变成一样的总代价就是 C(A,B)=i=0m|aibi|,即考虑按下标一一对应进行变化时代价是最小的。

其实不按照下标一一对应,也可以是代价一样最小的变换方式。但是按下标一一对应进行变换能够方便后续的计算,因此我们不妨设变换方式就是按照下标一一对应。

然后,该题思路是对于枚举每一对 si,tj,计算当它们在 A,B 同一下标时对答案的贡献。若用 cnt 代表 si,tjA,B 同一下标的情况总数,那么对答案的贡献便是:cnt|sitj|. 因此,现在问题的关键就变成了怎么计算这个情况总数 cnt.

考虑下面这个示例,被框住的数字是目前正在考虑的数字:

s:13469t:367910

那么能让 4,6 处于同一下标的子区间可能性有:

  • 左边选 0 个,右边选 0 个:长度为 1,选中数字处于下标 1,共 1 种。
  • 左边选 0 个,右边选 1 个:长度为 2,选中数字处于下标 1,共 6 种。
  • 左边选 0 个,右边选 2 个:长度为 3,选中数字处于下标 1,共 3 种。
  • 左边选 1 个,右边选 0 个:长度为 2,选中数字处于下标 2,共 2 种。
  • 左边选 1 个,右边选 1 个:长度为 3,选中数字处于下标 2,共 12 种。
  • 左边选 1 个,右边选 2 个:长度为 4,选中数字处于下标 2,共 6 种。

30 种情况,因此这种情况的贡献为 |46|×30=60.

接下来分析一下为什么是这么多。对于一对 i,jsi 左侧有 i1 个数,右侧有 ni 个数,tj 左侧有 j1 个数,右侧有 nj 个数。

那么选择的区间左侧最多放 min{i1,j1} 个数,右侧最多放 min{ni,nj} 个数,否则就会破坏下标对应关系。

(左右侧的计算是等价的,下面都只以右侧为例)那么右侧的选择数的情况有:

k=0min{ni,nj}Cnik

但上面只考虑了选择的方式,实际上还要考虑选出来的数怎么对应。例如上面这个例子,如果 si 右侧选择 6,9,那么 tj 右侧可以选择 7,9/7,10/9,10 这三种进行对应。

因此,右侧的情况数应该是:

k=0min{ni,nj}CnikCnjk

计算这个情况数的时间复杂度是 O(n) 的,那总时间复杂度会来到 O(n3) 而超时,因此必须要化简。

根据范德蒙德卷积公式

i=0kCniCmki=Cn+mk

它有推论:

i=0kCniCmi=Cn+mm

这个推论正好和上面的情况数对应上了,于是情况数转化成:

k=0min{ni,nj}CnikCnjk=C(ni)+(nj)ni

这下就能在 O(1) 时间得到一对 si,tj 的贡献了,那么该题的总答案为:

answer=i=1nj=1n(|sitj|C(i1)+(j1)i1C(ni)+(nj)ni)

预处理组合数后,时间复杂度为 O(n2).

代码

#include <bits/stdc++.h>
#define endl '\n'
#define int long long

using namespace std;

constexpr int MAXN = 4e3 + 10, MOD = 998244353;
int comb[MAXN][MAXN];

void init_comb()
{
    for (int i = 0; i < MAXN; i++)
        comb[i][0] = 1;
    for (int i = 1; i < MAXN; i++)
        for (int j = 1; j < MAXN; j++)
            comb[i][j] = (comb[i - 1][j] + comb[i - 1][j - 1]) % MOD;
}

int n;
int s[MAXN], t[MAXN];

void solve()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> s[i];
    for (int i = 1; i <= n; i++)
        cin >> t[i];
    sort(s + 1, s + 1 + n);
    sort(t + 1, t + 1 + n);
    int ans = 0;
    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= n; j++)
        {
            int li = i - 1, lj = j - 1;
            int ri = n - i, rj = n - j;
            int lcnt = comb[li + lj][li];
            int rcnt = comb[ri + rj][ri];
            int delta = abs(s[i] - t[j]) % MOD;
            ans = (ans + lcnt * rcnt % MOD * delta % MOD) % MOD;
        }
    }
    cout << ans << endl;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    init_comb();
    int t = 1;
    // cin >> t;
    while (t--)
        solve();
    return 0;
}