poj 1741 两点距离小于K(树DP)

http://blog.csdn.net/woshi250hua/article/details/7723400

求两点间距离小于等于k的方案数

理一下思路:

求通过点A与另一点连接符合条件的个数 = 到点A距离相加符合条件个数 - A内部符合条件的个数

步骤:

因为从哪个点开始都一样,所以每次找子树重心开始遍历

求出子节点到这个点的所有距离,排序搜索,得出总方案数

减掉内部符合条件的数量,得到通过这个点的方案数

以此遍历每个点

注意:

求重心时,因为每次子树数量都是不一样的,要动态更新

相向搜索:O(n)

sort(dis,dis+tot);
        int left =0,right = tot-1;
        while(left<right)
        {
            if(dis[left]+dis[right]<=k)
            {
                ans-= right-left;
                left++;
            }
            else right--;
        }

AC代码:

#include <iostream>
#include <string>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <stack>
#include <queue>
#include <cctype>
#include <vector>
#include <iterator>
#include <set>
#include <map>
#include <sstream>
using namespace std;

#define mem(a,b) memset(a,b,sizeof(a))
#define pf printf
#define sf scanf
#define spf sprintf
#define pb push_back
#define debug printf("!
")
#define MAXN 20000+5
#define MAX(a,b) a>b?a:b
#define blank pf("
")
#define LL long long
#define ALL(x) x.begin(),x.end()
#define INS(x) inserter(x,x.begin())
#define pqueue priority_queue
#define INF 0x3f3f3f3f

#define ls (rt<<1)
#define rs (rt<<1|1)

int n,m,k;

int ptr = 1,head[MAXN],vis[MAXN];

int num,ans,tot,rt,sum,son[MAXN],dis[MAXN],mu[MAXN];

struct node
{
    int y,val,next;
}tree[MAXN<<1];

void add(int fa,int son,int val)
{
    tree[ptr].y = son;
    tree[ptr].val = val;
    tree[ptr].next = head[fa];
    head[fa] = ptr++;
}

void getroot(int root,int fa)
{
    son[root] = 1;
    int tmp = 0;
    for(int i=head[root];i!=-1;i=tree[i].next)
    {
        int y = tree[i].y;
        if(vis[y] || y == fa) continue;
        getroot(y,root);
        son[root] += son[y];
        tmp = max(son[y],tmp);
        mu[root] = max(mu[root],son[y]);
    }
    tmp = max(tmp,sum-son[root]);
    if(tmp<num)
    {
        num = tmp;
        rt = root;
    }
}

void getdis(int root,int fa,int dist)
{
    dis[tot++] = dist;
    for(int i=head[root];i!=-1;i=tree[i].next)
    {
        int y = tree[i].y;
        if(vis[y] || y == fa || dist+tree[i].val > k) continue;
        getdis(y,root,dist+tree[i].val);
    }
}

void getcnt()
{
    sort(dis,dis+tot);
    int left =0,right = tot-1;
    while(left<right)
    {
        if(dis[left]+dis[right]<=k)
        {
            ans+= right-left;
            left++;
        }
        else right--;
    }
}

void getcnt2(int root)
{
    vis[root] = 1;
    for(int i=head[root];i!=-1;i=tree[i].next)
    {
        int y = tree[i].y;
        if(vis[y]) continue;
        tot = 0;
        getdis(y,root,tree[i].val);
        sort(dis,dis+tot);
        int left =0,right = tot-1;
        while(left<right)
        {
            if(dis[left]+dis[right]<=k)
            {
                ans-= right-left;
                left++;
            }
            else right--;
        }
    }
}

void solve(int root,int fa)
{
    num = sum = son[root];
    getroot(root,fa);
    tot = 0;
    getdis(rt,0,0);
    getcnt();
    getcnt2(rt);
    for(int i=head[rt];i!=-1;i=tree[i].next)
    {
        int y = tree[i].y;
        if(vis[y] || y==fa) continue;
        solve(y,rt);
    }
}


int main()
{
    int i,j,t,kase=1;
    while(~sf("%d%d",&n,&k),n+k)
    {
        mem(tree,0);
        mem(head,-1);
        mem(vis,0);
        mem(dis,0);
        ans = 0;
        ptr = 1;
        son[1] = n;
        int x,y,z;
        for(i=1;i<n;i++)
        {
            sf("%d%d%d",&x,&y,&z);
            add(x,y,z);
            add(y,x,z);
        }
        solve(1,0);
        pf("%d
",ans);
    }
    return 0;
}
/*
5 3
1 2 1
2 3 1
3 4 1
4 5 1
*/
原文地址:https://www.cnblogs.com/qlky/p/5783975.html