P4220 [WC2018]通道 虚树+边分治

题意:

戳这里

分析:

  • 前置芝士: 边分治 虚树 (O(nlog)) 预处理 (O(1) lca)

题意让我们求三颗树上 (sum dis(x,y)) 最大是多少,这种树上距离问题大部分可以通过树分治解决

我们首先考虑最常见的点分治,发现这个题似乎不能用点分治解决,因为点分治的核心思想是通过容斥来在 (log) 次递归中计算出想要的答案,但这个题要求的 (max) 似乎不支持容斥的计算

那么我们考虑边分治,常见步骤大概就是:

  1. 多叉树转二叉树
  2. 找到中心的边
  3. 统计经过该边的所有答案
  4. 递归分治左右两个连通块

tip: 边分治时一定要,转化成二叉树,不然菊花图直接暴毙 (O(n^2))

所以有了思路之后,我们对于第一棵树进行边分治,每次找到中心边 (u o v),之后我们可以将式子转化

(ans=dis_1(x,u)+dis_1(v,y)+dis_2(x,y)+dis_3(x,y)+w(u,v))

其中 (w(u,v)) 是常量我们暂且不考虑,接着推柿子

(ans=dis_1(x,u)+dis_1(v,y)+dep_2(x)+dep_2(y)+dep_3(x)+dep_3(y)-2 imes(dep_2(lca(x,y))+dep_3(lca(x,y))))

(ans=(dis_1(x,u)+dep_2(x)+dep_3(x))+(dis_1(y,v)+dep_2(y)+dep_3(y))-2 imes(dep_2(lca(x,y))+dep_3(lca(x,y))))

我们令 (f(x)=dis_1(x,u/v)+dep_2(x)+dep_3(x)) 那么

(ans=f(x)+f(y)--2 imes(dep_2(lca(x,y))+dep_3(lca(x,y))))

由于边分治我们建出的是一颗二叉树,所以 (x)(y) 来自两个不同的连通块,我们可以对两个连通块分别进行染色,由于我们可以预处理并快速的求出 (f) 数组,那么我们只需要考虑如何快速计算 (dep_2(lca))(dep_3(lca))我们求**点集中的点的 (lca) ** 这个操作很 虚树 ,那么我们考虑建出虚树,然后枚举 (lca) ,那么 (dep_2(lca)) 也就成了常量,我们的任务变成了:

在虚树上找一对点 ((x,y)) 满足他们的 (lca)(z) ,且他们的颜色不同,同时尽可能最大化 (ans = f(x)+f(y)-2 imes dep_3(lca(x,y))+k) 其中 (k) 是一个常量包含 (w(u,v),2 imes dep_2(lca))

由于这个式子长得很树上两点间距离同时我们要尽可能最大化这个式子,所以我们要求的东西其实就变成了一个特殊的树的直径

然后这里需要一个结论来帮助我们求这个直径:

在一颗边权均为正的树上,存在两个点集

对于一个点集 (s) 它的直径两端是 (a,b)

对于另一个点集 (t) 它的直径两端是 (c,d)

那么分别以 (s)(t) 点集中的点为直径的两端,这条直径一定是 ((a,c),(a,d),(b,c),(b,d)) 中的一条

放到这个题上,我们在第三棵树上给每一个点连一个虚点边权为 (dis_1(x,u/v)+dep_2(x)) 然后我们求出树的直径就是答案,由于边权均为正,所以可以利用上面的结论

具体做法就是,我们开一个结构体,记 (dp(u,0)) 表示 (u) 子树中白点形成的点集 的直径两端, (dp(u,1)) 表示 (u) 的子树中黑点形成的点集 的直径两端,然后每次进行点集的合并,求出跨点集的直径两端

复杂度 (O(nlog ^2)) 边分治一个 (log) 虚树也有一个 (log) 至于求直径的 (DP) 可以和虚树一起做,所以只有两只 (log)

tip:

  1. (st) 表的预处理和查询的右端点不一样/kk
  2. 多叉树转二叉树后,新建的节点也有 (siz) 找中心边是也要考虑这些点的大小

