P4178 Tree

P4178 Tree

题目描述

给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。

输入格式

第一行输入一个整数 n,表示节点个数。

第二行到第 n 行每行输入三个整数 u,v,w ,表示 u 与 v 有一条边,边权是 w。

第 n+1 行一个整数 k 。

输出格式

一行一个整数,表示答案。

输入输出样例

输入 #1

7
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10

输出 #1

5

说明/提示

数据规模与约定

对于全部的测试点,保证:

  • 1≤n≤(4×10^4)

  • 1≤u,v≤n

  • 0≤w≤(10^3)

  • 0≤k≤(2×10^4)

    ​ 经典的点分治题目。

    ​ 我们先处理出这个联通块内所有点到重心的距离,将所有距离排个序,用双指针,一个从小往大找,另一个从大往下找,这样可以快速的统计出小于等于k的路径,注意要减去同一颗子树内的答案。

    #include <iostream>
    #include <cstdio>
    #include <cctype>
    #include <algorithm>
    
    using namespace std;
    
    inline long long read() {
        long long s = 0, f = 1; char ch;
        while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
        for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
        return s * f;
    }
    
    const int N = 4e4 + 5, inf = 1e9;
    int n, x, y, z, k, cnt, num, ans, root, totsize;
    int a[N], siz[N], dis[N], vis[N], max_siz[N], head[N];
    struct edge { int to, nxt, val; } e[N << 1];
    
    void add(int x, int y, int z) {
        e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y; e[cnt].val = z;
    }
    
    void init() {
        n = read();
        for(int i = 1;i <= n - 1; i++) {
            x = read(); y = read(); z = read();
            add(x, y, z); add(y, x, z);
        }
        k = read();
    }
    
    void get_root(int x, int fa) {
        siz[x] = 1;
        for(int i = head[x]; i; i = e[i].nxt) {
            int y = e[i].to; if(y == fa || vis[y]) continue;
            get_root(y, x); siz[x] += siz[y];
            max_siz[x] = max(max_siz[x], siz[y]);
        }
        max_siz[x] = max(max_siz[x], totsize - siz[x]);
        if(max_siz[x] < max_siz[root]) root = x;
    }
    
    void calc_dis(int x, int fa) {
        a[++num] = dis[x];
        for(int i = head[x]; i; i = e[i].nxt) {
            int y = e[i].to; if(y == fa || vis[y]) continue;
            dis[y] = dis[x] + e[i].val;
            calc_dis(y, x);
        }
    }
    
    void solve(int x) {
        vis[x] = 1; dis[x] = num = 0; 
        calc_dis(x, 0);
        sort(a + 1, a + num + 1);
        int r = num;
        for(int i = 1;i <= num; i++) {
            while(a[i] + a[r] > k && r >= i) r--;
            if(r < i) break;
            ans += (r - i);
        }
        for(int i = head[x]; i; i = e[i].nxt) {
            int y = e[i].to; if(vis[y]) continue;
            dis[y] = e[i].val; num = 0;
            calc_dis(y, x);
            sort(a + 1, a + num + 1);
            r = num;
            for(int j = 1;j <= num; j++) {
                while(a[j] + a[r] > k && r >= j) r--;
                if(r < j) break;
                ans -= (r - j);
            }
            max_siz[root = 0] = inf; totsize = siz[y];
            get_root(y, 0); solve(root);
        }
    }
    
    void work() {
        max_siz[root = 0] = inf; totsize = n;
        get_root(1, 0); solve(root);
        printf("%d", ans);
    }
    
    int main() {
    
        init();
        work();
    
        return 0;
    }
    
原文地址:https://www.cnblogs.com/czhui666/p/13599748.html