2014百度之星复赛解题报告:Query on the tree

Query on the tree

时间限制:1s  内存限制: 65536K

问题描述
度度熊最近沉迷在和树有关的游戏了,他一直认为树是最神奇的数据结构。一天他遇到这样一个问题:
有一棵树,树的每个点有点权,每次有三种操作:
1. Query x 表示查询以x为根的子树的权值和。
2. Change x y 表示把x点的权值改为y。
3. Root x 表示把x变为根。
现在度度熊想请更聪明的你帮助解决这个问题。

输入
第一行为数据组数T                              
每组数据第一行为N ,表示树的节点数。
后面 行每行有两个数 ,表示 之间有一条边 。初始时树是以1号节点为根节点。
之后的一行为 个数表示这 个点的点权。
然后为整数Q为操作次数。
之后的Q行为描述中的三种操作。

输出
对于第k组输入数据,第一行输出Case #k接下来对于每个”Queryx”操作,输出以x为根的子数和。

样例输入
2
5
1 2
1 3
3 4
3 5
1 2 3 4 5
5
Query 1
Change 3 10
Query 1
Root 4
Query 3
8
1 2
1 3
3 4
4 5
5 6
5 7
4 8
1 2 3 4 5 6 7 8
5
Query 1
Query 3
Root 5
Query 3
Query 1
样例输出
Case #1:
15
22
18
Case #2:
36
33
6
3


解题报告:
树上的查询一共有三种操作,如果只是考虑前两种操作QueryChange,则需要一种高效的数据结构来支持子树和的查询和更新。
         采取类似于LCA在线算法的方式,先DFS得到树上顶点序列,通过顶点序列可以知道每颗子树的范围。利用树状数组的方式来记录每个顶点的权值的变更。那么QueryChange的时间复杂度都为O(lgn)
         当引入第三种操作-变更根节点后,并不需要调整树的结构,而只是要在查询操作的时候做些处理。
         如果Query xx为当前root节点,则直接输出当前所有节点的权值和,可以用一个变量SumOfAllTree来记录整颗树所有节点的权值和,查询为O(1)的复杂度
         如果x在原树上不为当前root节点的祖先,即lca(x,root) != x,那么直接输出x节点所在的子树和SubTree(x)
         否则可找出rootx这条路径上x的儿子节点y,那么在当前root的条件下x对应的子树的和为SumOfAllTree-SubTree(y)
         lca和求y的时间复杂度都可以做到O(lgn)。因此加上预处理后整体的时间复杂度为O(nlgn + Qlgn)

解题代码:

#include "iostream"
#include "cstring"
#include "cstdio"
#include "vector"
#define F first
#define S second
#define PB push_back
#define MP make_pair
using namespace std;
const int N  = 10010;
const int D = 20;
vector<int>e[N];
int go[N][D],depth[N],l[N],r[N];
int time_stamp;
void dfs(int u,int p)
{
        depth[u]=p==-1?0:depth[p]+1;
        go[u][0]=p;
        l[u]=++time_stamp;
        for(int i=0;go[u][i]!=-1;i++){
                go[u][i+1]=go[go[u][i]][i];
        }
        for(int i=0;i<e[u].size();i++){
                int v=e[u][i];
                if(v!=p){
                        dfs(v,u);
                }
        }
        r[u]=time_stamp;
}
int jump(int u,int d)
{
        for(int i=D-1;i>=0;i--){
                if(d>=(1<<i)){
                        u=go[u][i];
                        d-=1<<i;
                }
        }
        return u;
}
int lca(int u,int v)
{
        if(depth[u]<depth[v]){
                swap(u,v);
        }
        u=jump(u,depth[u]-depth[v]);
        for(int i=D-1;i>=0;i--){
                if(go[u][i]!=go[v][i]){
                        u=go[u][i];
                        v=go[v][i];
                }
        }
        return u==v?u:go[u][0];
}
int lowbit(int x)
{
        return x&(-x);
}
char com[20];
int n;
int val[N],sumroot[N],sumroad[N];
void init(int n)
{
        time_stamp=0;
        memset(sumroot,0,sizeof(sumroot));
        memset(sumroad,0,sizeof(sumroad));
        memset(go, -1, sizeof(go));
    for(int i=1;i<=n;i++){
                e[i].clear();
        }
}
void update(int a[],int x,int v)
{
        while(x<=n){
                a[x]+=v;
                x+=lowbit(x);
        }
}
int get(int a[],int x)
{
        int sum=0;
        while(x>0){
                sum+=a[x];
                x-=lowbit(x);
        }
        return sum;
}
int getsum(int a[],int l,int r)
{
        return get(a,r)-get(a,l-1);
}
int main()
{
    int KK = 1;
        int T,Q,x,y,root,allval;
        scanf("%d",&T);
        while(T--){
                scanf("%d",&n);
                init(n);
                root=1;
                allval=0;
                for(int i=1;i<n;i++){
                        scanf("%d%d",&x,&y);
                        e[x].PB(y);
                        e[y].PB(x);
                }
                dfs(1,-1);
                for(int i=1;i<=n;i++){
                        scanf("%d",&val[i]);
                        update(sumroad,l[i],val[i]);
                        update(sumroad,r[i]+1,-val[i]);
                        update(sumroot,l[i],val[i]);
                        allval+=val[i];
                }
        printf("Case #%d:
", KK++);
                scanf("%d",&Q);
                while(Q--){
                        scanf("%s",com);
                        if(com[0]=='Q'){
                                scanf("%d",&x);
                                if(x==root){
                                        printf("%d
",allval);
                                }else if(lca(x,root)!=x){
                                        printf("%d
",getsum(sumroot,l[x],r[x]));
                                }else{
                                        int tmp=jump(root,depth[root]-depth[x]-1);
                                        printf("%d
",allval-getsum(sumroot,l[tmp],r[tmp]));
                                }
                        }else if(com[0]=='C'){
                                scanf("%d%d",&x,&y);
                                update(sumroad,l[x],y-val[x]);
                                update(sumroad,r[x]+1,val[x]-y);
                                update(sumroot,l[x],y-val[x]);
                                allval+=y-val[x];
                                val[x]=y;
                        }else{
                                scanf("%d",&root);
                        }
                }
        }
        return 0;
}


原文地址:https://www.cnblogs.com/hosealeo/p/4190491.html