Gym

题意:有一棵树,树上每个结点上有一个字母,有两种操作:

1)询问树上两点u,v间有向路径上有多少个字母和某个固定的字符串相匹配

2)将结点u的字母修改为x

树剖+线段,暴力维护前缀和后缀哈希值(正反都要维护)以及区间内匹配的个数,合并两区间时判断一下跨过分界点的情况就行了。由于被匹配的字符串长度不超过100,所以最多只需维护长度为100的前缀/后缀。

但即使这样复杂度也足足有$O(100nlog^2n)$啊,这常数是得有多小才能过掉...

注意各种条件判断和细节处理,还有就是这题内存比较吃紧,使用动态开点可以节省一半的内存。

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef unsigned long long ll;
  4 const int N=1e5+10,M=19260817;
  5 int hd[N],n,ne,m,Len,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt,ls[N<<1],rs[N<<1],tot2;
  6 ll H,pm[N];
  7 char s1[N],s2[N];
  8 struct D {
  9     int len,sum;
 10     ll L[101],R[101];
 11     D(int ch=0) {
 12         len=1;
 13         sum=(Len==1&&ch==H);
 14         L[1]=R[1]=ch;
 15     }
 16 } tr[N<<1][2];
 17 struct E {int v,nxt;} e[N<<1];
 18 D operator+(const D& a,const D& b) {
 19     if(a.L[1]==0)return b;
 20     if(b.L[1]==0)return a;
 21     D c;
 22     c.len=a.len+b.len;
 23     c.sum=a.sum+b.sum;
 24     for(int i=1; i<=min(100,c.len); ++i) {
 25         if(i<=a.len)c.L[i]=a.L[i];
 26         else c.L[i]=a.L[a.len]*pm[i-a.len]+b.L[i-a.len];
 27         if(i<=b.len)c.R[i]=b.R[i];
 28         else c.R[i]=a.R[i-b.len]*pm[b.len]+b.R[b.len];
 29     }
 30     for(int i=max(1,Len-b.len); i<=a.len&&i<=Len-1; ++i)if(a.R[i]*pm[Len-i]+b.L[Len-i]==H)c.sum++;
 31     return c;
 32 }
 33 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 34 void dfs1(int u,int f,int d) {
 35     fa[u]=f,son[u]=0,siz[u]=1,dep[u]=d;
 36     for(int i=hd[u]; ~i; i=e[i].nxt) {
 37         int v=e[i].v;
 38         if(v==fa[u])continue;
 39         dfs1(v,u,d+1),siz[u]+=siz[v];
 40         if(siz[v]>siz[son[u]])son[u]=v;
 41     }
 42 }
 43 void dfs2(int u,int tp) {
 44     top[u]=tp,dfn[u]=++tot,rnk[dfn[u]]=u;
 45     if(son[u])dfs2(son[u],top[u]);
 46     for(int i=hd[u]; ~i; i=e[i].nxt) {
 47         int v=e[i].v;
 48         if(v==fa[u]||v==son[u])continue;
 49         dfs2(v,v);
 50     }
 51 }
 52 #define mid ((l+r)>>1)
 53 void pu(int u) {
 54     tr[u][0]=tr[ls[u]][0]+tr[rs[u]][0];
 55     tr[u][1]=tr[rs[u]][1]+tr[ls[u]][1];
 56 }
 57 void upd(int p,int x,int& u=rt,int l=1,int r=tot) {
 58     if(l==r) {tr[u][0]=tr[u][1]=D(x); return;}
 59     p<=mid?upd(p,x,ls[u],l,mid):upd(p,x,rs[u],mid+1,r);
 60     pu(u);
 61 }
 62 void build(int& u=rt,int l=1,int r=tot) {
 63     u=++tot2;
 64     if(l==r) {tr[u][0]=tr[u][1]=D(s2[rnk[l]-1]); return;}
 65     build(ls[u],l,mid),build(rs[u],mid+1,r),pu(u);
 66 }
 67 D qry(int L,int R,int f,int u=rt,int l=1,int r=tot) {
 68     if(l>=L&&r<=R)return tr[u][f];
 69     if(l>R||r<L)return D(0);
 70     if(f==0)return qry(L,R,f,ls[u],l,mid)+qry(L,R,f,rs[u],mid+1,r);
 71     else return qry(L,R,f,rs[u],mid+1,r)+qry(L,R,f,ls[u],l,mid);
 72 }
 73 int qry2(int u,int v) {
 74     D L=D(0),R=D(0),M;
 75     for(; top[u]!=top[v];) {
 76         if(dep[top[u]]>dep[top[v]])L=L+qry(dfn[top[u]],dfn[u],1),u=fa[top[u]];
 77         else R=qry(dfn[top[v]],dfn[v],0)+R,v=fa[top[v]];
 78     }
 79     if(dep[u]>dep[v])M=qry(dfn[v],dfn[u],1);
 80     else M=qry(dfn[u],dfn[v],0);
 81     return (L+M+R).sum;
 82 }
 83 int main() {
 84     pm[0]=1;
 85     for(int i=1; i<N; ++i)pm[i]=pm[i-1]*M;
 86     memset(hd,-1,sizeof hd),ne=0;
 87     scanf("%d%d",&n,&m);
 88     scanf("%s%s",s1,s2),Len=strlen(s1);
 89     H=0;
 90     for(int i=0; i<Len; ++i)H=H*M+s1[i];
 91     for(int i=1; i<n; ++i) {
 92         int u,v;
 93         scanf("%d%d",&u,&v);
 94         addedge(u,v);
 95         addedge(v,u);
 96     }
 97     dfs1(1,0,1),dfs2(1,1);
 98     build(rt);
 99     while(m--) {
100         int f,u,v;
101         char ch;
102         scanf("%d",&f);
103         if(f==1)scanf("%d%d",&u,&v),printf("%d
",qry2(u,v));
104         else scanf("%d %c",&u,&ch),upd(dfn[u],ch);
105     }
106     return 0;
107 }

 当然还有$O(100nlogn)$的LCT毒瘤做法,代码比树剖短,需要特判的地方少,而且更省内存,但是常数巨大,需要优化下常数才能过,比如改个引用传递什么的。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef unsigned long long ll;
 4 const int N=1e5+10,M=19260817;
 5 int n,m,a[N],fa[N],ch[N][2],flp[N],sta[N],tp,Len;
 6 ll H,pm[N];
 7 char s1[N],s2[N];
 8 struct D {
 9     int len,sum;
10     ll L[101],R[101];
11     D(int ch=0) {
12         len=1;
13         sum=(Len==1&&ch==H);
14         L[1]=R[1]=ch;
15     }
16 } tr[N][2];
17 D operator+(const D& a,const D& b) {
18     if(a.L[1]==0)return b;
19     if(b.L[1]==0)return a;
20     D c;
21     c.len=a.len+b.len;
22     c.sum=a.sum+b.sum;
23     for(int i=1; i<=min(100,c.len); ++i) {
24         if(i<=a.len)c.L[i]=a.L[i];
25         else c.L[i]=a.L[a.len]*pm[i-a.len]+b.L[i-a.len];
26         if(i<=b.len)c.R[i]=b.R[i];
27         else c.R[i]=a.R[i-b.len]*pm[b.len]+b.R[b.len];
28     }
29     for(int i=max(1,Len-b.len); i<=a.len&&i<=Len-1; ++i)if(a.R[i]*pm[Len-i]+b.L[Len-i]==H)c.sum++;
30     return c;
31 }
32 #define l(u) ch[u][0]
33 #define r(u) ch[u][1]
34 void rev(int u) {flp[u]^=1,swap(l(u),r(u)),swap(tr[u][0],tr[u][1]);}
35 void pu(int u) {
36     tr[u][0]=tr[l(u)][0]+D(s2[u])+tr[r(u)][0];
37     tr[u][1]=tr[r(u)][1]+D(s2[u])+tr[l(u)][1];
38 }
39 void pd(int u) {if(flp[u])rev(l(u)),rev(r(u)),flp[u]=0;}
40 int sf(int u) {return u==r(fa[u]);}
41 bool isrt(int u) {return u!=l(fa[u])&&u!=r(fa[u]);}
42 void rot(int u) {
43     int v=fa[u],f=sf(u);
44     if(!isrt(v))ch[fa[v]][sf(v)]=u;
45     ch[v][f]=ch[u][f^1],fa[ch[v][f]]=v;
46     fa[u]=fa[v],ch[u][f^1]=v,fa[v]=u,pu(v);
47 }
48 void splay(int u) {
49     sta[tp=0]=u;
50     for(int v=u; !isrt(v); v=fa[v])sta[++tp]=fa[v];
51     for(; ~tp; pd(sta[tp--]));
52     for(; !isrt(u); rot(u))if(!isrt(fa[u])&&sf(fa[u])==sf(u))rot(fa[u]);
53     pu(u);
54 }
55 void access(int u) {for(int v=0; u; splay(u),r(u)=v,pu(u),u=fa[v=u]);}
56 void makert(int u) {access(u),splay(u),rev(u);}
57 void link(int u,int v) {makert(u),fa[u]=v;}
58 void join(int u,int v) {makert(u),access(v),splay(v);}
59 void upd(int u,int ch) {splay(u),s2[u]=ch;}
60 int qry(int u,int v) {join(u,v); return tr[v][0].sum;}
61 int main() {
62     pm[0]=1;
63     for(int i=1; i<N; ++i)pm[i]=pm[i-1]*M;
64     scanf("%d%d",&n,&m);
65     scanf("%s%s",s1,s2+1),Len=strlen(s1);
66     H=0;
67     for(int i=0; i<Len; ++i)H=H*M+s1[i];
68     for(int i=1; i<n; ++i) {
69         int u,v;
70         scanf("%d%d",&u,&v);
71         link(u,v);
72     }
73     while(m--) {
74         int f,u,v;
75         char ch;
76         scanf("%d",&f);
77         if(f==1)scanf("%d%d",&u,&v),printf("%d
",qry(u,v));
78         else scanf("%d %c",&u,&ch),upd(u,ch);
79     }
80     return 0;
81 }
原文地址:https://www.cnblogs.com/asdfsag/p/11623017.html