点分治

点分治

核心思路是先加入这个点的贡献,然后删除树上这个点,每次删的点是该子树的重心

给定一个有N个点(编号0,1,…,N-1)的树,每条边都有一个权值(不超过1000)。

树上两个节点x与y之间的路径长度就是路径上各条边的权值之和。

求长度不超过K的路径有多少条。

输入格式

输入包含多组测试用例。

每组测试用例的第一行包含两个整数N和K。

接下来N-1行,每行包含三个整数u,v,l,表示节点u与v之间存在一条边,且边的权值为l。

当输入用例N=0,K=0时,表示输入终止,且该用例无需处理。

输出格式

每个测试用例输出一个结果。

每个结果占一行。

数据范围

1N1e4
1K5×1e6

1l1e3

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 10010, M = N * 2;

int n, m;
int h[N], e[M], w[M], ne[M], idx;
bool st[N];
int p[N], q[N];//分别表示整个树和当前树

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

int get_size(int u, int fa)  // 求子树大小
{
    if (st[u]) return 0;
    int res = 1;
    for (int i = h[u]; ~i; i = ne[i])
        if (e[i] != fa)
            res += get_size(e[i], u);
    return res;
}

int get_wc(int u, int fa, int tot, int& wc)  // 求重心
{
    if (st[u]) return 0;
    int sum = 1, ms = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;
        int t = get_wc(j, u, tot, wc);
        ms = max(ms, t);
        sum += t;
    }
    ms = max(ms, tot - sum);
    if (ms <= tot / 2) wc = u;
    return sum;
}

void get_dist(int u, int fa, int dist, int& qt)
{
    if (st[u]) return;
    q[qt ++ ] = dist;
    for (int i = h[u]; ~i; i = ne[i])
        if (e[i] != fa)
            get_dist(e[i], u, dist + w[i], qt);
}

int get(int a[], int k)
{
    sort(a, a + k);
    int res = 0;
    for (int i = k - 1, j = -1; i >= 0; i -- )
    {
        while (j + 1 < i && a[j + 1] + a[i] <= m) j ++ ;
        j = min(j, i - 1);
        res += j + 1;
    }
    return res;
}

int calc(int u)
{
    if (st[u]) return 0;
    int res = 0;
    get_wc(u, -1, get_size(u, -1), u);
    st[u] = true;  // 删除重心

    int pt = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i], qt = 0;
        get_dist(j, -1, w[i], qt);
        res -= get(q, qt);
        for (int k = 0; k < qt; k ++ )
        {
            if (q[k] <= m) res ++ ;
            p[pt ++ ] = q[k];
        }
    }
    res += get(p, pt);

    for (int i = h[u]; ~i; i = ne[i]) res += calc(e[i]);
    return res;
}

int main()
{
    while (scanf("%d%d", &n, &m), n || m)
    {
        memset(st, 0, sizeof st);
        memset(h, -1, sizeof h);
        idx = 0;
        for (int i = 0; i < n - 1; i ++ )
        {
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            add(a, b, c), add(b, a, c);
        }

        printf("%d
", calc(0));
    }

    return 0;
}
原文地址:https://www.cnblogs.com/cyq123/p/14002496.html