[Luogu5327][ZJOI2019]语言(树上差分+线段树合并)

首先可以想到对每个点统计出所有经过它的链的并所包含的点数,然后可以直接得到答案。根据实现不同有下面几种方法。
三个log:假如对每个点都存下经过它的链并S[x],那么每新加一条路径进来的时候,相当于在路径上所有点的S中都加入这条路径。树剖之后,相当于对log个区间中的点都加入log个区间。具体实现有树剖后线段树维护虚树、矩形扫描线、线段树+set存区间等多种方法,这里不再多说。
两个log:先树剖,然后对每个点开一棵线段树存储它的S,由于题中没有修改,所以可以树上差分+线段树合并,这样可以将方法一中“需要修改的区间数”的log去掉了。
一个log:发现就是对每个点求所有经过它的路径的端点的斯坦纳树(这里一个点集的斯坦纳树是指原树上最小的点集,满足包含这个点集且连通)。考虑如何暴力求一个点集的斯坦纳树,那显然就是将它们按DFS序排序后,所有点深度之和减去每对相邻点LCA的深度和。为了方便我们将点集中强制加入根,最后求出结果后再减去所有点LCA的深度的两倍。以DFS序为下标建线段树,每个点维护它所代表的DFS区间中,所有在点集中的点(加上根)构成的斯坦纳树的大小。两个区间的合并就是两边的斯坦纳树大小之和,减去左边区间里在点集中的DFS序最大的点与右边区间里在点集中的DFS序最小的点的LCA的深度,于是再维护区间里在点集中的DFS序最大/小的点分别是谁即可。同样使用树上差分+线段树合并,就可以将方法一中“每个修改区间中要加入的区间数”的log去掉了。

(参考https://www.luogu.org/blog/Sooke/solution-p5327

 1 #include<cstdio>
 2 #include<vector>
 3 #include<algorithm>
 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++)
 5 #define For(i,x) for (int i=h[x],k; i; i=nxt[i])
 6 typedef long long ll;
 7 using namespace std;
 8 
 9 const int N=200010,M=6400010,K=18;
10 int n,m,x,y,tim,cnt,nd,d[N],lg[N],rt[N],fa[N],dfn[N],st[N][20];
11 int v[M],ls[M],rs[M],s[M],t[M],c[M],h[N],to[N],nxt[N];
12 ll ans;
13 vector<int>del[N];
14 
15 void add(int u,int v){ to[++cnt]=v; nxt[cnt]=h[u]; h[u]=cnt; }
16 
17 void dfs(int x){
18     d[x]=d[fa[x]]+1; st[++tim][0]=x; dfn[x]=tim;
19     For(i,x) if ((k=to[i])!=fa[x]) fa[k]=x,dfs(k),st[++tim][0]=x;
20 }
21 
22 void init(){
23     rep(j,1,lg[tim]) rep(i,1,tim-(1<<j)+1){
24         int x=st[i][j-1],y=st[i+(1<<(j-1))][j-1];
25         st[i][j]=d[x]<d[y] ? x : y;
26     }
27 }
28 
29 int lca(int x,int y){
30     if (!x || !y) return 0;
31     x=dfn[x]; y=dfn[y];
32     if (x>y) swap(x,y);
33     int t=lg[y-x+1]; x=st[x][t]; y=st[y-(1<<t)+1][t];
34     return d[x]<d[y] ? x : y;
35 }
36 
37 void upd(int x){
38     v[x]=v[ls[x]]+v[rs[x]]-d[lca(t[ls[x]],s[rs[x]])];
39     s[x]=s[ls[x]] ? s[ls[x]] : s[rs[x]];
40     t[x]=t[rs[x]] ? t[rs[x]] : t[ls[x]];
41 }
42 
43 void mdf(int &x,int L,int R,int p,int k){
44     if (!x) x=++nd;
45     if (L==R){ c[x]+=k; v[x]=(c[x]?d[p]:0); s[x]=t[x]=(c[x]?p:0); return; }
46     int mid=(L+R)>>1;
47     if (dfn[p]<=mid) mdf(ls[x],L,mid,p,k); else mdf(rs[x],mid+1,R,p,k);
48     upd(x);
49 }
50 
51 int merge(int x,int y,int L,int R){
52     if (!x || !y) return x|y;
53     if (L==R){ c[x]+=c[y]; v[x]|=v[y]; s[x]|=s[y]; t[x]|=t[y]; return x; }
54     int mid=(L+R)>>1; ls[x]=merge(ls[x],ls[y],L,mid); rs[x]=merge(rs[x],rs[y],mid+1,R);
55     upd(x); return x;
56 }
57 
58 void solve(int x){
59     For(i,x) if ((k=to[i])!=fa[x]) solve(k);
60     int ed=del[x].size()-1;
61     rep(i,0,ed) mdf(rt[x],1,tim,del[x][i],-1);
62     ans+=v[rt[x]]-d[lca(s[rt[x]],t[rt[x]])]; rt[fa[x]]=merge(rt[fa[x]],rt[x],1,tim);
63 }
64 
65 int main(){
66     freopen("a.in","r",stdin);
67     freopen("a.out","w",stdout);
68     scanf("%d%d",&n,&m);
69     rep(i,2,n<<1) lg[i]=lg[i>>1]+1;
70     rep(i,2,n) scanf("%d%d",&x,&y),add(x,y),add(y,x);
71     dfs(1); init();
72     rep(i,1,m){
73         scanf("%d%d",&x,&y); int l=lca(x,y);
74         mdf(rt[x],1,tim,x,1); mdf(rt[x],1,tim,y,1);
75         mdf(rt[y],1,tim,x,1); mdf(rt[y],1,tim,y,1);
76         del[l].push_back(x); del[l].push_back(y);
77         del[fa[l]].push_back(x); del[fa[l]].push_back(y);
78     }
79     solve(1); printf("%lld
",ans/2);
80     return 0;
81 }
原文地址:https://www.cnblogs.com/HocRiser/p/10805499.html