POJ 1741 [点分治][树上路径问题]

/*
不要低头,不要放弃,不要气馁,不要慌张
题意:
给一棵有n个节点的树,每条边都有一个正权值,求一共有多少个点对使得它们之间路的权值和小于给定的k.
思路:
《分治算法在树的路径问题中的应用》
*/

#include<stdio.h>
#include<vector>
#include<string.h>
#include<algorithm>
#define N 10050
using namespace std;
struct edge{
    int id;
    long long w;
    bool im;
    edge *next;
};
struct st{
    st(){}
    st(int a,long long b,int c){
        id=a;dis=b;iid=c;
    }
    int id,iid;
    long long dis;
};
vector<st>mv;
edge edges[N*2];
edge *adj[N];
long long ans,kk;
int ednum;
inline void addedge(int a,int b,long long c){
    edge *tmp=&edges[ednum++];
    tmp->id=b;
    tmp->w=c;
    tmp->im=1;
    tmp->next=adj[a];
    adj[a]=tmp;
}
int zong,next_val,next;
bool vis[N];
int siz[N],fa[N];
long long dis[N];
void dfs(int pos,int dep){
    vis[pos]=1;
    siz[pos]=1;
    for(edge *it=adj[pos];it;it=it->next){
        if(it->im&&!vis[it->id]){
            dfs(it->id,dep+1);
            siz[pos]+=siz[it->id];
        }
    }
}
void dfs2(int pos,int dep){
    int my_next=-1;
    vis[pos]=1;
    for(edge *it=adj[pos];it;it=it->next){
        if(it->im&&!vis[it->id]){
            my_next=max(my_next,siz[it->id]);
        }
    }
    my_next=max(my_next,zong-siz[pos]);
    if(next_val>my_next){
        next=pos;
        next_val=my_next;
    }
    for(edge *it=adj[pos];it;it=it->next){
        if(it->im&&!vis[it->id]){
            dfs2(it->id,dep+1);
        }
    }
}
bool cmp1(st a,st b){
    if(a.dis!=b.dis)return a.dis<b.dis;
    else return a.iid<b.iid;
}
bool cmp2(st a,st b){
    if(a.id!=b.id)return a.id<b.id;
    else if(a.dis!=b.dis)return a.dis<b.dis;
    else return a.iid<b.iid;
}
inline void del(int a,int b){
    for(edge *it=adj[a];it;it=it->next){
        if(it->id==b){
            it->im=0;
            return;
        }
    }
}
void dfs3(int pos,int dep){
    vis[pos]=1;
    if(!dep)dis[pos]=0;
    for(edge *it=adj[pos];it;it=it->next){
        if(it->im&&!vis[it->id]){
            if(!dep)fa[it->id]=it->id;
            else fa[it->id]=fa[pos];
            dis[it->id]=dis[pos]+it->w;
            mv.push_back(st(fa[it->id],dis[it->id],it->id));
            dfs3(it->id,dep+1);
        }
    }
}
void solve(int pos){
    mv.clear();
    memset(vis,0,sizeof(vis));
    dfs(pos,0);
    zong=siz[pos];
    if(zong<=1)return;
    memset(vis,0,sizeof(vis));
    next_val=9999999;
    dfs2(pos,0);
    memset(vis,0,sizeof(vis));
    dfs3(next,0);
    int n=mv.size();
    sort(mv.begin(),mv.end(),cmp1);
    for(int i=0;i<n;i++){
        if(mv[i].dis>kk)break;
        ans++;
        int l=i+1,r=n-1;
        while(l<=r){
            int mid=(l+r)>>1;
            if(mv[i].dis+mv[mid].dis<=kk)l=mid+1;
            else r=mid-1;
        }
        ans+=r-i;
    }
    sort(mv.begin(),mv.end(),cmp2);
    int st=0;
    for(int i=0;i<n;i++){
        if(mv[i].id!=mv[st].id){
            for(int j=st;j<i;j++){
                int l=j+1,r=i-1;
                while(l<=r){
                    int mid=(l+r)>>1;
                    if(mv[j].dis+mv[mid].dis<=kk)l=mid+1;
                    else r=mid-1;
                }
                ans-=r-j;
            }
            st=i;
        }
    }
    for(int j=st;j<n;j++){
        int l=j+1,r=n-1;
        while(l<=r){
            int mid=(l+r)>>1;
            if(mv[j].dis+mv[mid].dis<=kk)l=mid+1;
            else r=mid-1;
        }
        ans-=r-j;
    }
    vector<int>mmv;
    for(edge *it=adj[next];it;it=it->next){
        if(it->im&&siz[it->id]>1){
            mmv.push_back(it->id);
            it->im=0;
            del(it->id,next);
        }
    }
    for(int i=0;i<mmv.size();i++)solve(mmv[i]);
}
int main()
{
    int n;
    while(scanf("%d%lld",&n,&kk)!=EOF){
        ans=0;
        if(!n)break;
        memset(adj,NULL,sizeof(adj));
        ednum=0;
        for(int i=1;i<n;i++){
            int a,b;
            long long c;
            scanf("%d%d%lld",&a,&b,&c);
            addedge(a,b,c);
            addedge(b,a,c);
        }
        solve(1);
        printf("%lld
",ans);
    }
}
原文地址:https://www.cnblogs.com/tun117/p/5956621.html