[LOJ3280] JOISC2020 首都城市

问题描述

在 JOI 的国度有 N 个小镇,从 1 到 N 编号,并由 N−1 条双向道路连接。第 i 条道路连接了 Ai 和 Bi 这两个编号的小镇。

这个国家的国王现将整个国家分为 K 个城市,从 1 到 K 编号,每个城市都有附属的小镇,其中编号为 j 的小镇属于编号为 Cj 的城市。每个城市至少有一个附属小镇。

国王还要选定一个首都。首都的条件是该城市的任意小镇都只能通过属于该城市的小镇到达。

但是现在可能不存在这样的选址,所以国王还需要将一些城市进行合并。对于合并城市 x 和 y ,指的是将所有属于 y 的小镇划归给 x 城。

你需要求出最少的合并次数。

输入格式

输入第一行两个整数 N,K,为小镇和城市的数量。

接下来的 N−1 行,每行两个整数 Ai,Bi,描述了 N−1 条道路。

再接下来的 N 行,每行一个整数 Cj,表示编号为 j 的小镇属于编号为 Cj 的城市。

输出格式

输出一行一个整数为最少的合并次数。

样例输入

6 3
2 1
3 5
6 2
3 4
2 3
1
3
1
2
3
2

样例输出

1

解析

我们先单独考虑某一种颜色。如果在首都中要包括这种颜色的城镇,那么两个该种颜色的点之间的所有点都必须合并到一起。将其转化为图上的关系,设颜色 i 向颜色 j 连边表示将颜色 i 的城镇合并到颜色 j 中。因此,我们不妨将每一种颜色的虚树建出来,然后就可以方便地在原树上寻找相同颜色之间的点了。

然而,这样连边是 (n^2) 的。考虑到每次连边的对象都是树上的一段路径,我们可以用树链剖分优化连边。连边的方式与线段树优化连边类似。为了实现从颜色向颜色连边,我们让每一种颜色向线段树的叶子节点(实际是原树上的节点)连边。对于每一棵虚树,我们用树链剖分找到虚树上两点之间对应的路径,然后从路径对应的区间向虚树的颜色连边即可。

这样,我们得到了一张有向图。对于构成环的合并关系,我们可以将其直接合并,相当于缩点,将原图转化为一个DAG。考虑入度为0的点,这样的点一定是不需要额外合并的,可以作为首都。因此,答案就是DAG上最小的入度为0的点。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#define N 200002
using namespace std;
vector<int> v[N],c[N],a;
int head[N*8],ver[N*16],nxt[N*16],l;
int n,k,i,j,col[N],dep[N],son[N],size[N],fa[N],top[N],pos[N],in[N],out[N],cnt,tot;
int dfn[N*8],low[N*8],s[N*8],T,tim,sccno[N*8],num[N*8],deg[N*8];
int read()
{
	char c=getchar();
	int w=0;
	while(c<'0'||c>'9') c=getchar();
	while(c<='9'&&c>='0'){
		w=w*10+c-'0';
		c=getchar();
	}
	return w;
}
void insert(int x,int y)
{
	l++;
	ver[l]=y;
	nxt[l]=head[x];
	head[x]=l;
}
void dfs1(int x,int pre)
{
	fa[x]=pre;
	dep[x]=dep[pre]+1;
	size[x]=1;
	for(int i=head[x];i;i=nxt[i]){
		int y=ver[i];
		if(y!=pre){
			dfs1(y,x);
			size[x]+=size[y];
			if(size[y]>size[son[x]]) son[x]=y;
		}
	}
}
void dfs2(int x,int t)
{
	top[x]=t;
	in[x]=++cnt;
	pos[cnt]=x;
	if(son[x]) dfs2(son[x],t);
	for(int i=head[x];i;i=nxt[i]){
		int y=ver[i];
		if(y!=fa[x]&&y!=son[x]) dfs2(y,y);
	}
}
void build(int p,int l,int r)
{
	tot++;
    if(l==r){
        insert(p+k,col[pos[l]]);
        return;
    }
    int mid=(l+r)/2;
    insert(p+k,p*2+k);insert(p+k,p*2+1+k);
    build(p*2,l,mid);build(p*2+1,mid+1,r);
}
void link(int p,int l,int r,int ql,int qr,int c)
{
    if(ql<=l&&r<=qr){
        insert(c,p+k);
        return;
    }
    int mid=(l+r)/2;
    if(ql<=mid) link(p*2,l,mid,ql,qr,c);
    if(qr>mid) link(p*2+1,mid+1,r,ql,qr,c);
}
int LCA(int u,int v)
{
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		u=fa[top[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	return u;
}
void split(int c,int u,int v)
{
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        link(1,1,n,in[top[u]],in[u],c);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    link(1,1,n,in[u],in[v],c);
}
int my_comp(const int &x,const int &y)
{
	return in[x]<in[y];
}
void dfs(int c,int x)
{
    for(int i=0;i<v[x].size();i++){
        split(c,x,v[x][i]);
        dfs(c,v[x][i]);
    }
}
void Tarjan(int x)
{
	dfn[x]=low[x]=++tim;
	s[++T]=x;
	for(int i=head[x];i;i=nxt[i]){
		int y=ver[i];
		if(!dfn[y]){
			Tarjan(y);
			low[x]=min(low[x],low[y]);
		}
		else if(!sccno[y]) low[x]=min(low[x],dfn[y]);
	}
	if(dfn[x]==low[x]){
		cnt++;
		while(1){
			int y=s[T--];
			sccno[y]=cnt;
			if(y<=k) num[cnt]++;
			if(y==x) break;
		}
	}
}
int main()
{
	n=read();k=read();
	tot=k;
	for(i=1;i<n;i++){
		int u=read(),v=read();
		insert(u,v);
		insert(v,u);
	}
	for(i=1;i<=n;i++){
		col[i]=read();
		c[col[i]].push_back(i);
	}
	dfs1(1,0);dfs2(1,1);
    memset(head,0,sizeof(head));l=0;
    build(1,1,n);
	for(i=1;i<=k;i++){
        if(!c[i].size()) continue;
        a.clear();
		sort(c[i].begin(),c[i].end(),my_comp);
		s[1]=T=1;
        a.push_back(1);
		for(j=0;j<c[i].size();j++){
            if(c[i][j]==1) continue;
			int x=c[i][j],lca=LCA(x,s[T]);
			if(lca==s[T]){
				s[++T]=x;
                a.push_back(x);
				continue;
			}
			while(T>1&&in[s[T-1]]>=in[lca]) v[s[T-1]].push_back(s[T]),T--;
			if(lca!=s[T]) v[lca].push_back(s[T]),s[T]=lca,a.push_back(lca);
			s[++T]=x;
            a.push_back(x);
		}
        while(T>1) v[s[T-1]].push_back(s[T]),T--;
        sort(a.begin(),a.end(),my_comp);
        if(v[1].size()==1&&col[1]!=i) dfs(i,a[1]);
        else dfs(i,1);
        for(j=0;j<a.size();j++) v[a[j]].clear();
	}
    T=cnt=0;
    for(i=1;i<=tot;i++){
        if(!dfn[i]) Tarjan(i);
    }
    for(i=1;i<=tot;i++){
        for(j=head[i];j;j=nxt[j]){
            if(sccno[i]!=sccno[ver[j]]) deg[sccno[i]]++;
        }
    }
    int ans=1<<30;
    for(i=1;i<=cnt;i++){
        if(deg[i]==0&&num[i]!=0) ans=min(ans,num[i]-1);
    }
    printf("%d
",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/LSlzf/p/12997390.html