[POJ3417]Network(LCA,树上差分)

Network

Description

Yixght is a manager of the company called SzqNetwork(SN). Now she's very worried because she has just received a bad news which denotes that DxtNetwork(DN), the SN's business rival, intents to attack the network of SN. More unfortunately, the original network of SN is so weak that we can just treat it as a tree. Formally, there are N nodes in SN's network, N-1 bidirectional channels to connect the nodes, and there always exists a route from any node to another. In order to protect the network from the attack, Yixght builds _M_new bidirectional channels between some of the nodes.

As the DN's best hacker, you can exactly destory two channels, one in the original network and the other among the M new channels. Now your higher-up wants to know how many ways you can divide the network of SN into at least two parts.

Input

The first line of the input file contains two integers: N (1 ≤ N ≤ 100 000), M (1 ≤ M ≤ 100 000) — the number of the nodes and the number of the new channels.

Following N-1 lines represent the channels in the original network of SN, each pair (a,b) denote that there is a channel between node a and node b.

Following M lines represent the new channels in the network, each pair (a,b) denote that a new channel between node a and node b is added to the network of SN.

Output

Output a single integer — the number of ways to divide the network into at least two parts.

Sample Input

4 1
1 2
2 3
1 4
3 4

Sample Output

3

树上乱搞题
最开始看了没有思路,第二天再看还是没有思路GG... 自己想到的是统计自己子树中有多少条附加边的一端,但是没细想,感觉情况很多,不好讨论。

事实上我们考虑,每多加一条非树边,在不重的情况下,树上都会多出一个环。

考虑断掉某条树边:
1.它可能没进入环,这时很显然再随意断开一条非树边即可。对答案的贡献为m。
2.它进入了一个环,这时我们再断开环中的那条非树边,成为一种方案。
3.它进入了两个环,树上部分成为两部分,但因为这条边处于两个环中,所以还有两条非树边从上半部分连到下半部分,无法使图成为两块。
4.它进入了多条环时与进入了两个环时情况类似,无法使图成为两块。

那这时思路就很显然了,对于每条非树边,我们把环上的每条树边入环个数++,但是我们发现这样做的话时间复杂度比较高,因为这样做需要遍历路径上的每一条边。这时考虑能不能把边转移到点上来记录,巨佬这时很容易想到运用树上差分的思想来解决(我还是太菜了,完全没想到)。对于每条非树边((u,v)):

a[u]++;a[v]++;a[lca(u,v)]-=2;

最后对于每个点按照上面的方法处理即可。
注意:不需要考虑根节点

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
using namespace std;
int read()
{
	int x=0,w=1;char ch=getchar();
	while(ch>'9'||ch<'0') {if(ch=='-')w=-1;ch=getchar();}
	while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return x*w;
}
const int N=100010;
int n,m,x,y,cnt,ans;
int head[N],deep[N],f[N][20],a[N];
struct node{
	int to,next;
}edge[2*N];
void add(int x,int y)
{
	cnt++;
	edge[cnt].to=y;
	edge[cnt].next=head[x];
	head[x]=cnt;
}
void init()
{
	for(int i=1;i<=19;i++)
		for(int j=1;j<=n;j++)
			f[j][i]=f[f[j][i-1]][i-1];
}
void dfs(int k,int fa)
{
	for(int i=head[k];i;i=edge[i].next)
	{
		int v=edge[i].to;
		if(v==fa) continue;
		deep[v]=deep[k]+1;f[v][0]=k;
		dfs(v,k);
	}
}
int LCA(int x,int y)
{
	if(deep[x]<deep[y]) swap(x,y);
	for(int i=19;i>=0;i--)
	{
		if(deep[f[x][i]]>=deep[y]) x=f[x][i];
	}
	if(x==y) return x;
	for(int i=19;i>=0;i--)
	{
		if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}
void get(int k,int fa)
{
	for(int i=head[k];i;i=edge[i].next)
	{
		int v=edge[i].to;
		if(v==fa) continue;
		get(v,k);
		a[k]+=a[v];
	}
}
int main()
{
	n=read();m=read();
	for(int i=1;i<n;i++)
	{
		x=read();y=read();
		add(x,y);add(y,x);
	}
	dfs(1,0);init();
	for(int i=1;i<=m;i++)
	{
		x=read();y=read();
		int lca=LCA(x,y);
		a[x]++;a[y]++;a[lca]-=2;
	}
	get(1,0);
	for(int i=2;i<=n;i++)
	{
		if(a[i]==0) ans+=m;
		else if(a[i]==1) ans++;
	}
	cout<<ans;
}
原文地址:https://www.cnblogs.com/lsgjcya/p/9247167.html