[湖南集训]谈笑风生(主席树)

设 T 为一棵有根树,我们做如下的定义:

• 设 a 和 b 为 T 中的两个不同节点。如果 a 是 b 的祖先,那么称“a 比 b 不知道高明到哪里去了”。

• 设 a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数 x,那么称“a 与 b 谈笑风生”。

给定一棵 n 个节点的有根树 T,节点的编号为 1 ∼ n,根节点为 1 号节点。你需要回答 q 个询问,询问给定两个整数 p 和 k,问有多少个有序三元组 (a; b; c) 满足:

  1. a、 b 和 c 为 T 中三个不同的点,且 a 为 p 号节点;

  2. a 和 b 都比 c 不知道高明到哪里去了;

  3. a 和 b 谈笑风生。这里谈笑风生中的常数为给定的 k

Solution

其实看标题就知道这题的解法了。


发现我们求的是min(deep[p]-1,k)*(size[p]-1)加上以p为根的子树里所有深度在deep[p]+1~deep[p]+k的点的size和。

这个东西怎么求?

按照dfs序建立主席树,下标为节点深度。

Code

#include<iostream>
#include<cstdio>
#define N 300003
using namespace std;
typedef long long ll; 
int size[N],deep[N],mad,head[N],tot,top,ji,dfn[N],hui[N],L[N*22],R[N*22],re[N],n,q,x,y,T[N];
ll tr[N*22];
struct ds{
    int n,to;
}e[N<<1];
inline void add(int u,int v){
    e[++tot].n=head[u];
    e[tot].to=v;
    head[u]=tot;
}
void dfs(int u,int fa){
    size[u]=1;deep[u]=deep[fa]+1;
    mad=max(mad,deep[u]);
    dfn[u]=++top;
    re[top]=u;
    for(int i=head[u];i;i=e[i].n){
        int v=e[i].to;
        if(v==fa)continue;
        dfs(v,u);
        size[u]+=size[v];
    }
    hui[u]=top;
}
int build(int l,int r){
    int p=++ji;
    if(l==r)return p;
    int mid=(l+r)>>1;
    L[p]=build(l,mid);R[p]=build(mid+1,r);
    return p;
}
int update(int pre,int l,int r,int x,ll y){
    int p=++ji;
    L[p]=L[pre];R[p]=R[pre];tr[p]=tr[pre]+y;
    if(l==r)return p;
    int mid=(l+r)>>1;
    if(mid>=x)L[p]=update(L[pre],l,mid,x,y);
    else R[p]=update(R[pre],mid+1,r,x,y);
    return p;
} 
int rd(){
    int x=0;char c=getchar();
    while(!isdigit(c))c=getchar();
    while(isdigit(c)){
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x;
}
ll query(int pre,int now,int l,int r,int LL,int RR){
    if(l>=LL&&r<=RR)return tr[now]-tr[pre];
    int mid=(l+r)>>1;
    ll ans=0;
    if(mid>=LL)ans+=query(L[pre],L[now],l,mid,LL,RR);
    if(mid<RR)ans+=query(R[pre],R[now],mid+1,r,LL,RR);
    return ans;
}
int main(){
    n=rd();q=rd(); int pu,k;
    for(int i=1;i<n;++i)x=rd(),y=rd(),add(x,y),add(y,x);
    dfs(1,0);
    T[0]=build(1,mad);
    for(int i=1;i<=n;++i)T[i]=update(T[i-1],1,mad,deep[re[i]],size[re[i]]-1); 
    while(q--){
        pu=rd();k=rd();ll num=0;
        if(deep[pu]!=mad)num=query(T[dfn[pu]],T[hui[pu]],1,mad,deep[pu]+1,min(deep[pu]+k,mad));
        printf("%lld
",((ll)min((ll)deep[pu]-1,(ll)k))*((ll)size[pu]-1)+num);
    } 
    return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/9600350.html