【Comet OJ

大大大数据结构题. 

我们发现,如果 $k=2$,答案就是树的直径.   

而 $k>2$ 时,相当于选择 $k$ 个叶子,使得这些叶子的并最大.   

我们有一个显然的贪心:$k+1$ 的答案一定是在 $k$ 的答案上加一个叶子.           

如果不考虑修改,这其实就是长链剖分.  

即 $k$ 时的答案就是大小为前 $k$ 大的链之和.  

而由于有修改,我们就需要通过 $LCT$ 中的 Access 操作来动态维护这个长链剖分的状态.  

然后要讨论一下什么时候需要换根.   

code: 

#include <cstdio> 
#include <cstring>  
#include <string>  
#include <algorithm>

#define ll long long   
#define N 200007  
#define INF 1e14  

using namespace std;

namespace IO 
{   
    inline void setIO(string s) 
    {
        string in=s+".in"; 
        string out=s+".out"; 
        freopen(in.c_str(),"r",stdin);  
        // freopen(out.c_str(),"w",stdout);  
    }
};  

int RT;   
namespace seg
{          
    #define lson s[x].ls 
    #define rson s[x].rs 
    struct data 
    {    
        int ls,rs,sum;  
        ll sum2;  
    }s[N*30];             
    int tot;   
    inline int newnode() { return ++tot; }  

    void update(int &x,ll l,ll r,ll p,int v) 
    {
        if(!x) 
            x=newnode();  
        s[x].sum2+=p*v;   
        s[x].sum+=v;   
        if(l==r)  
            return;   
        ll mid=(l+r)>>1;  
        if(p<=mid)   
            update(lson,l,mid,p,v);  
        else 
            update(rson,mid+1,r,p,v);   
    }          

    ll get_kth(int x,ll l,ll r,int k) 
    {      
    	if(!x||!s[x].sum2) 
    		return 0; 
        if(l==r)     
            return min(s[x].sum2,(ll)l*k);      
        ll mid=(l+r)>>1;     
        if(s[rson].sum<k)        
            return s[rson].sum2+get_kth(lson,l,mid,k-s[rson].sum);      
        else   
            return get_kth(rson,mid+1,r,k);    
    }

    #undef lson 
    #undef rson 
};   

int rt; 
namespace LCT 
{             
    #define lson s[x].ch[0] 
    #define rson s[x].ch[1] 
    struct data 
    {      
        int ch[2],f,rev,L,R;      
        ll val,sum;    
    }s[N];   
    int sta[N];          
    inline int get(int x) { return s[s[x].f].ch[1]==x; }  
    inline int isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x; }          

    inline void mark(int x) 
    {
        swap(s[x].L,s[x].R);  
        swap(lson,rson); 
        s[x].rev^=1;      
    }

    inline void pushdown(int x) 
    {
        if(s[x].rev) 
        {
            if(lson)   
                mark(lson);  
            if(rson)  
                mark(rson);   
            s[x].rev^=1;  
        }
    }

    inline void pushup(int x) 
    {     
        s[x].sum=s[lson].sum+s[rson].sum+s[x].val;     
        s[x].L=s[x].R=x;  
        if(lson)  
            s[x].L=s[lson].L;   
        if(rson) 
            s[x].R=s[rson].R;   
    } 

    inline void rotate(int x) 
    {
        int old=s[x].f,fold=s[old].f,which=get(x);  
        if(!isr(old)) 
            s[fold].ch[s[fold].ch[1]==old]=x;   
        s[old].ch[which]=s[x].ch[which^1];     
        if(s[old].ch[which])  
            s[s[old].ch[which]].f=old;     
        s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;  
        pushup(old),pushup(x);  
    }

    inline void splay(int x) 
    { 
        int u=x,v=0,fa;  
        for(sta[++v]=u;!isr(u);u=s[u].f)   
            sta[++v]=s[u].f;     
        for(;v;--v)  
            pushdown(sta[v]);   
        for(u=s[u].f;(fa=s[x].f)!=u;rotate(x))   
            if(s[fa].f!=u)    
                rotate(get(fa)==get(x)?fa:x);   
    }      

    inline void Access(int x,int y) 
    {
        for(;x;y=x,x=s[x].f)  
        {
            splay(x);      
            if(s[x].L==rt) 
            {    
            	if(s[rson].sum>s[lson].sum)             
            	{      
            		rt=s[x].R;         
            		mark(x),pushdown(x);  
            	}   
            }  
            if(s[y].sum>s[rson].sum) 
            {     
                seg::update(RT,0,INF,s[x].sum,-1);     
                seg::update(RT,0,INF,s[rson].sum,1);   
                seg::update(RT,0,INF,s[y].sum,-1);           
                rson=y,pushup(x);    
                seg::update(RT,0,INF,s[x].sum,1);     
            }
            else   
                break;  
        }
    }         

    #undef lson 
    #undef rson 
};    
 
