2023牛客暑期多校训练营6

A - Tree

https://ac.nowcoder.com/acm/contest/57360/A

题意

给定 n 个点的一棵边带权的树,点有黑白二色(0,1 表示),现在可以以 ai 的价值翻转第 i 个点的颜色,一对异色点 (u,v) 的价值为树上路径的最大边权值。问经过任意颜色翻转后,价值减去代价的最大值。1n3×1031ai109

题解

该题有关树上两点间路径的最大边权,考虑借助 Kruskal 重构树解决这道题。

首先将题目给出的带边权的树转换为带点权的 Kruskal 重构树,vali 是重构树的点权。

那么如果对于重构树的一个内部节点 p,如果 a 点在它的左子树,b 点在它的右子树,那么 a,b 两点的最近公共祖先就是 p,根据重构树性质,a,b 两点路径的最大边权就是 valp.

如果左子树大小为 nl,右子树大小为 nr,左子树有 nl,black 个黑点,左子树有 nl,white 个白点,右子树有 nr,black 个黑点,右子树有 nr,white 个白点,那么由 p 节点代表的这条边的贡献就是:

valp(nl,blacknr,white+nl,whitenr,black)

然后,遍历所有的黑点分配方式,便能算出最大的共享值。当然,这只计算了经过这个 p 节点产生的贡献。要解决整个问题,要使用动态规划的思想。

dpi,j 代表对于重构树上 i 节点,左右子树加起来共 j 个黑点的最大贡献。

转移方式为,其中左子树根节点下标 l,右子树根节点下标 r,左子树大小为 nl,右子树大小为 nr

dpi,j=maxk=max{0,jnr}min{nl,j}(dpl,nl,black+dpr,nr,black+valp(nl,blacknr,white+nl,whitenr,black)),{nl,black=knl,white=nlnl,blacknr,black=jknr,white=nrnr,black

代码

#include <bits/stdc++.h>
#define endl '\n'
#define int long long

using namespace std;

constexpr int MAXN = 3e3 + 10;
int n;
int a[MAXN], cost[MAXN];
int fa[MAXN << 1];
int val[MAXN << 1], siz[MAXN << 1]; // 重构树点权, 重构树子树大小
vector<int> krus[MAXN << 1];        // Kruskal 重构树
int dp[MAXN << 1][MAXN];

void init()
{
    for (int i = 0; i < (MAXN << 1); i++)
        fa[i] = i;
}

int find(int x)
{
    return (fa[x] == x) ? x : (fa[x] = find(fa[x]));
}

void dfs(int now)
{
    if (now <= n)
    {
        siz[now] = 1;
        dp[now][a[now]] = 0;
        dp[now][1 - a[now]] = -cost[now];
        return;
    }
    int l_son = krus[now].front(), r_son = krus[now].back();
    dfs(l_son);
    dfs(r_son);
    siz[now] = siz[l_son] + siz[r_son];
    for (int blk = 0; blk <= siz[now]; blk++)
    {
        dp[now][blk] = INT64_MIN;
        for (int l = max(0LL, blk - siz[r_son]); l <= min(blk, siz[l_son]); l++)
        {
            int l_black = l, l_white = siz[l_son] - l_black;
            int r_black = blk - l, r_white = siz[r_son] - r_black;
            int tmp = val[now] * (l_black * r_white + l_white * r_black);
            tmp += dp[l_son][l_black];
            tmp += dp[r_son][r_black];
            dp[now][blk] = max(dp[now][blk], tmp);
        }
    }
}

void solve()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= n; i++)
        cin >> cost[i];
    vector<array<int, 3>> edge(n + 10);
    for (int i = 1; i < n; i++)
        cin >> edge[i][0] >> edge[i][1] >> edge[i][2];
    sort(edge.begin() + 1, edge.begin() + n, [](array<int, 3> a, array<int, 3> b) { return a[2] < b[2]; });
    int pos = n + 1;
    for (int i = 1; i < n; i++)
    {
        auto &[u, v, w] = edge[i];
        int fa_u = find(u), fa_v = find(v);
        if (fa_u == fa_v)
            continue;
        krus[pos].push_back(fa_u);
        krus[pos].push_back(fa_v);
        val[pos] = w;
        fa[fa_u] = pos;
        fa[fa_v] = pos;
        pos++;
    }
    int root = find(1);
    dfs(root);
    int ans = 0;
    for (int i = 0; i <= n; i++)
        ans = max(ans, dp[root][i]);
    cout << ans << endl;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t = 1;
    // cin >> t;
    init();
    while (t--)
        solve();
    return 0;
}