LOJ #6119. 「2017 山东二轮集训 Day7」国王

Description

在某个神奇的大陆上,有一个国家,这片大陆的所有城市间的道路网可以看做是一棵树,每个城市要么是工业城市,要么是农业城市,这个国家的人认为一条路径是 exciting 的,当且仅当这条路径上的工业城市和农业城市数目相等。现在国王想把城市分给他的两个儿子,大儿子想知道,他选择一段标号连续的城市作为自己的领地,并把剩下的给弟弟,能够满足两端都是自己城市的 exciting 路径比两端都是弟弟的城市的 exciting 路径数目多的方案数。

Solution

我们分析一下:
要求的是满足 两端点全在 ([l,r]) 之间的路径-两端点全在 ([l,r]) 外的路径>0 的方案数 .....①
我们两个端点都在某个位置会不好算,如果只有一个端点在区间内就比较好算

求出至少一个有端点在 ([l,r]) 的方案数=两个端点都在 ([l,r]) 的方案数+有一个在内另一个在外的方案数,
我们发现如果另一个端点在 ([l,r]) 内,另一端点在外的情况在①式中相减之后抵消了,所以根本不需要考虑这种情况

所以只需要求出 (w[x]) 表示以 (x) 为其中一个端点的合法的路径的方案数,(sum_{i=l}^{r}w[i]) 就是至少有一个端点在 ([l,r]) 内的方案数,设为 (cnt)
设总合法的路径为 (tot)
我们维护两个单调指针,当 (cnt>tot-cnt) 时,移动指针 (l) 就行了,

至于 (w[x]) 的求法就是一个基本的点分治了,值得注意的是合并子树的统计方法仿佛在这题不能用,需要用容斥

#include<bits/stdc++.h>
using namespace std;
const int N=100005;
int n,a[N],son[N]={N},sz[N],head[N],nxt[N<<1],to[N<<1],num=0;
int sum,rt=0;bool vis[N];
inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
inline void getroot(int x,int last){
	sz[x]=1;son[x]=0;
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];if(u==last || vis[u])continue;
		getroot(u,x);
		sz[x]+=sz[u];son[x]=max(son[x],sz[u]);
	}
	son[x]=max(son[x],sum-sz[x]);
	if(son[x]<son[rt])rt=x;
}
int st[N],top=0,id[N],dis[N],w[N],t[N*2];
inline void dfs(int x,int last,int val){
	st[++top]=x;dis[x]=val;t[val+N]++;
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];if(u==last || vis[u])continue;
		dfs(u,x,val+a[u]);
	}
}
inline void calc(int r,int x,int op,int sta){
	top=0;dfs(x,x,sta+a[x]);
	for(int i=1;i<=top;i++)
		w[st[i]]+=op*t[N-dis[st[i]]+a[r]];
	for(int i=1;i<=top;i++)t[N+dis[st[i]]]--;
}
inline void solve(int x){
	vis[x]=1;calc(x,x,1,0);
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];if(vis[u])continue;
		calc(x,u,-1,a[x]);
		rt=0;sum=sz[u];getroot(u,x);solve(rt);
	}
}
int main(){
  freopen("pp.in","r",stdin);
  freopen("pp.out","w",stdout);
  int x,y;
  scanf("%d",&n); 
  for(int i=1;i<=n;i++)scanf("%d",&a[i]),a[i]?a[i]=1:a[i]=-1;
  for(int i=1;i<n;i++){
	  scanf("%d%d",&x,&y);
	  link(x,y);link(y,x);
  }
  rt=0;sum=n;getroot(1,1);
  solve(rt);
  long long ans=0,cnt=0,tot=0;
  for(int i=1;i<=n;i++)tot+=w[i];
  for(int i=1,l=1;i<=n;i++){
	  cnt+=w[i];
	  while(l<i && cnt>tot-cnt)cnt-=w[l++];
	  ans+=l-1;
  }
  cout<<ans<<endl;
  return 0;
}

原文地址:https://www.cnblogs.com/Yuzao/p/8505053.html