Palindrome Mouse(2019年牛客多校第六场C题+回文树+树状数组)

题目链接

传送门

题意

(s)串中所有本质不同的回文子串中有多少对回文子串满足(a)(b)的子串。

思路

参考代码:传送门

本质不同的回文子串肯定是要用回文树的啦~

在建好回文树后分别对根结点为(0,1)的子树进行(dfs),处理出以每个结点为根结点的子树的大小(sz)(也就是说有多少个回文子串以其为中心)和其(dfs)序,回文子串包含除了作为其他回文子串的中心被包含外,还可以不作为中心被包含,而这一部分则需要靠回文树的(fail)数组来进行处理。

我们先用(vector)存下有多少个结点的(fail)数组指向(i),然后把这些结点按照其对应的回文串长度进行排序,用树状数组来防止去重,加入这个结点对应的(dfs)序没被覆盖,那么就加上这个结点的(sz),否则就不加。此处举个例子帮助理解:(cac,cedcacdec)(fail)数组都指向了(c),但是(cedcacdec)(cac)子树中的结点,我们在加(cac)的时候已经把(cedcacdec)的贡献计算过了,如果再加一次就会重复,因此如果某个结点的(dfs)序被前面长度短的结点包含过,那么就不用加进答案中。

代码

#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> pil;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;

#define lson (rt<<1),L,mid
#define rson (rt<<1|1),mid + 1,R
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********
")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("/home/dillonh/CLionProjects/Dillonh/in.txt","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)

const double eps = 1e-8;
const int mod = 1000000007;
const int maxn = 100000 + 7;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;

int _, n, cnt;
char s[maxn];
vector<int> vec[maxn];
int sz[maxn], ls[maxn], rs[maxn], vis[maxn];

struct PAM {
    //len数组表示以i为结尾的最长回文子串长度
    //tot为结点数,lst为上一个字符加的位置
    int N;
    int str[maxn];
    int ch[maxn][30], fail[maxn], len[maxn], cnt[maxn], tot, lst;
    void init() {
        for(int i = 0; i <= n + 1; ++i) {
            cnt[i] = len[i] = fail[i] = 0;
            for(int j = 0; j <= 26; ++j) ch[i][j] = 0;
        }
        N = lst = 0; tot = 1; fail[0] = fail[1] = 1; len[1] = -1;
    }
    inline void add(int c) {
        int p = lst;
        str[++N] = c;
        while(str[N - len[p] - 1] != str[N]) p = fail[p];
        if(!ch[p][c]) {
            int now = ++tot, k = fail[p];
            len[now] = len[p] + 2;
            while(str[N - len[k] - 1] != str[N]) k = fail[k];
            fail[now] = ch[k][c]; ch[p][c] = now;
        }
        lst = ch[p][c]; cnt[lst]++;
    }
    inline void solve() {
        for(int i = tot; i; i--) {
            cnt[fail[i]] += cnt[i];
        }
    }
}pam;

int tree[maxn];

void add(int x, int val) {
    while(x < maxn) {
        tree[x] += val;
        x += lowbit(x);
    }
}

int query(int x) {
    int ans = 0;
    while(x) {
        ans += tree[x];
        x -= lowbit(x);
    }
    return ans;
}

void dfs(int u) {
    sz[u] = 1;
    ls[u] = ++cnt;
    for(int i = 1; i <= 26; ++i) {
        if(pam.ch[u][i]) {
            dfs(pam.ch[u][i]);
            sz[u] += sz[pam.ch[u][i]];
        }
    }
    rs[u] = cnt;
}

int main() {
#ifndef ONLINE_JUDGE
    FIN;
#endif
    scanf("%d", &_);
    for(int __ = 1; __ <= _; ++__) {
        scanf("%s", s + 1);
        n = strlen(s + 1);
        pam.init();
        for(int i = 1; i <= n; ++i) pam.add(s[i] - 'a' + 1);
        cnt = 0;
        dfs(1);
        dfs(0);
        LL ans = 0;
        for(int i = 2; i <= pam.tot; ++i) vec[i].clear();
        for(int i = 2; i <= pam.tot; ++i) {
            if(pam.fail[i] >= 2) vec[pam.fail[i]].emplace_back(i);
        }
        for(int i = 2; i <= pam.tot; ++i) {
            vec[i].emplace_back(i);
            sort(vec[i].begin(), vec[i].end(), [](int x, int y) {return pam.len[x] < pam.len[y];});
            LL sum = 0;
            for(int j = 0; j < (int)vec[i].size(); ++j) {
                int u = vec[i][j];
                if(query(ls[u]) == 0) {
                    sum += sz[u];
                    add(ls[u], 1);
                    add(rs[u] + 1, -1);
                    vis[j] = 1;
                }
            }
            ans += sum - 1;
            for(int j = 0; j < (int)vec[i].size(); ++j) {
                if(!vis[j]) continue;
                int u = vec[i][j];
                add(ls[u], -1);
                add(rs[u] + 1, 1);
                vis[j] = 0;
            }
        }
        printf("Case #%d: %lld
", __, ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Dillonh/p/11397252.html