luogu4770 [NOI2018]你的名字 后缀自动机 + 线段树合并


其实很水的一道题吧....

题意是:每次给定一个串(T)以及(l, r),询问有多少个字符串(s)满足,(s)(T)的子串,但不是(S[l .. r])的子串


统计(T)本质不同的串,建个后缀自动机

然后自然的可以想到,对于每个(T)的子串,它对应了一个(right)集合

那么,它应该会被这个(right)集合所限制

考虑对于每个(i),求出最小的(l)使得(T[l .. i])存在于(S[l..r])

这个可以套个线段树转移

然后就没了.....


如果不需要统计(T)本质不同的串,又怎么做呢?

统计的时候乘上(right)集合大小就行


#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

#define ri register int
#define ll long long
#define rep(io, st, ed) for(ri io = st; io <= ed; io ++)
#define drep(io, ed, st) for(ri io = ed; io >= st; io --)

#define gc getchar
inline int read() {
    int p = 0, w = 1; char c =  gc();
    while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
    while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
    return p * w;
}

const int sid = 1005000;
const int eid = 30000000 + 5;

struct SAM {
    
    int id, fa[sid], mx[sid];
    int go[sid][26], mc[sid];
    
    inline int newnode() {
        ++ id;
        fa[id] = mx[id] = mc[id] = 0;
        memset(go[id], 0, sizeof(go[id]));
        return id;
    }
    
    inline void init() {
        id = 0;
        newnode();
    }
    
    inline int extend(int lst, int c, int pos) {
        int np = newnode(), p = lst;
        mx[np] = mx[p] + 1; mc[np] = pos;
        for( ; p && !go[p][c]; p = fa[p]) 
            go[p][c] = np;
        if(!p) fa[np] = 1;
        else {
            int q = go[p][c];
            if(mx[p] + 1 == mx[q]) fa[np] = q;
            else {
                int nq = newnode(); mx[nq] = mx[p] + 1;
                fa[nq] = fa[q]; fa[np] = fa[q] = nq;
                memcpy(go[nq], go[q], sizeof(go[q]));
                for( ; p && go[p][c] == q; p = fa[p]) 
                    go[p][c] = nq;
            }
        }
        return np;
    }
    
} S, T;

int q, n, m, seg;
char s[sid], t[sid];
int nc[sid], ip[sid], w[sid], val[sid];
int rt[sid], ls[eid], rs[eid];

inline int merge(int x, int y) {
    if(!x || !y) return x + y;
    int o = ++ seg;
    ls[o] = merge(ls[x], ls[y]);
    rs[o] = merge(rs[x], rs[y]);
    return o;
}

inline void ins(int &o, int l, int r, int p) {
    o = ++ seg;
    if(l == r) return;
    int mid = (l + r) >> 1;
    if(p <= mid) ins(ls[o], l, mid, p);
    else ins(rs[o], mid + 1, r, p);
}

inline bool qry(int o, int l, int r, int ml, int mr) {
    if(ml > r || mr < l || ml > mr || !o) return 0;
    if(ml <= l && mr >= r) return 1;
    int mid = (l + r) >> 1;
    if(qry(ls[o], l, mid, ml, mr)) return 1;
    else return qry(rs[o], mid + 1, r, ml, mr);
}

inline void init() {
    S.init();
    int lst = 1;
    rep(i, 1, n) lst = S.extend(lst, s[i] - 'a', i);
    int id = S.id;
    rep(i, 1, id) nc[S.mx[i]] ++;
    rep(i, 1, n) nc[i] += nc[i - 1];
    rep(i, 1, id) ip[nc[S.mx[i]] --] = i;
    rep(i, 1, id) 
        if(S.mc[i]) 
            ins(rt[i], 1, n, S.mc[i]);
    drep(i, id, 1) {
        int o = ip[i], f = S.fa[o];
        rt[f] = merge(rt[f], rt[o]);
    }
}

void Match(int l, int r) {
    int o = 1, nl = 0;
    rep(i, 1, m) {
        int c = t[i] - 'a';
        while(1) 
        {
            int nxt = S.go[o][c], f = S.fa[o];
            if(nxt && qry(rt[nxt], 1, n, l + nl, r)) 
            {
                nl ++; o = nxt;
                break;
            }
            if(!nl) break; nl --;
            if(nl == S.mx[f]) o = f;
        }
        w[i] = nl;
    }
}

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    init(); q = read();
    rep(i, 1, q) {
        
        scanf("%s", t + 1);
        m = strlen(t + 1);
        
        T.init();
        int lst = 1;
        rep(j, 1, m) lst = T.extend(lst, t[j] - 'a', j);
        
        int l = read(), r = read();
        Match(l, r);
            
        int id = T.id;
        rep(i, 1, id) nc[i] = val[i] = 0;
        rep(i, 1, id) nc[T.mx[i]] ++;
        rep(i, 1, id) nc[i] += nc[i - 1];
        rep(i, 1, id) ip[nc[T.mx[i]] --] = i;
        drep(i, id, 1) {
            int o = ip[i], f = T.fa[o];
            if(T.mc[o]) val[o] = w[T.mc[o]];
            val[f] = max(val[f], val[o]);
        }
        
        ll ans = 0;
        rep(i, 1, id) ans += max(T.mx[i] - max(T.mx[T.fa[i]], val[i]), 0);
        printf("%lld
", ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/reverymoon/p/10029355.html