caioj 2064 & POJ 1741 & CH 0x40数据结构进阶(0x45 点分治)例题1:树

传送门
这真是一道毒瘤入门题 ,一连做了两天,poj的数据竟然卡memsetmemset.QwQ
心力憔悴~~~~~~

思路:

很明显,对于u,v之间的合法路径一定满足以下条件之一:

  1. 经过根节点.(u或v为根节点也算)
  2. 不经过根节点.

废话!
我们可以只处理第1种情况,对于第2种情况直接分治.

有一个细节需要注意:就是路径上边不能重复走,或者说点不能在根节点的同一颗子树中.
b[x]b[x]表示x隶属于根节点的哪个子树,dis[x]xdis[x]表示x与根节点的距离
那么u,v之间的路径合法,必须满足:

  1. b[u]b[v]b[u] e b[v]
  2. dis[u]+dis[v]kdis[u]+dis[v]le k

我们把子树内所有点塞到一个数组cc中,并按disdis排序.
我们用两个指针l,rl,r扫描数组.
可以发现为了满足条件2,l,r.l在往右移时,r只能往左移.
为了满足条件1.我们定义一个数组cnt,cnt[x]cnt,cnt[x]表示在l+1l+1~rr中隶属x的节点数.
则本次操作对答案的贡献为rlcnt[c[l]]r-l-cnt[c[l]]

复杂度计算:

如果子树有nnnn个点,那么一次处理的复杂度为O(nnlog2nn)O(nn*log_2 nn).
由于递归,所以复杂度为O(Tnlog2n)O(T*n*log_2n)(T为递归层数)

关于层数的优化:

为了尽可能减小递归层数,我们只能在每次处理时找到子树中的重心.

树的重心是啥?
树的重心为满足最大子树节点数最小的点.(有点拗口)

这样我们就可以在O(nlog2n)O(nlog^2 n)的复杂度下求解了.

温馨提示: POJ数据会卡memset,快读快写让代码跑得飞快.

细节有点多,仔细看代码吧.

#include<cstdio>
#include<cstring>
#include<algorithm>
#define g g()
#define mes(x,y) memset(x,y,sizeof(x));
using namespace std;
const int N=10010;
struct edge{int y,next,d;}a[N<<1];int len,last[N];
void ins(int x,int y,int d){a[++len].y=y;a[len].next=last[x];a[len].d=d;last[x]=len;}
int ans,n,m;//m为题目的k
bool v[N],w[N];//v表示是否被遍历过(在子树中),w表示是否对答案产生贡献
int dis[N],size[N],b[N],c[N],tot,cnt[N],l,r;
//dis表示深度,size表示子树大小,b[i]表示i在那棵子树,c装子树内节点的深度,cnt[i]表示c[l+1],c[l+2]~~c[r]中属于i的子树的节点数
int temp,pos;//记录重心信息:最大子树的节点数,标号

void dfs_find(int S,int x)//求重心——S为(以pos为根的)总结点数(详见calc)
{
	v[x]=size[x]=1;
	int maxx=0;//最大子树大小 
	for(int k=last[x];k;k=a[k].next)
	{
		int y=a[k].y;
		if(v[y]||w[y])continue;
		dfs_find(S,y);
		size[x]+=size[y];
		maxx=max(maxx,size[y]);
	}
	maxx=max(maxx,S-size[x]);
	if(maxx<temp)temp=maxx,pos=x;
	v[x]=0;
}

void dfs(int x)//重求距离
{
	v[x]=1;
	for(int k=last[x];k;k=a[k].next)
	{
		int y=a[k].y;
		if(v[y]||w[y])continue;
		dis[y]=dis[x]+a[k].d;
		cnt[b[c[++tot]=y]=b[x]]++;//注意理解
		dfs(y);
	}
	v[x]=0;
}

bool cmp(int i,int j){return dis[i]<dis[j];}

void calc(int S,int x)
{
	temp=S;dfs_find(S,x);
	dis[pos]=0;
	w[c[tot=1]=b[pos]=pos]=1;cnt[pos]=1;
	for(int k=last[pos];k;k=a[k].next)//不要把x和pos弄混
	{
		int y=a[k].y;
		if(w[y])continue;
		dis[y]=a[k].d;
		cnt[c[++tot]=b[y]=y]=1;
		dfs(y);
	}
	sort(c+1,c+tot+1,cmp);
	l=1;r=tot;
	cnt[b[c[1]]]--;
	while(l<r)
	{
		while(dis[c[l]]+dis[c[r]]>m)cnt[b[c[r--]]]--;
		ans+=r-l-cnt[b[c[l]]];
		cnt[b[c[++l]]]--;
	}
	for(int k=last[pos];k;k=a[k].next)
	{
		int y=a[k].y;
		if(w[y])continue;
		calc(size[y],y);
	}
}

//快读、快写大法好
const int ss=1<<20;
char buf[ss],*p1=buf,*p2=buf;
inline char g{return p1==p2&&(p2=(p1=buf)+fread(buf,1,ss,stdin),p1==p2)?EOF:*p1++;}
void qr(int &x)
{
	char c=g;bool v=(x=0);
	while(!( ('0'<=c&&c<='9') || c=='-' ))c=g;
	if(c=='-')v=1,c=g;
	while('0'<=c&&c<='9')x=x*10+c-'0',c=g;
	if(v)x=-x;
}
void write(int x)
{
	if(x/10)write(x/10);
	putchar(x%10+'0');
}
void pri(int x){write(x);puts("");}

int main()
{
	while(qr(n),qr(m),n&&m)
	{
		mes(last,0);len=0;//注意初始化
		for(int i=1;i<n;i++)
		{
			int x,y,z;qr(x);qr(y);qr(z);
			ins(x,y,z);ins(y,x,z);
		}
		ans=0;mes(w,0);//注意初始化
		calc(n,1);
		pri(ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/zsyzlzy/p/12373895.html