poj 1741

点分治入门题

首先发现是树上点对的问题,那么首先想到上点分治

然后发现题目要求是求出树上点对之间距离小于等于k的对数,那么我们很自然地进行分类:

对于一棵有根树,树上的路径只有两种:一种经过根节点,另一种不经过根节点

对于经过根节点的路径,我们可以通过计算出每个点的根节点的距离,然后相加就能求出点对间距离

对于不经过根节点的路径,我们可以递归到子节点去算,去找子节点对应的子树来算即可

但是这里有两个问题:第一,如何快速算出以一个点为根的合法点对数量?

我们知道,可以在线性时间内求出每个点到根节点的距离,但如果我们逐个枚举点对的话,时间就会退化成平方级别

这显然不够优秀

所以我们将每个点到根节点距离排序,然后用两个指针,初始分别指向头和尾,如果两个指针指到的之和是合法的,那么这两个指针间的部分都是合法的(具体看代码),扫一遍即可

第二:这样做的结果是正确的吗?

我们看到,如果查到的一个点对在同一棵子树内,那么在计算以这个点为根和以这个点的子节点为根的时候,这个点对都会被计算一次!

这显然是不对的

因此我们在枚举每个子树时需要先去掉这一部分,然后再计算

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
const int inf=0x3f3f3f3f;
struct Edge
{
    int next;
    int to;
    int val;
}edge[200005];
int head[100005];
int rt,s;
int n,k;
int maxp[100005];
int siz[100005];
bool vis[100005];
int dis[100005];
int used[100005];
int ans=0;
int cnt=1;
void init()
{
    memset(head,-1,sizeof(head));
    memset(vis,0,sizeof(vis));
    ans=0;
    cnt=1;
}
void add(int l,int r,int w)
{
    edge[cnt].next=head[l];
    edge[cnt].to=r;
    edge[cnt].val=w;
    head[l]=cnt++;
}
void get_rt(int x,int fa)
{
    siz[x]=1,maxp[x]=0;
    for(int i=head[x];i!=-1;i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to]||to==fa)continue;
        get_rt(to,x);
        siz[x]+=siz[to];
        maxp[x]=max(maxp[x],siz[to]);
    }
    maxp[x]=max(maxp[x],s-siz[x]);
    if(maxp[x]<maxp[rt])rt=x;
}
void get_dis(int x,int fa)
{
    used[++used[0]]=dis[x];
    for(int i=head[x];i!=-1;i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to]||to==fa)continue;
        dis[to]=dis[x]+edge[i].val;
        get_dis(to,x);
    }
}
int calc(int x,int val)
{
    dis[x]=val;
    used[0]=0;
    get_dis(x,0);
    sort(used+1,used+used[0]+1);
    int l=1,r=used[0];
    int ret=0;
    while(l<r)
    {
        if(used[l]+used[r]<=k)ret+=r-l,l++;
        else r--;
    }
    return ret;
}
void solve(int x)
{
    vis[x]=1;
    ans+=calc(x,0);
    for(int i=head[x];i!=-1;i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to])continue;
        ans-=calc(to,edge[i].val);
        rt=0,s=siz[to],maxp[rt]=inf;
        get_rt(to,0);
        solve(rt);
    }
}
int main()
{
    while(1)
    {
        scanf("%d%d",&n,&k);
        init();
        if(!n&&!k)return 0;
        for(int i=1;i<n;i++)
        {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            add(x,y,z),add(y,x,z);
        }
        maxp[rt]=s=n;
        get_rt(1,0);
        solve(rt);
        printf("%d
",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/zhangleo/p/10784793.html