poj1741 树上距离小于等于k的对数 点分治 入门题

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define  N 40005
#define  M 80005
#define LL long long
using namespace std;
const int INF = 0x3f3f3f3f;
int ver[M],edge[M],head[N],Next[M];
int n,m,tot,root;LL k;
void add(int x,int y,int w){
    ver[++tot]=y;edge[tot]=w;Next[tot]=head[x];head[x]=tot;
    ver[++tot]=x;edge[tot]=w;Next[tot]=head[y];head[y]=tot;
}
int sz[N],vis[N],mx,size;
LL d[N],q[N],l,r,ans=0;
//求出树的重心 因为找到重心后,递归子树不超过原来的一半,递归层数小于logn层
void getroot(int u,int fa){
    sz[u]=1;int num=0;
    for (int i=head[u];i;i=Next[i]){
        int v=ver[i];
        if (v==fa||vis[v])continue;
        ///继续深搜
        getroot(v,u);
        ///计算出子树的大小
        sz[u]+=sz[v];
        ///维护子树的最长的链
        num=max(num,sz[v]);
    }
    ///num代表的是子节点的最长链 size-sz[u]代表的是父亲链长
    num=max(num,size-sz[u]);
    if (num<mx)mx=num,root=u;
}
///计算某个节点所有子树中的节点的到这个节点的距离
void getdis(int u,int fa){
    q[++r]=d[u];
    for(int i=head[u];i;i=Next[i]){
        int v=ver[i];
        if(v==fa||vis[v])continue;
        d[v]=d[u]+edge[i];
        getdis(v,u);
    }
}
LL cal(int u,int val){
    r=0;
    d[u]=val;
    getdis(u,0);
    LL sum=0,l=1;
    ///把子树点到当前点的距离进行排序
    sort(q+1,q+1+r);
    cout<<u<<" "<<r<<endl;
    for(int i=1;i<=r;i++){
        cout<<q[i]<<" ";
    }
    cout<<endl;
    ///开一个左右指针,以左端点为基准移动,如果两个距离是大于k,肯定移动右指针,
    ///也就是对于每一个小的,去右边寻找最远能满足条件的,而中间的一定满足
    while(l<r){
        if(q[l]+q[r]<=k)sum+=r-l,++l;
        else --r;
    }
    return sum;
}
void dfs(int u){
    ///计算当前节点内部所有>=k的数目 但是会存在连个点是在同一联通块内部 答案就不对了
    ans+=cal(u,0);
    vis[u]=1;
    for (int i=head[u];i;i=Next[i]){
        int v=ver[i];
        if (vis[v])continue;
        ///这里我们减去内部在同一个联通块里面的答案 相当于剪掉重复的
        ans-=cal(v,edge[i]);
        ///在当前点内部继续找重心
        size=sz[v];
        mx=INF;
        getroot(v,0);
        ///然后找到子树的重心,进行深搜
        dfs(root);
    }
}
int main(){
   int u,v,e,k;
   while(~scanf("%d%d",&n,&k) && n+k) {
       ans=0;
       memset(vis,0,sizeof(vis));
       memset(head,0,sizeof(head));
       memset(Next,0, sizeof(Next));
       for (int i = 1; i <n ; i++) {
           scanf("%d%d%d", &u, &v, &e);
           add(u, v, e);
       }
       size = n;
       mx = INF;
       getroot(1, 0);
       dfs(root);
       printf("%lld
", ans);
   }
    return 0;
}

  

原文地址:https://www.cnblogs.com/bluefly-hrbust/p/11550658.html