[loj3046]语言

定义$S_{i}$表示第$i$条链所包含的点的集合,$(x,y)$合法当且仅当$x e y$且$exists i,{x,y}subseteq S_{i}$(答案即$frac{合法点对数}{2}$),显然后者等价于$yin cup_{xin S_{i}}S_{i}$,因此合法点对数为$sum_{x=1}^{n}|cup_{xin S_{i}}S_{i}|-1$

结论:$链并的大小=链端点所构成的虚树点数=frac{按照dfs序排序后相邻(包括首尾)两点距离和}{2}+1$

前者显然,后者证明如下:

对每一条边统计经过次数,设其连结的深度较大的点为$x$,那么记$p_{i}=1$当且仅当$i$在$x$子树内(否则$p_{i}=0$),观察可得两个点$x$和$y$经过这条边当且仅当$p_{x}+p_{y}=1$

考虑dfs序的性质:每一个子树一定是一段区间,因此设端点按dfs序排序后为$a_{1},a_{2},...,a_{k}$,$S={i|p_{a_{i}}=1}$一定是一段区间$[l,r]$,观察可得当$[l,r]=emptyset$或$[l,r]=[1,k]$时该边答案为0,否则答案为2

考虑$[l,r]=emptyset$或$[l,r]=[1,k]$的条件,即等价于这条边不在虚树上,那么$frac{按照dfs序排序后相邻(包括首尾)两点距离和}{2}$即为边数,根据树的性质,加1即为点数

根据这个结论,将每条链差分并用线段树合并来找到所有端点,线段树上维护:1.个数(判断是否存在);2.区间最小点;3.区间最大点;4.区间相邻点距离和(最左和最右可以在外面算)即可,如果用st表维护lca可以做到$o(nlog_{2}n)$

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 100005
  4 #define mid (l+r>>1)
  5 struct ji{
  6     int nex,to;
  7 }edge[N<<1];
  8 int V,E,n,m,x,y,head[N],dfn[N],id[N],s[N],f[N][21],r[N],ls[N*100],rs[N*100],vis[N*100],mn[N*100],mx[N*100],sum[N*100];
  9 long long ans;
 10 void add(int x,int y){
 11     edge[E].nex=head[x];
 12     edge[E].to=y;
 13     head[x]=E++;
 14 }
 15 void dfs(int k,int fa,int sh){
 16     dfn[k]=++x;
 17     id[x]=k;
 18     s[k]=sh;
 19     f[k][0]=fa;
 20     for(int i=1;i<=20;i++)f[k][i]=f[f[k][i-1]][i-1];
 21     for(int i=head[k];i!=-1;i=edge[i].nex)
 22         if (edge[i].to!=fa)dfs(edge[i].to,k,sh+1);
 23 }
 24 int lca(int x,int y){
 25     if (s[x]<s[y])swap(x,y);
 26     for(int i=20;i>=0;i--)
 27         if (s[f[x][i]]>=s[y])x=f[x][i];
 28     if (x==y)return x;
 29     for(int i=20;i>=0;i--)
 30         if (f[x][i]!=f[y][i]){
 31             x=f[x][i];
 32             y=f[y][i];
 33         }
 34     return f[x][0];
 35 }
 36 int dis(int x,int y){
 37     return s[x]+s[y]-2*s[lca(x,y)];
 38 }
 39 void up(int k){
 40     mn[k]=min(mn[ls[k]],mn[rs[k]]);
 41     mx[k]=max(mx[ls[k]],mx[rs[k]]);
 42     sum[k]=sum[ls[k]]+sum[rs[k]];
 43     if ((mx[ls[k]])&&(mn[rs[k]]<=n))sum[k]+=dis(id[mx[ls[k]]],id[mn[rs[k]]]);
 44 }
 45 void update(int &k,int l,int r,int x,int y){
 46     if (!k){
 47         k=++V;
 48         mn[k]=n+1;
 49     }
 50     if (l==r){
 51         vis[k]+=y;
 52         if (vis[k]>0)mn[k]=mx[k]=l;
 53         else{
 54             mn[k]=n+1;
 55             mx[k]=0;
 56         }
 57         return;
 58     }
 59     if (x<=mid)update(ls[k],l,mid,x,y);
 60     else update(rs[k],mid+1,r,x,y);
 61     up(k);
 62 }
 63 int merge(int k1,int k2){
 64     if ((!k1)||(!k2))return k1+k2;
 65     if ((!ls[k1])&&(!rs[k1])){
 66         vis[k1]+=vis[k2];
 67         if (vis[k1]>0){
 68             mn[k1]=min(mn[k1],mn[k2]);
 69             mx[k1]=max(mx[k1],mx[k2]);
 70         }
 71         else{
 72             mn[k1]=n+1;
 73             mx[k1]=0;
 74         }
 75         return k1;
 76     }
 77     ls[k1]=merge(ls[k1],ls[k2]);
 78     rs[k1]=merge(rs[k1],rs[k2]);
 79     up(k1);
 80     return k1;
 81 }
 82 void dfs(int k,int fa){
 83     for(int i=head[k];i!=-1;i=edge[i].nex)
 84         if (edge[i].to!=fa){
 85             dfs(edge[i].to,k);
 86             r[k]=merge(r[k],r[edge[i].to]); 
 87         }
 88     if (mn[r[k]]!=mx[r[k]])ans+=sum[r[k]]+dis(id[mn[r[k]]],id[mx[r[k]]]);
 89 }
 90 int main(){
 91     scanf("%d%d",&n,&m);
 92     memset(head,-1,sizeof(head));
 93     for(int i=1;i<n;i++){
 94         scanf("%d%d",&x,&y);
 95         add(x,y);
 96         add(y,x);
 97     }
 98     x=0;
 99     dfs(1,0,1);
100     mn[0]=n+1;
101     for(int i=1;i<=m;i++){
102         scanf("%d%d",&x,&y);
103         int z=lca(x,y);
104         update(r[x],1,n,dfn[x],1);
105         update(r[x],1,n,dfn[y],1);
106         update(r[y],1,n,dfn[x],1);
107         update(r[y],1,n,dfn[y],1);
108         update(r[f[z][0]],1,n,dfn[x],-2);
109         update(r[f[z][0]],1,n,dfn[y],-2);
110     }
111     dfs(1,0);
112     printf("%lld",ans/4);
113 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/13657458.html