树分治 poj 1741

n k 

n个节点的一棵树 k是距离

求树上有几对点距离<=k;

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<vector>

using namespace std;
#define MAXN 100100

int head[MAXN];
bool vis[MAXN];
struct edg
{
    int to,next,w;
}x[MAXN];
int cnt,ans,n,k,root;
int mi,num;
int mx[MAXN],size[MAXN],dis[MAXN]; //分别 是去掉这个点的子树最大节点数 以这个节点为根的子树包括的节点树(包括自己) 就一个存距离的数组 

void add(int u,int v,int w)
{
    x[cnt].next=head[u];
    x[cnt].w=w;
    x[cnt].to=v;
    head[u]=cnt++;
}
void dfssize(int u,int fa) //树DP差不多 包括的节点
{
    size[u]=1;
    mx[u]=0;
    for(int i=head[u];i!=-1;i=x[i].next)
    {
        int v=x[i].to;
        if(v!=fa&&!vis[v])
        {
            dfssize(v,u);
            size[u]+=size[v];
            if(size[v]>mx[u]) //这边子树更新一下  下面还有更新 
                mx[u]=size[v];
        }
    }
}
void dfsroot(int r,int u,int fa)
{
    if(size[r]-size[u]>mx[u]) //根节点另外一边
        mx[u]=size[r]-size[u];
    if(mx[u]<mi) //找重心
    {
        mi=mx[u];
        root=u;
    }
    for(int i=head[u];i!=-1;i=x[i].next)
    {
        int v=x[i].to;
        if(v!=fa&&!vis[v])
            dfsroot(r,v,u);
    }
}
void dfsdis(int u,int d,int fa)//存一下距离
{
    dis[num++]=d;
    for(int i=head[u];i!=-1;i=x[i].next)
    {
        int v=x[i].to;
        if(v!=fa&&!vis[v])
            dfsdis(v,d+x[i].w,u);
    }
}
int calc(int u,int d)
{
    int ret=0;
    num=0;
    dfsdis(u,d,0);
    sort(dis,dis+num);
    int i=0,j=num-1;
    while(i<j) //自己模拟一下  算出有多少对
    {
        while(dis[i]+dis[j]>k&&i<j)
            j--;
        ret+=j-i;
        i++;
    }
    return ret;
}
void dfs(int u)
{
    mi=n;
    dfssize(u,0);
    dfsroot(u,u,0);
    ans+=calc(root,0);  //这边算出来的还要减去  在它的一颗子树中(2个节点)如果满足 上面条件 其实并没什么用  但是多算了 
    vis[root]=1;
    for(int i=head[root];i!=-1;i=x[i].next)
    {
        int v=x[i].to;
        if(!vis[v])
        {
            ans-=calc(v,x[i].w); //所以这边要减掉
            dfs(v);
        }
    }
}
int main()
{
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        if(n==k&&k==0)
            break;
        memset(head,-1,sizeof(head));
        memset(vis,0,sizeof(vis));
        ans=cnt=0;
        for(int i=1;i<n;i++)
        {
            int a,b,w;
            scanf("%d%d%d",&a,&b,&w);
            add(a,b,w);
            add(b,a,w);
        }
        dfs(1);
        printf("%d
",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/cherryMJY/p/6151572.html