int edges,an,n,cnt;    
int hd[N],to[N<<1],nex[N<<1],id1[N],id2[N],val[N];  
ll d1[N],d2[N];     

inline void add(int u,int v) 
{    
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;   
}  

void dfs1(int x,int ff) 
{     
    id1[x]=id2[x]=x;       
    d1[x]=val[x],d2[x]=0;        
    for(int i=hd[x];i;i=nex[i]) 
    {  
        int y=to[i];   
        if(y==ff) 
            continue;   
        dfs1(y,x);     
        if(d1[y]+val[x]>d1[x])   
        {    
            d2[x]=d1[x],id2[x]=id1[x];   
            d1[x]=d1[y]+val[x],id1[x]=id1[y];         
        }
        else if(d1[y]+val[x]>d2[x])    
        {
            d2[x]=d1[y],id2[x]=id1[y];     
        }
    }     
    if(d1[x]+d2[x]>d1[an]+d2[an])     
        an=x;               
}   

void dfs2(int x,int ff) 
{      
    ll maxv=0;  
    int be=0;   
    for(int i=hd[x];i;i=nex[i]) 
    {
        int y=to[i];   
        if(y==ff)  
            continue;   
        dfs2(y,x);     
        if(LCT::s[y].sum>maxv) 
            maxv=LCT::s[y].sum,be=y;       
    }             

    // 叶子    
    if(!be)   
        ++cnt;   

    LCT::s[x].f=ff;    
    LCT::s[x].val=val[x];     

    for(int i=hd[x];i;i=nex[i]) 
    {
        int y=to[i]; 
        if(y==ff)  
            continue;    
        if(y==be) 
            LCT::s[x].ch[1]=be;          
        else     
            seg::update(RT,0,INF,LCT::s[y].sum,1);    
    }      

    LCT::pushup(x);   
}   

int main() 
{ 
    // IO::setIO("input");   
    int i,j;           
    scanf("%d",&n);   
    for(i=1;i<n;++i) 
    {
        int x,y;  
        scanf("%d%d",&x,&y);      
        add(x,y),add(y,x);  
    }
    ll maxx=0; 
    for(i=1;i<=n;++i)   
        scanf("%d",&val[i]),maxx=max(maxx,(ll)val[i]);  
    dfs1(1,0);              
    dfs2(id1[an],0);   
    rt=id1[an];            
    seg::update(RT,0,INF,LCT::s[rt].sum,1);                         
    int Q; 
    scanf("%d",&Q);        
    for(i=1;i<=Q;++i) 
    {
        int op,x,y,k; 
        scanf("%d",&op);   
        if(op==0) 
        {     
            scanf("%d%d",&x,&y);             
    
            LCT::splay(x);
            seg::update(RT,0,INF,LCT::s[x].sum,-1);    
            seg::update(RT,0,INF,LCT::s[x].sum+(ll)y,1);                        
            LCT::s[x].val+=y;   
            maxx=max(maxx,LCT::s[x].val);    
            LCT::pushup(x);             // 更新当前重链               

            int fa=LCT::s[x].f;              
            LCT::Access(fa,x);            // 更新树  
        }  
        if(op==1) 
        {                      
            int k;      
            scanf("%d",&k); 
            if(k==1)   
                printf("%lld
",maxx);  
            else   
                printf("%lld
",seg::get_kth(RT,0,INF,k-1));        
        }
    }
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/12447566.html