hdu4918 Query on the subtree

  树分治,设当前树的分治中心为x,其子树分治中心为y,则设father[y]=x,分治下去则可以得到一颗重心树,而且树的深度是logn。

  询问操作(x,d),只需要查询重心树上x到重心树根节点上的节点的累加和。假设当前节点是y,那么节点y可以贡献的答案是那些以y为分治中心且到y距离为d-dis(x,y)的节点的总和。当然这样可能会出现重复的情况,重复情况只会出现在包含x的那颗子树上,因此减掉即可。修改操作类似。复杂度O(nlognlogn)

代码

#include<cstdio>
#include<cstring>
#define N 200010
#define LL long long
using namespace std;
int dp,pre[N],p[N],tt[N],vis[N],father[N],s[N],tmp,m;
int n,a,b,i,w[N],L,cnt,tot,len[N],Len[N],start[N],Start[N],v[N];
int deep[N],ss[N][21],fa[N];
int c[N*50];
int min(int a,int b)
{
    if (a<b) return a;return b;
}
int lowbit(int x)
{
    return x&(-x);
}
void cc(int x,int w,int y)
{
    while (x<=L)
    {
        c[y+x]+=w;
        x+=lowbit(x);
    }
}
LL sum(int x,int y)
{
    LL ans=0;
    while (x>0)
    {
        ans+=c[y+x];
        x-=lowbit(x);
    }
    return ans;
}
void link(int x,int y)
{
    dp++;pre[dp]=p[x];p[x]=dp;tt[dp]=y;
}
void gao(int x)
{
    int i;
    i=p[x];
    while (i)
    {
        if (tt[i]!=fa[x])
        {
            fa[tt[i]]=x;
            deep[tt[i]]=deep[x]+1;
            gao(tt[i]);
        }
        i=pre[i];
    }
}
int lca(int x,int y)
{
    if(deep[x]>deep[y])x^=y^=x^=y;
    int i;
    for(i=19;i>=0;i--)
    {
        if(deep[y]-deep[x]>=(1<<i))
        {
            y=ss[y][i];
        }
    }
    if(x==y)return x;
    for(i=19;i>=0;i--)
    {
        if(ss[x][i]!=ss[y][i])
        {
            x=ss[x][i];
            y=ss[y][i];
        }
    }
    return fa[x];
}
void getroot(int x,int fa,int sum)
{
    int i,flag=0;
    i=p[x];s[x]=1;
    while (i)
    {
        if ((!vis[tt[i]])&&(tt[i]!=fa))
        {
            getroot(tt[i],x,sum);
            s[x]+=s[tt[i]];
            if (s[tt[i]]>sum/2) flag=1;
        }
        i=pre[i];
    }
    if (sum-s[x]>sum/2) flag=1;
    if (!flag) tmp=x;
}
void dfs(int x,int fa,int dis)
{
    int i;
    i=p[x];
    if (dis>cnt) cnt=dis;
    v[dis]+=w[x];
    while (i)
    {
        if ((!vis[tt[i]])&&(tt[i]!=fa))
            dfs(tt[i],x,dis+1);
        i=pre[i];    
    }
}
void clear()
{
    int i;
    for (i=1;i<=cnt;i++)
    v[i]=0;cnt=0;
}
int work(int x,int fa,int sum)
{
    int i,root,t;
    getroot(x,0,sum);
    root=tmp;
    father[root]=fa;
    i=p[root];
    vis[root]=1;
    while (i)
    {
        if (!vis[tt[i]])
        {
            if (s[root]>s[tt[i]])
            t=work(tt[i],root,s[tt[i]]);
            else
            t=work(tt[i],root,sum-s[root]);
            //------dist(root,point in subtree t)-------- 
            
            dfs(tt[i],0,2);
            Len[t]=cnt;
            Start[t]=tot;
            for (int j=1;j<=cnt;j++)
            {
                L=cnt;
                cc(j,v[j],Start[t]);
            }
            tot+=cnt;
            clear();
            
        }
        i=pre[i];
    }
    vis[root]=0;
    
//--------dist(root,all point)----------
    
    dfs(root,0,1);
    len[root]=cnt;
    start[root]=tot;
    for (i=1;i<=cnt;i++)
    {
        L=cnt;
        cc(i,v[i],start[root]);
    }
    tot+=cnt;
    clear();
    
    return root;
}
LL query(int x,int d)
{
    int y=0,z=x,t;
    LL ans=0;
    while (x)
    {
        t=lca(x,z);
        t=deep[x]+deep[z]-2*deep[t];
        L=len[x];
        ans+=sum(min(L,d-t+1),start[x]);
        
        if (y)
        {
            L=Len[y];
            ans-=sum(min(L,d-t+1),Start[y]);
        }
        y=x;
        x=father[x];
    }
    return ans;
}
void change(int x,int w)
{
    int y=0,z=x,t;
    while (x)
    {
        t=lca(x,z);
        t=deep[x]+deep[z]-2*deep[t];
        L=len[x];
        cc(t+1,w,start[x]);
        
        if (y)
        {
            L=Len[y];
            cc(t+1,w,Start[y]);
        }    
        y=x;
        x=father[x];
    }
}
int main()
{
    while (scanf("%d%d",&n,&m)!=EOF)
    {
    dp=0;memset(p,0,sizeof(p));
    for (i=1;i<=tot;i++)
    c[i]=0;tot=0;
    
    for (i=1;i<=n;i++)
        scanf("%d",&w[i]);
    for (i=1;i<n;i++)
    {
        scanf("%d%d",&a,&b);
        link(a,b);
        link(b,a);
    }
    gao(1);
    for(i=1;i<=n;i++)
        ss[i][0]=fa[i];
    for(int h=1;h<20;h++)
    {
        for(i=1;i<=n;i++)
        {
            ss[i][h]=ss[ss[i][h-1]][h-1];
        }
    }
    work(1,0,n);
    
    for (i=1;i<=m;i++)
    {
        getchar();
        char ch;
        scanf("%c%d%d",&ch,&a,&b);    
        if (ch=='?')
        printf("%I64d
",query(a,b));
        else
        {
            change(a,b-w[a]);
            w[a]=b;
        }    
    }

    }
}

  

原文地址:https://www.cnblogs.com/fzmh/p/4703571.html