POJ 1741 Tree【树分治】

第一次接触树分治,看了论文又照挑战上抄的代码,也就理解到这个层次了。。
以后做题中再慢慢体会学习。


题目链接:

http://poj.org/problem?id=1741

题意:

给定树和树边的权重,求有多少对顶点之间的边的权重之和小于等于K。

分析:

树分治。
直接枚举不可,我们将树划分成若干子树。
那么两个顶点有两种情况:

  1. u,v属于同一子树的顶点对
  2. u,v属于不同子树的顶点对

第一种情况,对子树递归即可求得。
第二种情况,从u到v的路径必然经过了顶点s,只要先求出每个顶点到s的距离再做统计即可。(注意在第二种情况中减去第一种重复计算的部分)
当树退化成链的形式时,递归的深度则退化为O(n),所以选择每次都找到树的重心作为分隔顶点。重心就是删掉此结点后得到的最大子树的顶点数最少的顶点,删除重心后得到的所有子树顶点数必然不超过n/2
查找重心的时候,假设根为v,先在v的子树中找到一个顶点,使删除该顶点后的最大子树的顶点数最少,然后考虑删除v的情况,获得最大子树的顶点数。
两者选择最小的一个,此时选中的顶点即为重心。
递归的每一层都做了排序O(nlogn),递归深度O(logn),总体时间复杂度O(nlog2n)

代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define sa(a) scanf("%d", &a)
#define mem(a,b) memset(a, b, sizeof(a))
using namespace std;
const int maxn = 1e4 + 5, oo = 0x3f3f3f3f;
int cnt[maxn], vis[maxn];
int K, ans;
struct EDGE{int to; int length;int next;};
int head[maxn];
EDGE edge[maxn * 2];
typedef pair<int, int>pii;
int tot = 0;
void addedge(int u, int v, int l)
{
    edge[tot].to = v;
    edge[tot].length = l;
    edge[tot].next = head[u];
    head[u] = tot++;
    edge[tot].to = u;
    edge[tot].length = l;
    edge[tot].next = head[v];
    head[v] = tot++;
}
int countsubtree(int v, int p)
{
    int ans = 1;
    for(int i = head[v]; i != -1; i = edge[i].next){
        int w = edge[i].to;
        if(w == p||vis[w]) continue;
        ans += countsubtree(w, v);
    }
    return cnt[v] = ans;
}
pii findc(int v, int p, int t)
{
    pii res = pii(oo, 0);
    int s = 1, m = 1;
   for(int i = head[v]; i != -1; i = edge[i].next){
        int w = edge[i].to;
        if(w == p || vis[w]) continue;
        res = min(res, findc(w, v, t));
        m = max(m, cnt[w]);
        s += cnt[w];
    }
    m = max(m, t - s);
    res = min(res, pii(m, v));
    return res;
}
void findpath(int v, int p, int d, vector<int>&ds)
{
    ds.push_back(d);
    for(int i = head[v]; i != -1; i = edge[i].next){
        int w = edge[i].to;
        if(w == p || vis[w]) continue;
        findpath(w, v, d +edge[i].length, ds);
    }
}
int count_pair(vector<int>&ds)
{
    int res = 0;
    sort(ds.begin(), ds.end());
    int j = ds.size() - 1;
    int i = 0;
    while(i < j){
        while(j > i &&ds[i] + ds[j] > K) j--;
        res += j - i;
        i++;
    }
    return res;
/*
    int j = ds.size();
    for(int i = 0; i < ds.size(); i++){
        while(j > 0 && ds[i] + ds[j - 1] > K) j--;
        res += j - (j > i?1:0);
    }
    return res / 2;*/
}
void solve(int v)
{
    vector<int>ds;
     countsubtree(v, -1);
     int s = findc(v, -1, cnt[v]).second;
     vis[s] =true;
     //(1)
     for(int i = head[s]; i != -1; i = edge[i].next){
        int w = edge[i].to;
         if(vis[w]) continue;
         solve(w);
     }
     //(2)
     ds.push_back(0);
    for(int i = head[s]; i != -1; i = edge[i].next){
        int w = edge[i].to;
        if(vis[w]) continue;
        vector<int>ts;
        findpath(w, s, edge[i].length, ts);
        ans -=  count_pair(ts);
        ds.insert(ds.end(), ts.begin(), ts.end());
     }
      vis[s] = false;
     ans += count_pair(ds);
}
void init()
{
    tot = 0;
    ans = 0;
    mem(head, -1);
    mem(vis, 0);
    mem(cnt, 0);
}
int main (void)
{
    int n;
    while(scanf("%d%d", &n, &K)== 2 && n + K){
         int u, v, l;
         init();
         for(int i = 0; i < n - 1; i++){
            sa(u),sa(v),sa(l);
            addedge(u, v, l);
         }
        solve(1);
        printf("%d
", ans);
    }
     return 0;
}
原文地址:https://www.cnblogs.com/Tuesdayzz/p/5758647.html