LOJ#6031. 「雅礼集训 2017 Day1」字符串 根号分治+SAM+倍增

怎么想都没想出来 $log n$ 做法,那么这道题基本就是根号分治了.    

题目描述中保证 $sum k leqslant 10^5$,然后 $k$ 在每次询问中又是相同的,那么就考虑对 $k$ 根号分治.     

先对 $s$ 建立后缀自动机,然后把倍增数组求出来.   

我们设块的大小为 $B$,那么当 $k leqslant B$ 时可以对 $k$ 的每一个子串在 $s$ 上都求一遍出现次数 (暴力跳祖先).  

其中对于编号为 $[a,b]$ 的限制我们可以直接开一个二维的 vector 存储查询为 $(l,r)$ 的编号,然后 lowerbound 一下就行.   

这个的复杂度是 $O(k^2 log n)$ 的,总复杂度是 $O(k^2 Q log n)$,即 $O(B nlog n)$.   

对于 $k>B$ 时询问次数不会超过 $frac{10^5}{B}$ 个,那么可以直接对询问按照右端点离线,然后将 $w$ 在 $s$ 上匹配,最后再倍增一下.  

这部分的复杂度是 $O(frac{Q}{k} n log n)$ 的.  

这个 $B$ 取到 400 或 $sqrt n$ 即可.  

code:  

#include <cstdio>  
#include <vector>
#include <cstring>
#include <algorithm>    
#define N 100009   
#define ll long long 
#define pb push_back
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
const int B=403;    
char str[N]; 
int n,m,Q,k,tot,last,edges;        
ll cnt[N<<1];  
int pre[N<<1],ch[N<<1][27],mx[N<<1]; 
int hd[N<<1],to[N<<1],nex[N<<1],fa[20][N<<1];  
void add(int u,int v) {  
    nex[++edges]=hd[u]; 
    hd[u]=edges,to[edges]=v; 
}
struct oper {  
    int l,r,id;    
    bool operator<(const oper b) const {    
        return r<b.r;    
    }
}a[N];  
void init() {    
    last=tot=1;   
}
void extend(int c) {     
    int np=++tot,p=last;  
    mx[np]=mx[p]+1,last=np;  
    for(;p&&!ch[p][c];p=pre[p]) {  
        ch[p][c]=np; 
    } 
    if(!p) {  
        pre[np]=1; 
    } 
    else {  
        int q=ch[p][c];  
        if(mx[q]==mx[p]+1) pre[np]=q;   
        else {
            int nq=++tot; 
            mx[nq]=mx[p]+1;  
            pre[nq]=pre[q],pre[np]=pre[q]=nq;  
            memcpy(ch[nq],ch[q],sizeof(ch[q]));    
            for(;p&&ch[p][c]==q;p=pre[p]) { 
                ch[p][c]=nq;  
            }
        }
    }
    ++cnt[np];   
}     
void dfs(int x) {  
    fa[0][x]=pre[x];  
    for(int i=1;i<20;++i) fa[i][x]=fa[i-1][fa[i-1][x]];      
    for(int i=hd[x];i;i=nex[i]) {  
        dfs(to[i]); 
        cnt[x]+=cnt[to[i]];  
    }
}
int len;   
int trans(int x,int c) {  
    while(x&&!ch[x][c]) x=pre[x],len=mx[x];        
    if(ch[x][c]) { x=ch[x][c],++len;  return x; }     
    else return 1;   
}
namespace sol1 {     
    vector<int>q[B][B];   
    int main() {          
        for(int i=1;i<=m;++i) { 
            if(a[i].r<=k) {
                q[a[i].l][a[i].r].pb(i);  
            }
        }  
        int x,y,z;  
        for(int T=1;T<=Q;++T) {  
            scanf("%s%d%d",str+1,&x,&y);        
            ++x,++y; 
            ll cur=0; 
            z=1,len=0;       
            for(int j=1;j<=k;++j) {    
                z=trans(z,str[j]-'a');   
                for(int p=z,o=len;o;--o) {        
                    // [j-len+1,j]   
                    int l=j-o+1;  
                    int a1=lower_bound(q[l][j].begin(),q[l][j].end(),x)-q[l][j].begin();  
                    int a2=upper_bound(q[l][j].begin(),q[l][j].end(),y)-q[l][j].begin();                  
                    cur+=(a2-a1)*cnt[p];     
                    if(o-1==mx[pre[p]]) p=pre[p];          
                }
            }
            printf("%lld
",cur);  
        }
        return 0; 
    }
};     
int get_up(int x,int kth) {  
    for(int i=19;i>=0;--i) { 
        if(mx[fa[i][x]]>=kth) {   
            x=fa[i][x];  
        }
    }       
    return x;  
}
namespace sol2 { 
    int main() {     
        int x,y,z; 
        for(int i=1;i<=m;++i) a[i].id=i;  
        sort(a+1,a+1+m);          
        for(int T=1;T<=Q;++T) {     
            scanf("%s%d%d",str+1,&x,&y);  
            ++x,++y;    
            z=1,len=0;    
            ll cur=0; 
            int lst=1;  
            for(int j=1;j<=k;++j) {     
                z=trans(z,str[j]-'a');       
                while(a[lst].r<=j&&lst<=m) {            
                    if(a[lst].r-a[lst].l+1<=len&&a[lst].id>=x&&a[lst].id<=y) {     
                        int p=get_up(z,a[lst].r-a[lst].l+1);  
                        cur+=cnt[p];  
                    } 
                    ++lst;  
                }
            }
            printf("%lld
",cur); 
        }
        return 0;  
    }
};  
int main() { 
    // setIO("input");
    // freopen("input.out","w",stdout);              
    scanf("%d%d%d%d%s",&n,&m,&Q,&k,str+1);          
    init();  
    for(int i=1;i<=n;++i) {  
        extend(str[i]-'a');  
    }   
    for(int i=2;i<=tot;++i) {  
        add(pre[i],i); 
    }  
    dfs(1);  
    for(int i=1;i<=m;++i) {      
        scanf("%d%d",&a[i].l,&a[i].r);  
        ++a[i].l;  
        ++a[i].r;  
    }        
    if(k<403) sol1::main();  
    else sol2::main();  
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/13346673.html