SAM

这是广义SAM

#include<bits/stdc++.h>
#define rep(x, L, R) for(int x = (L), _x = (R); x <= _x; x++)
using namespace std;
const int N = 2e6 + 10, md = 998244353, T = N << 1, iv2 = (md + 1) / 2;
int fa[T], ch[T][26], dep[T], cnt = 0;
int n, p[T], sum[T][2];
int ans[3], res;
char s[N];
vector<int> g[T];
vector<int> str[N];
int add(int x) {return x >= md ? x - md : x;}
int sub(int x) {return x < 0 ? x + md : x;}
 
void Add(int &x, int y) {
    if((x += y) >= md && (x -= md));
}
 
void Sub(int &x, int y) {
    if((x -= y) < 0 && (x += md));
}
 
void adde(int u, int v) {
    g[u].push_back(v);
}
 
int newd(int d) {
    int u = ++cnt;
    dep[u] = d;
    memset(ch[u], 0, sizeof(ch[u]));
    fa[u] = sum[u][0] = sum[u][1] = 0;
    g[u].clear();
    return u;
}
 
void Clear() {
    cnt = 0;
    return;
}
 
int ins(int p, int c) {
    if(ch[p][c]) {
        int q = ch[p][c];
        if(dep[q] == dep[p] + 1) return q;
        int np = newd(dep[p] + 1);
        memcpy(ch[np], ch[q], sizeof(ch[q]));
        fa[np] = fa[q], fa[q] = np;
        for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = np;
        return np;
    }
    int np = newd(dep[p] + 1);
    for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
    if(!p) fa[np] = 1;
    else {
        int q = ch[p][c];
        if(dep[q] == dep[p] + 1) fa[np] = q;
        else {
            int nq = newd(dep[p] + 1);
            memcpy(ch[nq], ch[q], sizeof(ch[q]));
            fa[nq] = fa[q], fa[q] = fa[np] = nq;
            for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
        }
    }
    return np;
}
 
 
void Dfs(int u, int fa) {
    rep(i, 0, 1) {
        rep(j, 0, 1) {
            Add(ans[i + j], 1ll * sum[u][i] * sum[u][j] % md * dep[u] % md);
        }
    }
    for(auto v : g[u]) {
        if(v == fa) continue;
        Dfs(v, u);
        rep(i, 0, 1) {
            rep(j, 0, 1) {
                Add(ans[i + j], 1ll * sum[u][i] * sum[v][j] % md * dep[u] % md * 2 % md);
            }
        }
        rep(i, 0, 1) Add(sum[u][i], sum[v][i]);
    }
    return;
}
 
void calc() {
    Clear();
    memset(ans, 0, sizeof(ans));
    newd(0);
    for(int i = 1; i <= n; i++) {
        int len = str[i].size();
        int rt = 1;
        for(int j = 0; j < len; j++) {
            rt = ins(rt, str[i][j]);
            Add(sum[rt][1], p[i]);
        }
        rt = 1;
        for(int j = len - 1; j >= 0; j--) {
            rt = ins(rt, str[i][j]);
            Add(sum[rt][0], sub(1 - p[i]));
        }
    }
    for(int i = 2; i <= cnt; i++) adde(i, fa[i]), adde(fa[i], i);
    Dfs(1, 0);
    Add(res, add(add(ans[0] + ans[1]) + ans[2]));
    return;
}
 
void sub(vector<int> &str, int p) {
    Clear();
    memset(ans, 0, sizeof(ans));
    int len = str.size();
    int rt = newd(0);
    for(int i = 0; i < len; i++) {
        rt = ins(rt, str[i]);
        Add(sum[rt][1], 1);
    }
    rt = 1;
    for(int i = len - 1; i >= 0; i--) {
        rt = ins(rt, str[i]);
        Add(sum[rt][0], 1);
    }
    for(int i = 2; i <= cnt; i++) adde(i, fa[i]), adde(fa[i], i);
    Dfs(1, 0);
    Sub(res, 1ll * ans[0] * sub(1 - p) % md * sub(1 - p) % md);
    Sub(res, 1ll * ans[2] * p % md * p % md);
    Sub(res, 1ll * ans[1] * p % md * sub(1 - p) % md);
    Add(res, 1ll * ans[0] * sub(1 - p) % md);
    Add(res, 1ll * ans[2] * p % md);
    return ;
}
 
 
int main() {
//  freopen("in.in", "r", stdin);
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) scanf("%d", &p[i]);
    for(int i = 1; i <= n; i++) {
        scanf("%s", s + 1);
        int len = strlen(s + 1);
        for(int j = 1; j <= len; j++) str[i].push_back(s[j] - 'a');
    }
    calc();
    for(int i = 1; i <= n; i++) sub(str[i], p[i]);
    printf("%d
", res);
    return 0;
}

一定要记住memcpy和copy的区别在于memcpy的目标地址放在前面,copy放在后面

原文地址:https://www.cnblogs.com/SegmentTree/p/13052814.html