bzoj3730 震波 解题报告 (动态点分治)

bzoj3730 震波

由于一个小错误, 花了我一个上午的时间.....


题意

一棵 (n) 个点的树 ($ 1 le n le 10^5$), 每个点有一个点权 (w[i]) ((1 le w[i] le 10^4)).

(m) 个询问, 询问有两种,

  1. 将点 (x) 的权值改为 (y).
  2. 询问与点 (x) 距离不超过 (k) 的点的权值和.

思路

若不用修改, 且只有一个询问, 我们可以点分治, 对每个点 (u) 找出所有满足 (dist(u,v) le k - dist(u,x)), 且与 (x) 不在同一子树的点 (v), 计算它们的权值和, 找到答案.

假设依旧不用修改, 但有多个询问, 考虑一下有什么数据是可以反复使用的.

发现由于树的结构与权值都不变, 那么与任一点 (u) 相距 (t) 的点的权值和是不变的, ((u in [1,n], t in [0,n-1])).

所以, 我们可以对每个点 (u) 开一个 (vector), 设为 (v1), 存储(u) 的点分树子树中与它距离为 (t) 的点权和, 并对 (v1) 建立树状数组维护前缀和.


询问时, 在**点分树**上依次往上跳,

当前点 (不是询问点) 为 (u), 用倍增求 (lca) 的方法找出 (u) 在点分树上的父亲 (ft[u]) 与**询问点 (x) **的距离 (len) ,在 (v1[ft[u]]) 上查询 (k-len) 的前缀和.

但仔细思考一下, 发现这样会重复计算与 (x) 在同一子树内的点的权值,

所以, 我们需要再开一个 (vector), 设为 (v2),

(v2[u][t]) 表示在 (u)在点分树的子孙中(ft[u]) 的距离为 (t) 的点的点权和, 类似地, 也对它建立树状数组维护前缀和.

那么, 我们只需要在询问时在点分树上逐层往上跳,

对于每一层的点 (u) , (res+= sum(v1[ft[u]],k-len) - sum(v2[u],k-len)) ((sum) 表示在树状数组上查询到前缀和).

现在, 考虑本题的实际情况 : 要修改, 多个询问.


其实考虑到了如何处理多个询问后, 修改也就不难了.


和询问类似, 修改时也是在点分树上逐层往上跳, 对每个当前点 (u) 算出 (len = dist(ft[u],x)),

然后 $ add(v1[ft[u]], len, y-val[x]), add(v2[u], len, y-val[x])$ ( (add) 表示树状数组的修改操作, (y) 表示要修改为的权值).


这道题就解完了.


还有一点要注意的是, 修改和查询时, 都不要忘记处理(x) 本身.


代码