代码:

毒瘤出题人,这个题码量和 猪国杀有的一拼了

#include<bits/stdc++.h>
#define pii pair<long long,long long>
#define mk(x,y) make_pair(x,y)
#define lc rt<<1
#define rc rt<<1|1
#define pb push_back
#define fir first
#define sec second
#define int long long
using namespace std;

namespace zzc
{
	inline long long read()
	{
		long long x=0,f=1;char ch=getchar();
		while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
		while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
		return x*f;
	}
    
    const long long maxn = 5e5+5;
    long long n,ans,top,tmp;
    long long lg[maxn],dis[maxn],st[maxn],a[maxn],col[maxn];
    bool vis[maxn];
	struct edge
    {
        long long to,nxt,val;
    };
    
    struct tree23
    {
        long long cnt,idx;
		long long head[maxn],st[maxn][20],dfn[maxn],len[maxn],dep[maxn];
		edge e[maxn];
		
		long long Min(long long x,long long y)
		{
			return dep[x]<dep[y]?x:y;
		}
		
		void add(long long u,long long v,long long w)
		{
			e[++cnt].to=v;
			e[cnt].nxt=head[u];
			e[cnt].val=w;
			head[u]=cnt;
		}
		
		void dfs(long long u,long long ff)
		{
			st[++idx][0]=u;dfn[u]=idx;dep[u]=dep[ff]+1;
			for(long long i=head[u];i;i=e[i].nxt)
			{
				long long v=e[i].to;
				if(v==ff) continue;
				len[v]=len[u]+e[i].val;
				dfs(v,u);
				st[++idx][0]=u;
			}
		}
		
		void init()
		{
			long long a,b,c;
			for(long long i=1;i<n;i++)
			{
				a=read();b=read();c=read();
				add(a,b,c);add(b,a,c);
			}
			dfs(1,0);
			for(long long j=1;j<=19;j++)
			{
				for(long long i=1;i+(1<<j)-1<=idx;i++)
				{
					st[i][j]=Min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
				}
			}
		}
		
		long long lca(long long x,long long y)
		{
			if(dfn[x]>dfn[y]) swap(x,y);
			long long k=lg[dfn[y]-dfn[x]+1];
			return Min(st[dfn[x]][k],st[dfn[y]-(1<<k)+1][k]);
		}
		
		long long get_dis(long long x,long long y)
		{
			return len[x]+len[y]-2*len[lca(x,y)];
        }
    }t2,t3;
    
	long long calc(long long x,long long y)
    {
        if(!x||!y) return 0;
        return dis[x]+dis[y]+t2.len[x]+t2.len[y]+t3.get_dis(x,y);
    }
    
    struct node
    {
        long long dis,x,y;
        node(){x=0;y=0;dis=0;}
        node(const long long &_x,const long long &_y){x=_x;y=_y;dis=calc(_x,_y);}
        node(const long long &_x,const long long &_y,const long long &_dis){x=_x;y=_y;dis=_dis;}
        
    }dp[maxn][2];
	bool operator <(node a,node b){return a.dis<b.dis;}
    node operator + (node a,node b)
    {
         if(!b.x) return a;
         if(!a.x) return b;
         node res=max(a,b);
         res=max(res,max(max(max(node(a.x,b.y),node(a.y,b.x)),node(a.x,b.x)),node(a.y,b.y)));
         return res;
    }

    bool cmp(long long x,long long y)
    {
        return t2.dfn[x]<t2.dfn[y];
    }

