P3899 [湖南集训]谈笑风生

P3899 [湖南集训]谈笑风生

题目大意

n个节点的树,q次查询,每次查询给出a,k求三元组的数量(a,b,c),(a,b,c)的定义为:a、b均为c的祖先且距离<=k

离线,启发式合并线段树,长链剖分当然都能过这题

这里讲讲主席树的做法

dfs序建树

a为b的祖先时 查询a子树内深度<=dep[a]+k的节点的子树和

b为a祖先时 乘法原理就好

My complete code: 

#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const LL maxn=3e5+5;
const LL inf=1e18;
inline LL read(){
    LL x=0,f=1; char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-') f=-1; c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=x*10+c-'0'; c=getchar();
    }
    return x*f;
}
struct node{
    LL to,next;
}dis[maxn<<1];
LL num,cnt,n,q,nod; 
LL head[maxn],dfn[maxn],low[maxn],root[maxn],a[maxn],b[maxn],dep[maxn],size[maxn];
LL rt[maxn<<7],lt[maxn<<7],sum[maxn<<7],date[maxn<<7];
inline void add(LL u,LL v){
    dis[++num]=(node){v,head[u]}; head[u]=num;
}
void dfs(LL u,LL fa){
    dfn[u]=++num;
    a[num]=u;
    b[num]=dep[u];
    size[u]=1;
    for(LL i=head[u];i;i=dis[i].next){
        LL v=dis[i].to;
        if(v==fa)
            continue;
        dep[v]=dep[u]+1;
        dfs(v,u);
        size[u]+=size[v];
    }
    low[u]=num;
}
void update(LL &now,LL pre,LL l,LL r,LL c,LL u){
    now=++nod;
    date[now]=date[pre]+1;
    LL mid=(l+r)>>1;
    if(l==r){
        sum[now]=sum[pre]+size[u];
        return;
    }
    if(c<=mid){
        rt[now]=rt[pre];
        update(lt[now],lt[pre],l,mid,c,u);
    }else{
        lt[now]=lt[pre];
        update(rt[now],rt[pre],mid+1,r,c,u);
    }
    sum[now]=sum[lt[now]]+sum[rt[now]];
}
LL query(LL pre,LL next,LL l,LL r,LL c){
    LL mid=(l+r)>>1;
    if(l==r)
        return sum[next]-sum[pre];
    if(c-1<=mid)
        return query(lt[pre],lt[next],l,mid,c);
    else 
        return query(rt[pre],rt[next],mid+1,r,c)+(sum[lt[next]]-sum[lt[pre]]);
}
int main(){
    n=read(); q=read();
    for(LL i=1;i<n;++i){
        LL u=read(),v=read();
        add(u,v); add(v,u);
    }
    num=0;
    dfs(1,0);
    cnt=n;
    b[++cnt]=inf;
    sort(b+1,b+1+cnt);
    cnt=unique(b+1,b+1+cnt)-b-1;
    for(LL i=1;i<=n;++i)
        --size[i];
    for(LL i=1;i<=n;++i){
        LL u=a[i];
        LL k=lower_bound(b+1,b+1+cnt,dep[u])-b;
        update(root[i],root[i-1],1,cnt,k,u);
    }
    while(q--){
        LL p=read(),k=read();
        LL l=dfn[p],r=low[p];
        LL now=upper_bound(b+1,b+1+cnt,dep[p]+k)-b;
        printf("%lld
",query(root[l],root[r],1,cnt,now)+(size[p])*min(dep[p],k));
    }
    return 0;
}

  

原文地址:https://www.cnblogs.com/y2823774827y/p/10090556.html