#include<bits/stdc++.h>
#define uint unsigned int
#define pb push_back
#define sz size
using namespace std;
const int _=1e6+7;
const int __=1e6+7;
const int L=20;
const int inf=0x3f3f3f3f;
bool be;
int n,m,val[_],dis[_],dep[_],f[_][L+7],sz[_],rt,minx=inf,dpt,q[_],top,ft[_];
int lst[_],nxt[__],to[__],tot;
bool vis[_];
vector<int> v1[_],v2[_];   // v1: to self  v2: to father
bool en;
void add(int x,int y){ nxt[++tot]=lst[x]; to[tot]=y; lst[x]=tot; }
void dbing(int u,int fa){
  dep[u]=dep[fa]+1;
  f[u][0]=fa;
  for(int i=1;i<=L;i++)
    f[u][i]=f[f[u][i-1]][i-1];
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa) continue;
    dbing(v,u);
  }
}
void add(int u,int x,int v,int id){
  if(id==1) for(int i=x;i<(int)v1[u].sz();i+=i&(-i)) v1[u][i]+=v;
  else for(int i=x;i<(int)v2[u].sz();i+=i&(-i)) v2[u][i]+=v;
}
int summ(int u,int x,int id){
  int res=0;
  if(id==1) for(int i=min(x,(int)v1[u].sz()-1);i>0;i-=i&(-i)) res+=v1[u][i];
  else for(int i=min(x,(int)v2[u].sz()-1);i>0;i-=i&(-i)) res+=v2[u][i];
  return res;
}
void g_rt(int u,int fa,int sum){
  int maxn=0; sz[u]=1; dis[u]=dis[fa]+1;
  dpt=max(dpt,dis[u]); q[++top]=u;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    g_rt(v,u,sum);
    sz[u]+=sz[v];
    maxn=max(maxn,sz[v]);
  }
  maxn=max(maxn,sum-sz[u]);
  if(maxn<minx){ minx=maxn; rt=u; }
}
void cnt(int u,int fa,int rt){
  dis[u]=dis[fa]+1; dpt=max(dpt,dis[u]); q[++top]=u;
  for(int i=lst[u];i;i=nxt[i])
    if(!vis[to[i]]&&to[i]!=fa) cnt(to[i],u,rt);
}
void calc(int u){
  dpt=top=0;
  for(int i=lst[u];i;i=nxt[i])
    if(!vis[to[i]]) cnt(to[i],0,u);
  v1[u].resize(dpt+2);
  for(int i=1;i<=top;i++) add(u,dis[q[i]]+1,val[q[i]],1);
  add(u,1,val[u],1);
}
void init(int u,int lrt,int sum){
  minx=inf; dpt=top=0;
  g_rt(u,0,sz[u]<sz[lrt] ?sz[u] :sum-sz[lrt]);
  sum=sz[u];
  v2[rt].resize(dpt+2);
  for(int i=1;i<=top;i++) add(rt,dis[q[i]]+1,val[q[i]],2);
  ft[rt]=lrt; vis[rt]=1; u=rt;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(vis[v]) continue;
    init(v,u,sum);
  }
  calc(u);
  vis[u]=0;
}
int Lca(int x,int y){
  if(dep[x]<dep[y]) swap(x,y);
  for(int i=L;i>=0;i--)
    if(dep[f[x][i]]>=dep[y])
      x=f[x][i];
  if(x==y) return x;
  for(int i=L;i>=0;i--)
    if(f[x][i]!=f[y][i]){
      x=f[x][i];
      y=f[y][i];
    }
  return f[x][0];
}
int dist(int x,int y){ return dep[x]+dep[y]-2*dep[Lca(x,y)]; }
int query(int x,int k){
  int len,fa=ft[x],u=x,res=summ(u,k+1,1);
  while(fa){
    len=k-dist(fa,x);
    res+=summ(fa,len+1,1)-summ(u,len+1,2);
    u=ft[u]; fa=ft[u];
  }
  return res;
}
void modify(int x,int v){
  int len,fa=ft[x],u=x;
  add(u,1,v-val[x],1);
  while(fa){
    len=dist(fa,x);
    add(fa,len+1,v-val[x],1);
    add(u,len+1,v-val[x],2);
    //printf("u: %d fa: %d len: %d %d 
",u,fa,len,v-val[x]);
    u=ft[u]; fa=ft[u];
  }
  val[x]=v;
}
void run(){
  int ty,x,y,lst=0;
  for(int i=1;i<=m;i++){
    scanf("%d%d%d",&ty,&x,&y);
    x^=lst; y^=lst;
    if(!ty){ lst=query(x,y); printf("%d
",lst); }
    else modify(x,y);
  }
}
int main(){
  //freopen("x.in","r",stdin);
  //freopen("new.out","w",stdout);
  cin>>n>>m;
  for(int i=1;i<=n;i++) scanf("%d",&val[i]);
  int x,y;
  for(int i=1;i<n;i++){
    scanf("%d%d",&x,&y);
    add(x,y);
    add(y,x);
  }
  dbing(1,0);
  init(1,0,n);
  run();
  return 0;
}
原文地址:https://www.cnblogs.com/BruceW/p/12123726.html