LUOGU P4178 Tree

题目描述

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K
输入输出格式
输入格式:

N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k

输出格式:

一行,有多少对点之间的距离小于等于k

输入输出样例
输入样例#1:

7
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10

输出样例#1:

5

解题思路

先%一发GhostCai TQL!!!
点分治+树状数组。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<queue>

using namespace std;
const int MAXN = 40005;

inline int rd() {
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch))  {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    return x*f;
}

int f[MAXN],n,K,sum,mx,siz[MAXN],ans;
int head[MAXN],cnt,stk[MAXN],top,rt;
int to[MAXN<<1],nxt[MAXN<<1],val[MAXN<<1];
queue<int> Q;
bool vis[MAXN];

inline void add(int bg,int ed,int w){
    to[++cnt]=ed,nxt[cnt]=head[bg],val[cnt]=w,head[bg]=cnt;
}

inline void update(int x,int w){
    for(;x<=K;x+=x&-x) f[x]+=w;
}

inline int query(int x){
    int ret=0;
    for(;x;x-=x&-x) ret+=f[x];
    return ret;
}

void dfs(int x,int fa){
    siz[x]=1;
    for(register int i=head[x];i;i=nxt[i]){
        int u=to[i];if(vis[u] || u==fa) continue;
        dfs(u,x);siz[x]+=siz[u];
    }
}

void getrt(int x,int fa){
    int k=0;
    for(register int i=head[x];i;i=nxt[i]){
        int u=to[i];if(vis[u] || u==fa) continue;
        getrt(u,x);k=max(k,siz[u]);
    }
    k=max(k,sum-siz[x]);
    if(k<mx) {mx=k;rt=x;}
}

void getdis(int x,int fa,int dis){
    if(dis<=K){
        Q.push(dis);stk[++top]=dis;
        ans+=query(K-dis);
    }
    else return;
    for(register int i=head[x];i;i=nxt[i]){
        int u=to[i];if(vis[u] || u==fa) continue;
        getdis(u,x,dis+val[i]);
    }
}

inline void getans(int x){
    vis[x]=1;
    for(register int i=head[x];i;i=nxt[i]){
        int u=to[i];if(vis[u]) continue;
        getdis(u,x,val[i]);ans+=Q.size();
        while(!Q.empty()){
            int now=Q.front();Q.pop();
            update(now,1);
        }
    }
    for(register int i=1;i<=top;i++) update(stk[i],-1);top=0;
    for(register int i=head[x];i;i=nxt[i]){
        int u=to[i];if(vis[u]) continue;
        dfs(u,x);mx=sum=siz[u];rt=0;getrt(u,x);
//      cout<<u<<" "<<rt<<endl;
        getans(rt);
    }
}

int main() {
    n=rd();
    for(register int i=1;i<n;i++) {
        int x=rd(),y=rd(),z=rd();
        add(x,y,z);add(y,x,z);
    }K=rd();sum=mx=n;
    dfs(1,0);getrt(1,0);getans(rt);
    cout<<ans<<endl;
    return 0;
}
 
原文地址:https://www.cnblogs.com/sdfzsyq/p/9676891.html