Poj1741-Tree(树分治)

题意:找树上有多少对距离小于K的对数
解析:树分治模板题,见注释

代码

#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<vector>
using namespace std;
typedef __int64 LL;
const int INF=1e9+7;
const int maxn=100005;
int N,K,ans;
struct edge{ int v,w,next; }E[2*maxn]; //
int head[maxn],eid;
void init()
{
    ans=eid=0;
    for(int i=0;i<=N;i++) head[i]=-1;
}
void AddEdge(int u,int v,int w)
{
    E[++eid].v=v; E[eid].w=w;
    E[eid].next=head[u]; head[u]=eid;
}
struct TBS
{
    int root,did,deep[maxn],dist[maxn]; //root选取的根
    int tot,s[maxn],cen[maxn]; //s树的大小,cen左右两边的节点最大值
    bool used[maxn]; //是否已被访问
    void init(int n) //初始化
    {
        tot=n;
        cen[0]=INF;
        memset(used,false,sizeof(used));
    }
    void GetRoot(int u,int fa)//找到重心
    {
        s[u]=1; //树的大小
        cen[u]=0;
        for(int i=head[u];i!=-1;i=E[i].next)
        {
            int v=E[i].v;
            if(v==fa||used[v]) continue;
            GetRoot(v,u);
            s[u]+=s[v];
            cen[u]=max(cen[u],s[v]); 
        }
        cen[u]=max(cen[u],tot-s[u]); 
        if(cen[u]<cen[root]) root=u; //更新重心
    }
    void GetDeep(int u,int fa)
    {
        deep[++did]=dist[u]; //保存下来
        for(int i=head[u];i!=-1;i=E[i].next)
        {
            int v=E[i].v,w=E[i].w;
            if(v==fa||used[v]) continue;
            dist[v]=dist[u]+w; //距离所选重心的距离
            GetDeep(v,u);
        }
    }
    int Cal(int u,int w)
    {
        dist[u]=w; did=0;
        GetDeep(u,0);
        sort(deep+1,deep+did+1);
        int ret=0;
        for(int l=1,r=did;l<r;)
        {
            if(deep[l]+deep[r]<=K){ ret+=r-l; l++; } //计算小于K的点对
            else r--;
        }
        return ret;
    }
    void Work(int u)
    {
        used[u]=true;
        ans+=Cal(u,0);
        for(int i=head[u];i!=-1;i=E[i].next)
        {
            int v=E[i].v,w=E[i].w;
            if(used[v]) continue;
            ans-=Cal(v,w); //去重
            tot=s[v];
            root=0;
            GetRoot(v,u); //找下一个重心
            Work(root);
        }
    }
}tbs;
int main()
{
    while(scanf("%d%d",&N,&K)!=EOF)
    {
        if(!N&&!K) break;
        init();
        int u,v,w;
        for(int i=1;i<N;i++)
        {
            scanf("%d%d%d",&u,&v,&w);
            AddEdge(u,v,w);
            AddEdge(v,u,w);
        }
        tbs.init(N);
        tbs.GetRoot(1,0);
        tbs.Work(tbs.root);
        printf("%d
",ans);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/wust-ouyangli/p/5793438.html