POJ1741 Tree(点分治)

嘟嘟嘟


没错,这一道最经典的点分治模板题。
题意:求树上两点间距离(leqslant k)的点对个数。


点分治这东西我好早就听说了,然后一两个月前也学了一下,不过只是刷了个模板,没往深处学。
对于这道题,就说说大概的步骤吧。
1.找重心:一遍(dfs)即可。
2.求出每一个子树中的点到重心的距离。并且记录这个点属于哪一棵子树。
3.把上述的点存下来,按距离从小到大排序。
4.统计答案。采用双指针,(i)从头开始,(j)从尾开始。这样每一个(i)对答案的贡献是(j - i - num[point_i])(num[point_i])表示的是(i)所在的子树有多少个。(为了减去属于相同子树的贡献)
5.递归到每一个子树中统计答案。


20.10.18更新:
上了大学打算打acm,然后又学了一遍点分治。
在统计答案时有一个更简单的做法:在每一次从重心开始往每一个子树dfs得到距离序列s后,我们把s用双指针扫一下,求出距离小于等于(k)的点的对数(x)。然后把(x)从答案中减去。最后再加上整个重心的距离序列即可。
这是利用类似容斥的思想,从而避免了记录上述(num[i]),还减少了代码难度。


先给出方法一的代码:

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 4e4 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, k;
struct Edge
{
  int nxt, to, w;
}e[maxn << 1];
int head[maxn], ecnt = -1;
void addEdge(int x, int y, int w)
{
  e[++ecnt] = (Edge){head[x], y, w};
  head[x] = ecnt;
}

bool out[maxn];
int Siz, siz[maxn], Max[maxn];
void dfs1(int now, int _f, int &cg)
{
  siz[now] = 1; Max[now] = -1;
  for(int i = head[now], v; i != -1; i = e[i].nxt)
    {
      if(!out[v = e[i].to] && v != _f)
	{
	  dfs1(v, now, cg);
	  siz[now] += siz[v];
	  Max[now] = max(Max[now], siz[v]);
	}
    }
  Max[now] = max(Max[now], Siz - siz[now]);
  if(!cg || Max[now] < Max[cg]) cg = now;
}

struct Node
{
  int dis, bel;
  bool operator < (const Node& oth)const
  {
    return dis < oth.dis;
  }
}a[maxn];
int cnt = 0;
void dfs2(int now, int _f, int dis, int x, int cg)
{
  siz[now] = 1;
  a[++cnt] = (Node){dis, x};
  for(int i = head[now], v; i != -1; i = e[i].nxt)
    {
      if(!out[v = e[i].to] && v != _f)
	{
	  dfs2(v, now, dis + e[i].w, now == cg ? v : x, cg);
	  siz[now] += siz[v];
	}
    }
}

int num[maxn], ans = 0;
void solve(int now)
{
  int cg = 0; cnt = 0;
  dfs1(now, 0, cg);
  dfs2(cg, 0, 0, 0, cg);
  sort(a + 1, a + cnt + 1);
  for(int i = head[cg]; i != -1; i = e[i].nxt)
    if(!out[e[i].to]) num[e[i].to] = 0;
  for(int i = 1; i <= cnt; ++i) num[a[i].bel]++;
  for(int i = 1, j = cnt; i <= j; ++i)
    {
      num[a[i].bel]--;
      while(a[i].dis + a[j].dis > k && i <= j) num[a[j--].bel]--;
      if(i > j) break;
      ans += j - i - num[a[i].bel];
    }
  out[cg] = 1;
  for(int i = head[cg], v; i != -1; i = e[i].nxt)
    if(!out[v = e[i].to]) Siz = siz[v], solve(v);
}

int main()
{
  Mem(head, -1);
  n = read();
  for(int i = 1; i < n; ++i)
    {
      int x = read(), y = read(), w = read();
      addEdge(x, y, w); addEdge(y, x, w);
    }
  k = read();
  ans = 0; Siz = n;
  solve(1);
  write(ans), enter;
  return 0;
}

然后是方法2的主要代码
int a[maxn], cnt = 0;
In void dfs2(int now, int _f, int d)
{
	if(d > K) return;
	a[++cnt] = d;
	forE(i, now, v) if(v != _f && !out[v]) dfs2(v, now, d + e[i].w);
}

int ans;
In int calc(int* a, int cnt)
{
	int ret = 0;
	for(int i = 1, j = cnt; i < j; ++ i)
	{
		while(i < j && a[i] + a[j] > K) --j;
		ret += j - i;
	}
	return ret;
}

int st[maxn], top = 0;
In void solve(int now)
{
	cg = 0; st[top = 1] = 0;
	dfs1(now, 0, cg);		//求重心cg 
	forE(i, cg, x)
	{
		if(out[x]) continue;
		cnt = 0, dfs2(x, cg, e[i].w);
		sort(a + 1, a + cnt + 1);
		ans -= calc(a, cnt);
		for(int j = 1; j <= cnt; ++j) st[++top] = a[j];
	}
	sort(st + 1, st + top + 1);
	ans += calc(st, top);
	out[cg] = 1;
	forE(i, cg, x) if(!out[x]) Siz = siz[x], solve(x);
}
原文地址:https://www.cnblogs.com/mrclr/p/10032664.html