POJ 1741 Tree(点分治)

点分治

因为树上的路径只有两种,经过根的和没有经过根的,所以可以以根进行分治计算.

  1. 找重心.
  2. 以重心为根,计算经过根的答案.
  3. 分治根的每颗子树.

POJ 1741

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#define ll long long
using namespace std;
const int N = 1e5+10;
const int INF = 0x3f3f3f3f;

struct E{
    int u,v,w,nxt;
    E(){}
    E(int u,int v,int w,int nxt):u(u),v(v),w(w),nxt(nxt){}
}e[N<<1];
int tot,head[N];
void add(int u,int v,int w){
    e[tot] = E(u,v,w,head[u]);    head[u] = tot++;
}

int n,m,rt,sum,cnt,k;
int ans;
int tmp[N]; // 到根长度为i的链
int siz[N],dis[N],maxp[N]; // 最大子树的大小
bool vis[N]; // 是否计算过以这个点所构成的子树

// 获得重心+处理siz
void getrt(int u,int f){
    siz[u] = 1, maxp[u] = 0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v = e[i].v;
        if(v == f || vis[v]) continue;
        getrt(v,u);
        siz[u] += siz[v];
        maxp[u] = max(siz[v],maxp[u]);
    }
    maxp[u] = max(sum-siz[u],maxp[u]);
    if(maxp[u]<maxp[rt])rt = u;
}
// 获得到根节点的距离,并将其存到tmp数组中
void getdis(int u,int f){
    tmp[++cnt] = dis[u];
    for(int i=head[u];~i;i=e[i].nxt){
        int v = e[i].v;
        if(v==f ||vis[v])   continue;
        dis[v] = dis[u]+e[i].w;
        getdis(v,u);
    }
}
int solve(int u,int dep){
    int res = 0;
    dis[u] = dep;
    cnt = 0;
    getdis(u,0);
    sort(tmp+1,tmp+cnt+1);
    int l=1,r=cnt;
    while(l<r){
        if(tmp[l]+tmp[r]<=k)    res+=r-l,l++;
        else r--;
    }
    return res;
}

void divide(int u){
    vis[u] = true; 
    ans += solve(u,0); // 先统计以为u根 经过u的路径的答案
    for(int i=head[u];~i;i=e[i].nxt){ // 每个儿子形成一颗子树
        int v = e[i].v;
        if(vis[v])  continue;
        ans -= solve(v,e[i].w); // 容斥掉从同一颗子树上的答案
        maxp[rt=0] = sum = siz[v]; // 初始化条件,总点数为当前子树大小
        getrt(v,0);      // 找重心
        getrt(rt,0);     // 更新siz
        divide(rt);      // 分治v所在的这个子树
    }
}

int main(){
    int u,v,w;
    while(scanf("%d%d",&n,&k)==2){
        if(n==0 && k==0)    break;
        memset(head,-1,sizeof head);
        memset(vis,0,sizeof vis);
        ans = 0;
        tot = 0;
        for(int i=1;i<n;++i){
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        maxp[0] = sum = n;
        getrt(1,0); // 获得重心
        getrt(rt,0);    // 以重心为根 更新siz
        divide(rt);
        printf("%d
",ans);
    }
    return 0;
}

学习链接

原文地址:https://www.cnblogs.com/xxrlz/p/11628272.html