题意
有两棵以(1)为根的(n)节点有编号树。
操作时选出(A)树中的边((p,q)),对于(B)树中的((x,y)),若满足在(A)树中,(x, y) 两个顶点中恰好有且只有一个同时在顶点(p,q)的公共子树中,则删去((x,y))。(这里只考虑初始未被删边的两棵树)
开始时删除给出的(A)树中的一条边,然后删去(B)中满足条件的边,再对于上一组删去的边,在(A)树中操作,一直重复,直到没有可删的边。求每次删去边的编号。
(2 leq n leq 2 imes 10^5)
思路
考虑当前边为A中的(p,q),要删去的边为B中的(x,y),那么在A中,((p,q))一定是((x,y))路径(链)上的点。也就是说一棵树上的边对应另一棵树上的一条链,对方树链中任一条边被删,这条边下一次就会被删掉。容易想到要用树链剖分。
建两棵线段树,将每条边在对方树链区间上打上删除标记(这里没有修改,可以使用标记永久化),每条边对应(log)个dfs序区间,对应(log^2)个线段树区间,所以最多(nlog^2n)个标记。每次钦定一条边后,找到该边对应的线段树节点,沿路即可知道对应要删的边。由于每条边只会被删除一次,当一个线段树节点被访问过后,后来的访问一定是无用的(已经被删了),所以时间复杂度就是(O(n log^2n))
#include <bits/stdc++.h>
using std::vector;
const int N=200005;
int cnt,now,n,x;
vector<int> ans[2];
struct{
int son[N],dfn[N],size[N],tp[N],deep[N],id,vis[N<<2],c[N],f[N];
vector<int> e[N],tag[N<<2];
void dfs(int x){
size[x]=1;
for (auto u:e[x]){
deep[u]=deep[x]+1,dfs(u);
if (size[u]>size[son[x]]) son[x]=u;
size[x]+=size[u];
}
}
void dfs(int x,int top){
dfn[x]=++id,tp[x]=top;
if (son[x]) dfs(son[x],top);
for (auto u:e[x]){
if (u==son[x]) continue;
dfs(u,u);
}
}
void add(int k,int l,int r,int L,int R,int x){
if (l==L && r==R){
tag[k].push_back(x);
return;
}
int mid=(L+R)>>1;
if (r<=mid) add(k<<1,l,r,L,mid,x);
else if (l>mid) add(k<<1|1,l,r,mid+1,R,x);
else{
add(k<<1,l,mid,L,mid,x);
add(k<<1|1,mid+1,r,mid+1,R,x);
}
}
void del(int k,int L,int R,int x){
if (!vis[k]){
for (auto x:tag[k]){
if (!c[x]){
cnt++;
ans[now].push_back(x);
c[x]=1;
}
}
vis[k]=1;
}
if (L==R) return;
int mid=(L+R)>>1;
if (x<=mid) del(k<<1,L,mid,x);
else del(k<<1|1,mid+1,R,x);
}
void solve(int x,int y,int tag){
while (tp[x]!=tp[y]){
if (deep[tp[x]]<deep[tp[y]]) std::swap(x,y);
add(1,dfn[tp[x]],dfn[x],1,n,tag);
x=f[tp[x]];
}
if (deep[x]>deep[y]) std::swap(x,y);
if (x!=y) add(1,dfn[x]+1,dfn[y],1,n,tag);
}
}T[2];
int main(){
scanf("%d",&n);
for (int i=2;i<=n;i++){
scanf("%d",&T[0].f[i]);
T[0].e[T[0].f[i]].push_back(i);
}
for (int i=2;i<=n;i++){
scanf("%d",&T[1].f[i]);
T[1].e[T[1].f[i]].push_back(i);
}
T[0].deep[1]=1,T[1].deep[1]=1;
T[0].dfs(1),T[0].dfs(1,1);
T[1].dfs(1),T[1].dfs(1,1);
for (int i=2;i<=n;i++){
T[1].solve(i,T[0].f[i],i);
T[0].solve(i,T[1].f[i],i);
}
now=1;
scanf("%d",&x);
printf("Blue
%d
",x);
ans[0].push_back(x+1);
T[1].c[x+1]=1;
do{
cnt=0;
ans[now].clear();
for (auto j:ans[now^1]) T[now^1].del(1,1,n,T[now^1].dfn[j]);
if (!cnt) break;
std::sort(ans[now].begin(),ans[now].end());
puts(now?"Red":"Blue");
for (auto j:ans[now]) printf("%d ",j-1);
puts("");
now=now^1;
}while(1);
}