[ JLOI 2014 ] 松鼠的新家

(\)

(Description)


给出一棵(N​)个节点的树,按顺序依次访问(N​)个节点(A_1,...,A_N​),即在树上走每一条(A_i​)(A_{i+1}​)到最短路。

注意,除了第一次直接从起点出发,以后每一次起点都是上一次的终点,且在这两次移动中视为该点只访问一次。

现要求每次访问过的点权值都加一,求最后树上每一个节点的权值。

  • (Nin [1,3 imes 10^5])

(\)

(Solution)


  • 树上差分板子。对于每个(i),在(A_i)(A_{i+1})处打(+1)标记,在(Lca(A_i,A_{i+1}))处打(-1)标记,在(fatherig(Lca(A_i,A_{i+1})ig))处打(-1)标记即可(()其中(father(i))表示(i)的直接父亲,根节点返回值为(0)())

  • 这种做法的合理性在于,我们需要将路径上经过的所有点都打上一次标记,而统计子树和以更新当前节点权值时,这个影响并不会上传到(fatherig(Lca(A_i,A_{i+1})ig))处,所以需要在(fatherig(Lca(A_i,A_{i+1})ig))处抵消掉所有影响。而只在(fatherig(Lca(A_i,A_{i+1})ig))处抵消影响时不够的,因为在两点(Lca)处本次增加的子树和会被累计两次,所以需要在(Lca)处先(-1),再在(fatherig(Lca(A_i,A_{i+1})ig))(-1)

  • 各点点权可以一遍(DFS)求出子树和后更新。注意题目要求每次直接访问的节点只算一次,所以最后在这些节点处权值(-1),注意第一个访问到的点并不需要(-1),还要注意这个操作一定要在(DFS)求子树和之后做

(\)

(Code)


#include<cmath>
#include<queue>
#include<cstdio>
#include<cctype>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 300010
#define R register
#define gc getchar
using namespace std;
 
inline int rd(){
  int x=0; bool f=0; char c=gc();
  while(!isdigit(c)){if(c=='-')f=1;c=gc();}
  while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc();}
  return f?-x:x;
}
 
int n,t,tot,d[N],hd[N],num[N],cnt[N],f[N][20];
struct edge{int to,nxt;}e[N<<1];
inline void add(int u,int v){
  e[++tot].to=v; e[tot].nxt=hd[u]; hd[u]=tot;
}
 
queue<int> q;
inline void bfs(){
  q.push(1); d[1]=1;
  while(!q.empty()){
    int u=q.front(); q.pop();
    for(R int i=hd[u],v;i;i=e[i].nxt)
      if(!d[v=e[i].to]){
        d[v]=d[u]+1; f[v][0]=u;
        for(R int i=1;i<=t;++i) f[v][i]=f[f[v][i-1]][i-1];
        q.push(v);
      }
  }
}
 
inline int lca(int u,int v){
  if(d[u]>d[v]) u^=v^=u^=v;
  for(R int i=t;~i;--i) if(d[f[v][i]]>=d[u]) v=f[v][i];
  if(u==v) return u;
  for(R int i=t;~i;--i) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
  return f[u][0];
}
 
inline void dfs(int u,int fa){
    for(R int i=hd[u],v;i;i=e[i].nxt)
        if((v=e[i].to)!=fa){
            dfs(v,u); cnt[u]+=cnt[v];
        }
}
 
int main(){
  t=log2(n=rd());
  for(R int i=1;i<=n;++i) num[i]=rd();
  for(R int i=1,u,v;i<n;++i){
    u=rd(); v=rd(); add(u,v); add(v,u);
  }
  bfs();
  for(R int i=1;i<n;++i){
    int l=lca(num[i],num[i+1]);
    ++cnt[num[i]]; ++cnt[num[i+1]];
    --cnt[l]; --cnt[f[l][0]];
  }
  dfs(1,0);
  for(R int i=2;i<=n;++i) --cnt[num[i]];
  for(R int i=1;i<=n;++i) printf("%d
",cnt[i]);
  return 0;
}
原文地址:https://www.cnblogs.com/SGCollin/p/9671683.html