2023牛客暑期多校训练营8

I - Make It Square

https://ac.nowcoder.com/acm/contest/57362/I

题意

给定串 ST,问有多少种长度均为 i 的串 P,Q,使得 concat(P,S,Q,T) 为一个 AA 型的串,其中 A 为一任意非空字符串。可用字符集为小写字母。对 i[1,m] 求出答案。1m1061|S|,|T|106

题解

特殊情况

首先特判掉几个特殊情况,从而将问题拆分简化。

  • |S|=|T|

由于 |P|=|Q|,|S|=|T|,那么字符串正好从中间分开,满足 P+S=Q+T,即满足 P=Q,S=T.

那么可以直接判断:如果 ST,答案均为 0;如果 S=T,答案为 26m.

  • |S|+|T| 为奇数

由于 |P|=|Q||S|+|T| 为奇数,那么 |P+S+Q+T| 一定为奇数,不满足题目条件,答案均为 0.

|S|>|T| 的情况

这种情况的示意图如下,为了方便,令 |P|=|Q|=i,|S|=a,|T|=b

由于要求整个字符串从中心分开后,两个子串一致。因此我们把上面这个字符串从中心切成两半,再对齐摆到一起,此时两个字符串的对应位置应该时相等的。

  1. ab2i>0,那么这个图的样子如下(图中标注的均为长度):

上面两行代表的字符串必须对应位置相对,而 ST 有交叉的部分从图上看有两个:长度为 b 的一段、长度为 ab2i 的一段。

因此如果要满足题目条件,那么 ST 对应位置必须满足以下条件(中括号代表按下标取子串):

  • S[ab2(ab2+b1)]=T[0(b1)]
  • S[0(ab2i1)]=S[(aab2+i)(a1)]

如果 ST 不满足上面两个条件,那么这种情况数为 0.

如果 ST 满足上面两个条件,可以发现 PQ 的位置已经被 S 字符串定死了,因此只有一种情况,情况数为 1.

对于如何判断是否符合上面的条件,第一条直接遍历判断就行,复杂度为 O(n),而对于第二条,就不能遍历判断了,因为 i[1,m]1m106,如果对于每个 i 执行一次,那么复杂度会来到 O(n2).

第二条的含义是字符串 S 的长度为 ab2i 的前缀和后缀是否相等,这个正好可以用 KMP 算法在 O(n) 时间内解决。将 S 作为模式串用 KMP 算法计算一遍部分匹配表,然后根据部分匹配表就可以得到 S 串的所有前后缀相对的长度。对于每个 i 直接查询 ab2i 这个长度在不在答案内即可。

  1. ab2i0,那么这个图的样子如下:

可以发现,此时 ST 有交叉的部分只有一个:长度为 b 的一段。

那么如果要满足题目条件, ST 对应位置只用满足:

  • S[ab2(ab2+b1)]=T[0(b1)]

如果 ST 不满足上面的条件,那么这种情况数为 0.

如果 ST 满足上面的条件,可以发现 PQ 的位置一部分被 S 定死了,而中间有一段长度为 iab2 交叉的地方,是可以自由取值的,那么情况数为:26iab2.

|S|<|T| 的情况

这种情况的示意图如下:

这种情况实际上和上面的思考方式完全一致,实际上代码只需要改几个变量名就行了,因此下面只放出来示意图,不再解释。

  1. ba2i>0

  1. ba2i0

代码

#include <bits/stdc++.h>
#define endl '\n'
#define int long long

using namespace std;

constexpr int MOD = 998244353;

class KMP
{
    vector<int> nxt;
    string pat;

public:
    KMP(string &s)
    {
        pat = s;
        int n = pat.length();
        int j = 0;
        nxt.resize(n);
        for (int i = 1; i < n; i++)
        {
            while (j > 0 && pat[i] != pat[j])
                j = nxt[j - 1];
            if (pat[i] == pat[j])
                j++;
            nxt[i] = j;
        }
    }
    vector<int> find(string &txt)
    {
        int n = pat.length(), m = txt.length();
        int j = 0;
        vector<int> ans;
        for (int i = 0; i < m; i++)
        {
            while (j > 0 && txt[i] != pat[j])
                j = nxt[j - 1];
            if (txt[i] == pat[j])
                j++;
            if (j == n)
            {
                ans.push_back(i - n + 1);
                j = nxt[j - 1];
            }
        }
        return ans;
    }
    set<int> get_border()
    {
        set<int> s;
        int cur = nxt.back();
        while (cur)
        {
            s.insert(cur);
            cur = nxt[cur - 1];
        }
        s.insert(0);
        return s;
    }
};

void solve()
{
    int m;
    cin >> m;
    string s, t;
    cin >> s >> t;
    int ls = s.size(), lt = t.size();
    if (ls == lt)
    {
        if (s != t)
        {
            for (int i = 1; i <= m; i++)
                cout << 0 << " \n"[i == m];
        }
        else
        {
            int ans = 1;
            for (int i = 1; i <= m; i++)
            {
                ans = ans * 26 % MOD;
                cout << ans << " \n"[i == m];
            }
        }
        return;
    }
    if ((ls + lt) % 2 == 1)
    {
        for (int i = 1; i <= m; i++)
            cout << 0 << " \n"[i == m];
        return;
    }
    if (ls > lt)
    {
        for (int i = lt - 1, j = ls - 1 - (ls - lt) / 2; i >= 0; i--, j--)
        {
            if (t[i] != s[j])
            {
                for (int i = 1; i <= m; i++)
                    cout << 0 << " \n"[i == m];
                return;
            }
        }
        int ans = 1;
        KMP kmp_s(s);
        auto border_s = kmp_s.get_border();
        for (int i = 1; i <= m; i++)
        {
            if (i > (ls - lt) / 2)
            {
                ans = ans * 26 % MOD;
                cout << ans << " \n"[i == m];
            }
            else
            {
                if (border_s.count((ls - lt) / 2 - i) == 0)
                    cout << 0 << " \n"[i == m];
                else
                    cout << 1 << " \n"[i == m];
            }
        }
    }
    else // if (ls <= lt)
    {
        for (int i = ls - 1, j = lt - 1 - (lt - ls) / 2; i >= 0; i--, j--)
        {
            if (s[i] != t[j])
            {
                for (int i = 1; i <= m; i++)
                    cout << 0 << " \n"[i == m];
                return;
            }
        }
        int ans = 1;
        KMP kmp_t(t);
        auto border_t = kmp_t.get_border();
        for (int i = 1; i <= m; i++)
        {
            if (i > (lt - ls) / 2)
            {
                ans = ans * 26 % MOD;
                cout << ans << " \n"[i == m];
            }
            else
            {
                if (border_t.count((lt - ls) / 2 - i) == 0)
                    cout << 0 << " \n"[i == m];
                else
                    cout << 1 << " \n"[i == m];
            }
        }
    }
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t = 1;
    // cin >> t;
    while (t--)
        solve();
    return 0;
}