洛谷P3085 [USACO13OPEN]阴和阳Yin and Yang(点分治,树上差分)

洛谷题目传送门

闲话

偶然翻到一道没有题解的淀粉质,想证明一下自己是真的弱

然而ZSYC(字符串组合)早就切了

然后证明成功了,WA到怀疑人生,只好借着ZSY的代码拍,拍了几万组就出来了。。。

思路

是人都能想到的:路径统计,点分治跑不了了。

然而这个统计有些麻烦。。。

首先别看错题,是中间的一个点到两个端点的两条路径都要满足黑白相等。(因为蒟蒻就看错了)

显然,我们每次要统计经过重心的路径,但是这个中点不一定会在重心。于是,必须要更一般化地统计了。

容易想到的是差分。记(d_x)(x)到重心路径上黑白个数的差,可以改一下边权变成(1)(-1),更好实现。

这时候,我们已经可以做到统计黑白个数相等的路径了,一对差分值为(d)(-d)的路径就可以产生贡献。

但是,题目要两条子路径都黑白相等啊!再想想,假设有两个点(x,y,d_x=-d_y),那么是不是只有在(x)的祖先中存在(z)(d_z=-d_y),或者在(y)的祖先中存在(z)(d_x=-d_z),路径(x-y)才能贡献答案?

这也就是说,我们选择的中点(z)(d)值,需要和(d_x)(d_y)相等才行。

于是,我们根据每个点是否有祖先的(d)等于它的(d),分成两类。显然,没有的点只能和有的点算贡献,而有的点可以和两种点都算贡献。开两个桶区别统计一下就好了。

具体实现的话,蒟蒻直接依次枚举重心的子树,先统计一个子树的答案,再把子树的每个点丢进桶里。这样就不用容斥去除不合法情况了。但是要特判一个端点在重心上的路径的贡献。

如何快速判断当前点的属于哪一类?开一个数组(b)统计祖先中每个(d)值出现的次数就可以了((d)可能是负数,要把所有下标加(n))。注意回溯的时候要减掉。

注意答案要开long long,注意清空数组。

代码算是很短的了。

#include<cstdio>
#include<cstring>
#define RG register
#define R RG int
#define G c=getchar()
const int N=1e5+1,M=N<<1;
int n,rt,mx,mn,he[N],ne[M],to[M],c[M],s[N],m[N],b[M],f[M],g[M];
bool vis[N];
long long ans;
inline void min(R&x,R y){if(x>y)x=y;}
inline void max(R&x,R y){if(x<y)x=y;}
inline int in(){
	RG char G;
	while(c<'-')G;
	R x=c&15;G;
	while(c>'-')x=x*10+(c&15),G;
	return x;
}
void dfs(R x){//求重心
	vis[x]=1;s[x]=1;m[x]=0;
	for(R y,i=he[x];i;i=ne[i])
		if(!vis[y=to[i]])
			dfs(y),s[x]+=s[y],max(m[x],s[y]);
	max(m[x],n-s[x]);
	if(m[rt]>m[x])rt=x;
	vis[x]=0;
}
void upd(R x,R d){//统计答案,f没有g有
	min(mn,d);max(mx,d);
	ans+=g[M-d];//分类贡献
	if(b[d])ans+=f[M-d];
	if(d==N)ans+=b[d]>1;//特判
	vis[x]=1;++b[d];
	for(R i=he[x];i;i=ne[i])
		if(!vis[to[i]])upd(to[i],d+c[i]);
	vis[x]=0;--b[d];//回溯清零
}
void mdf(R x,R d){//修改桶
	++(b[d]?g[d]:f[d]);
	vis[x]=1;++b[d];
	for(R i=he[x];i;i=ne[i])
		if(!vis[to[i]])mdf(to[i],d+c[i]);
	vis[x]=0;--b[d];
}
void div(R x){//分治
	rt=0;dfs(x);x=rt;
	vis[x]=1;b[mn=mx=N]=1;//此处注意初始化
	R t=n,y,i;
	for(i=he[x];i;i=ne[i])
		if(!vis[y=to[i]])
			upd(y,N+c[i]),mdf(y,N+c[i]);
	memset(f+mn,0,(mx-mn+1)<<2);//注意清空
	memset(g+mn,0,(mx-mn+1)<<2);
	for(i=he[x];i;i=ne[i])
		if(!vis[y=to[i]])
			n=s[x]>s[y]?s[y]:t-s[x],div(y);
}
int main(){
	m[0]=1e9;n=in();
	for(R a,b,p=0,i=1;i<n;++i){
		a=in();b=in();
		ne[++p]=he[a];to[he[a]=p]=b;
		ne[++p]=he[b];to[he[b]=p]=a;
		c[p]=c[p-1]=in()?1:-1;//处理边权
	}
	div(1);
	printf("%lld
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/flashhu/p/9391864.html