2023牛客暑期多校训练营8

H - Insert 1, Insert 2, Insert 3, ...

https://ac.nowcoder.com/acm/contest/57362/H

题意

给定长度为 $n$ 的序列 $\{a\}_{i=1}^n$,问有多少个连续子区间 $[l,r]$ 满足这个子区间可以通过若干次依次按顺序插入 $1,2,3\cdots,k$ 的子序列构成。$1\leq n\leq 10^6$,$1\leq a_i\leq n$。

题解

思路

一个合法的区间一定以 $1$ 开头,那么就可以考虑对于每个右端点 $r$,计算出有多少个左侧的 $1$ 能和它组成合法区间,统计总数便是题目答案。

对于一个数 $x$,它必须与左侧的最近的一个还没有被匹配的 $x-1$ 匹配,如果没有能匹配的则不合法。如:

$$ \hat1,\bar1,\bar2,\hat2,\hat3,\bar3,3 $$

数字顶上标记相同的代表匹配为一组,可以看到最后一个 $3$ 需要在左边找到一个 $2$ 匹配,但是在左边的 $2$ 已经全部被其他 $3$ 匹配占用了,没有剩余的还没有被匹配的 $2$。因此这个 $3$ 不合法,它的答案是 $0$.

如果匹配上了,那答案的数量如何计算。对于每个匹配的一组数,找到它们起点(也就是 $1$)和起点前面一共有几个 $1$,数量便是合法区间的数量了。如:

$$ 1,\bar1,\hat1,\hat2,\hat3,\bar2,\bar3 $$

对于第一个 $3$,它的起点是第三个 $1$,前面一共 $3$ 个 $1$,因此答案是 $3$。对于第二个 $3$,它的起点是第二个 $1$,因此答案是 $2$。

不过有一个要点是,如果有一个不合法的数将序列分开,那么左边的 $1$ 就不能再被匹配了,因为如果想要匹配左边的 $1$ 必须跨过那个不合法的数,那这个区间肯定就不合法了。如:

$$ 1,1,1,1,1,99,1,2,3 $$

对于最后一个 $3$,它的起点前面的 $1$ 被 $99$ 这个不合法的数挡住了,所以总共只有 $1$ 个能够匹配。

维护方式

使用 $10^6$ 个栈(用 vector 而不是 stack,要不然 MLE)维护每个数的位置,对于一个数 $x$,在第 $x-1$ 个栈中找栈顶,如果是空栈则这个数不成立,如果找到栈中有数,则将其弹出,并在第 $x$ 个栈中压入 $x$ 的下标。如此就可以快速找到左侧的最近的一个还没有被匹配的 $x-1$ 的位置。

对于 $1$ 的个数的维护,使用单调栈完成。考虑到当从左到右考虑时,能匹配的 $1$ 的位置从右往左移动,且是一个单调的。因此可以用单调栈弹掉不合法的 $1$,栈中剩下的数目便是答案数。对于碰见不合法的数,直接把栈弹空就行。

代码

#include <bits/stdc++.h>
#define endl '\n'
#define int long long

using namespace std;

constexpr int MAXN = 1e6 + 10;
vector<int> p[MAXN], s;

void solve()
{
    int n;
    cin >> n;
    int ans = 0;
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        if (x == 1)
        {
            p[1].push_back(i);
            s.push_back(i);
            ans += s.size();
        }
        else if (p[x - 1].empty())
        {
            s.clear();
        }
        else
        {
            int y = p[x - 1].back();
            p[x - 1].pop_back();
            p[x].push_back(y);
            while (s.size() && s.back() > y)
                s.pop_back();
            ans += s.size();
        }
    }
    cout << ans << endl;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t = 1;
    // cin >> t;
    while (t--)
        solve();
    return 0;
}