P4178 Tree

题目描述

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

输入输出格式

输入格式:

N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k

输出格式:

一行,有多少对点之间的距离小于等于k

输入输出样例

输入样例#1: 
7
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10
输出样例#1: 
5

说明

k20000 对于任意一条管道边权wi1000

代码

与点分治1树上直接统计不同的是这题用的指针扫描数组

根据排序数组具有单调性,用了l,r两个指针来维护距离,同时利用容斥剪掉不合法的方案(如图)

即对于每个根节点的子树ans-=calc(to,e[i].val)(两个不合法节点之间的距离相对于根节点为e[i].val

sort

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=40000+100;
int head[maxn];
int s[maxn],ms[maxn];
int vis[maxn];
int p[maxn];
int d[maxn],dis[maxn];
struct edge
{
    int to,next,val;
}e[maxn<<2];
int size=0;
int n,m;
int sum,rt;
int ans=0;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
void addedge(int u,int v,int w)
{
    e[++size].to=v;e[size].val=w;e[size].next=head[u];head[u]=size;
}
void tc(int u,int fa)
{
    s[u]=1;ms[u]=0;
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa||vis[to])continue;
        tc(to,u);
        s[u]+=s[to];
        ms[u]=max(ms[u],s[to]);
    }
    ms[u]=max(ms[u],sum-ms[u]);
    if(ms[u]<ms[rt])rt=u;
}
void dfs(int u,int fa)
{
    d[++d[0]]=dis[u];
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa||vis[to])continue;
        dis[to]=dis[u]+e[i].val;
        dfs(to,u);
    }
}
int calc(int u,int x)
{
    int l=1,r=0,res=0;
    d[0]=0,dis[u]=x;
    dfs(u,0);
    for(int i=1;i<=d[0];i++)if(d[i]<=m)p[++r]=d[i];
    sort(p+1,p+1+r); 
    while(l<=r)
    {
        if(p[l]+p[r]<=m)
        res+=r-l,++l;
        else --r;
    }
    return res;
}
void solve(int u)
{
    vis[u]=1;ans+=calc(u,0);
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(vis[to])continue;
        ans-=calc(to,e[i].val);
        sum=s[to],ms[rt=0]=inf;
        tc(to,u),solve(rt);
    }
}
int main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read(),w=read();
        addedge(u,v,w);
        addedge(v,u,w);
    }
    m=read();
    sum=n,ms[rt]=inf;
    tc(1,0);
    solve(rt);
    printf("%d",ans);
    return 0;
} 
View Code

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=40000+100;
int head[maxn];
int s[maxn],ms[maxn];
int vis[maxn];
int p[maxn],t[maxn];
int d[maxn],dis[maxn];
struct edge
{
    int to,next,val;
}e[maxn<<2];
int size=0;
int n,m;
int sum,rt;
int ans=0;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
void addedge(int u,int v,int w)
{
    e[++size].to=v;e[size].val=w;e[size].next=head[u];head[u]=size;
}
void tc(int u,int fa)
{
    s[u]=1;ms[u]=0;
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa||vis[to])continue;
        tc(to,u);
        s[u]+=s[to];
        ms[u]=max(ms[u],s[to]);
    }
    ms[u]=max(ms[u],sum-ms[u]);
    if(ms[u]<ms[rt])rt=u;
}
void dfs(int u,int fa)
{
    d[++d[0]]=dis[u];
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa||vis[to])continue;
        dis[to]=dis[u]+e[i].val;
        dfs(to,u);
    }
}
int calc(int u,int x)
{
    int l=1,r=0,res=0;
    d[0]=0,dis[u]=x;
    dfs(u,0);
    for(int i=1;i<=d[0];i++)if(d[i]<=m)++p[d[i]];
    for(int i=0;i<=m;i++)
    for(int j=1;j<=p[i];j++)t[++r]=i;
        while(l<=r)
        {
            if(t[l]+t[r]<=m)
            res+=r-l,++l;
            else --r;
        }
    for(int i=1;i<=d[0];i++)
    p[d[i]]=0;
    return res;
}
void solve(int u)
{
    vis[u]=1;ans+=calc(u,0);
    for(int i=head[u];i;i=e[i].next)
    {
        int to=e[i].to;
        if(vis[to])continue;
        ans-=calc(to,e[i].val);
        sum=s[to],ms[rt=0]=inf;
        tc(to,u),solve(rt);
    }
}
int main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read(),w=read();
        addedge(u,v,w);
        addedge(v,u,w);
    }
    m=read();
    sum=n,ms[rt]=inf;
    tc(1,0);
    solve(rt);
    printf("%d",ans);
    return 0;
} 
View Code
原文地址:https://www.cnblogs.com/DriverBen/p/10999295.html