    void pushup(long long x,long long y)
    {
        ans=max(ans,calc(dp[x][0].x,dp[y][1].x)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][0].x,dp[y][1].y)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][0].y,dp[y][1].x)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][0].y,dp[y][1].y)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][1].x,dp[y][0].x)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][1].x,dp[y][0].y)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][1].y,dp[y][0].x)-t2.len[x]*2);
        ans=max(ans,calc(dp[x][1].y,dp[y][0].y)-t2.len[x]*2);
        dp[x][0]=dp[x][0]+dp[y][0];dp[x][1]=dp[x][1]+dp[y][1];
    }

    void DP()
    {
        sort(a+1,a+tmp+1,cmp);
        for(int i=1;i<=tmp;i++) vis[a[i]]=true;
        int old=tmp;top=0;
        for(int i=1;i<=tmp;i++)
        {
            dp[a[i]][col[a[i]]]=node(a[i],a[i],0);dp[a[i]][col[a[i]]^1]=node();
            int lca=t2.lca(a[i],st[top]);
            if(!vis[lca])
            {
                a[++old]=lca;
                vis[lca]=true;
                dp[lca][0]=dp[lca][1]=node();
            }
            while(t2.dfn[lca]<t2.dfn[st[top]])
            {
            	if(t2.dfn[lca]<=t2.dfn[st[top-1]])
            	{
            		pushup(st[top-1],st[top]);
            		top--;
				}
				else
				{
					pushup(lca,st[top]);
					st[top]=lca;
					break;
				}
			}
            st[++top]=a[i];
        }
        while(top>1) pushup(st[top-1],st[top]),top--;
        for(int i=1;i<=old;i++) vis[a[i]]=false;
    }
	
	struct tree1
	{
        int cnt=1,head[maxn],tot,rt,siz[maxn];
        bool vis[maxn<<2];
        edge e[maxn<<2];
        vector<pii >g[maxn];
        void add(int u,int v,long long w)
        {
            e[++cnt].to=v;
            e[cnt].nxt=head[u];
            head[u]=cnt;
            e[cnt].val=w;
        }
        void rebuild(int u,int ff)
		{
			int lst=u;
			for(auto v:g[u])
			{
				if(v.fir==ff) continue;
				add(v.fir,lst,v.sec);add(lst,v.fir,v.sec);
				add(lst,++tot,0);add(tot,lst,0);lst=tot;
				rebuild(v.fir,u);
			}
		}
        void get_rt(int u,int ff)
		{
			siz[u]=1;
			for(int i=head[u];i!=-1;i=e[i].nxt)
			{
				int v=e[i].to;
				if(v==ff||vis[i]) continue;
				get_rt(v,u);
				siz[u]+=siz[v];
				if(rt==-1||max(siz[v],tot-siz[v])<max(siz[e[rt].to],tot-siz[e[rt].to])) rt=i;
			}
		}
        void get_dis(int u,int ff,int op)
		{
			col[u]=op;
			if(u<=n) a[++tmp]=u;
			for(int i=head[u];i!=-1;i=e[i].nxt)
			{
				int v=e[i].to;
				if(v==ff||vis[i]) continue;
				dis[v]=dis[u]+e[i].val;
				get_dis(v,u,op);
			}
		}
		void solve(int i)
		{
			if(i==-1) return ;
			vis[i]=vis[i^1]=true;
			int u=e[i].to,v=e[i^1].to;
			tmp=0;dis[u]=0;dis[v]=e[i].val;
			get_dis(u,0,0);get_dis(v,0,1);
			DP();
			tot=siz[u];rt=-1;
			get_rt(u,0);
			solve(rt);
			tot=siz[v];rt=-1;
			get_rt(v,0);
			solve(rt);
		}
        void work()
		{
			memset(head,-1,sizeof(head));
			cnt=1;rt=-1;tot=n;
			rebuild(1,0);
			get_rt(1,0);
			solve(rt);
		}
    }t1;
	
    void work()
	{
		long long a,b,c;
	    lg[0]=-1;for(long long i=1;i<=500000;i++) lg[i]=lg[i>>1]+1;
		n=read();
		for(int i=1;i<n;i++)
		{
			a=read();b=read();c=read();
			t1.g[a].pb(mk(b,c));t1.g[b].pb(mk(a,c));
		}
		t2.init();
		t3.init();
		t1.work();
		printf("%lld
",ans);
	}

}

signed main()
{
	zzc::work();
	return 0;
}

原文地址:https://www.cnblogs.com/youth518/p/14241121.html