2020牛客暑期多校训练营(第二场)A

Description

定义 (f(s,t)) 为最大的 (i) 满足 (s) 的长为 (i) 的前缀和 (t) 的长为 (i) 的后缀相等。给定 (n) 个字符串 (s_1,s_2,...,s_n),求 (sum_{i=1}^n sum_{j=1}^n f(s_i,s_j)^2)

Solution

考虑哈希,将每个后缀存入桶中。先不考虑 (f(s,t)) 的最大,而考虑所有满足条件的和,那么暴力扫描每个前缀,设桶中有 (k) 个后缀与它匹配,那么对答案的贡献为 (kl^2)

但由于我们要算的是最大的,这样其中会有重复,观察发现这种情况会发生,是因为我们枚举的前缀有一个非空的 Border,即对这个前缀 (s[1..i]),有 (s[1..j]=s[i-j+1..i], j<i),因此我们需要减去 (s[1..j]) 的贡献(减去 (kj^2),这里的 (k) 是本次的匹配数目),因为这些已经不是最长的了。

因此一个 Hash 加一个 Kmp 即可解决问题。

#include <bits/stdc++.h>
using namespace std;

#define int long long 
#define ull unsigned long long
const int N = 1000005;
const ull bas = 131;
const int mod = 998244353;

ull basPower[N];

// 定义为 fail[i] 表示 Border(s[0..i])
vector<int> kmp(string s)
{
    int n=s.length();
    s+=' ';
    vector<int> fail(n+1);
    for(int i=1;i<=n;i++)
    {
        fail[i]=fail[i-1];
        while(s[fail[i]]!=s[i] && fail[i]) fail[i]=fail[fail[i]-1];
        if(s[fail[i]]==s[i]) ++fail[i];
    }
    return fail;
}

struct HashString 
{
    string str;
    vector<ull> hash;
    void presolve(string srcString)
    {
        str=srcString;
        hash.clear();
        int n=str.length();
        hash.resize(n);
        hash[0]=str[0];
        for(int i=1;i<n;i++) hash[i]=str[i]+hash[i-1]*bas;
    }
    HashString()
    {

    }
    HashString(string srcString)
    {
        presolve(srcString);
    }
    ull getHash(int l,int r)
    {
        return hash[r]-(l?hash[l-1]:0)*basPower[r-l+1];
    }
};

signed main()
{
    ios::sync_with_stdio(false);

    basPower[0]=1;
    for(int i=1;i<N;i++)
    {
        basPower[i]=basPower[i-1]*bas;
    }

    /*string str;
    cin>>str;
    vector <int> fail = kmp(str);
    HashString hashString;
    hashString.presolve(str);*/

    int n;
    cin>>n;
    vector <string> strSet(n);
    vector <HashString> hashstrSet(n);
    map <ull,int> mp;       // 每种后缀的出现次数
    for(int i=0;i<n;i++)
    {
        cin>>strSet[i];
        hashstrSet[i].presolve(strSet[i]);
        int len=strSet[i].length();
        for(int j=0;j<len;j++)
        {
            mp[hashstrSet[i].getHash(j,len-1)]++;
        }
    }
    
    int ans=0;

    for(int i=0;i<n;i++)
    {
        string &str=strSet[i];
        HashString &hashstr=hashstrSet[i];
        int len=str.length();
        vector <int> nextArray=kmp(str);

        for(int j=0;j<len;j++)
        {
            ans+=mp[hashstr.getHash(0,j)]*(j+1)%mod*(j+1);
            if(nextArray[j]>0)
            {
                ans-=mp[hashstr.getHash(0,j)]*nextArray[j]%mod*nextArray[j];
            }
            ans%=mod;
            ans+=mod;
            ans%=mod;
        }
    }

    cout<<ans<<endl;

    return 0;
}
原文地址:https://www.cnblogs.com/mollnn/p/13792114.html