2018青岛网络预选赛 B."Red Black Tree"(LCA+二分答案)

传送门

•参考资料

  [1]:ACM-ICPC 2018 青岛赛区网络预赛 B. Red Black Tree (LCA、二分)

•题意 

  给出一棵树,根节点为1。

  每条边有一个权值,树上有红色结点 m 个,其花费为 0 ,其余为黑色;

  每个黑色结点的花费为其到最近红色祖先的经过的路径权值之和。

  有 q 次询问,每次给出一个点集;

  问将树上任意一个结点涂成红色结点后,点集中所有点的花费的最大值的最小是多少。

•题解

  相关变量解释:

    sum : 每次询问中询问的点集个数

    a[  ]  : 存储每次询问到的点集

    costR[i] : 结点 i 距其最近红色祖先的花费

  预处理每个点到根的距离cost、到最近红色祖先的距离 costR 和 ST 表。

  对于每次询问,将a[ ] 按 costR 从大到小排序,在 0~costR[a[0]] 范围内二分答案;

  对所有大于答案的点求它们的公共祖先(利用ST表可以O(1)求两点的公共祖先),将其涂红;

  之后计算每个大于答案的点的新花费是否小于答案。

•Code

  1 #include<iostream>
  2 #include<vector>
  3 #include<cstdio>
  4 #include<cmath>
  5 #include<algorithm>
  6 #include<cstring>
  7 using namespace std;
  8 #define pb push_back
  9 #define ll long long
 10 #define mem(a,b) (memset(a,b,sizeof a))
 11 const int maxn=1e5+50;
 12 
 13 int n,m,q;
 14 //===============Restore Graph============
 15 struct Node
 16 {
 17     int to;
 18     ll w;
 19     Node(int to,int w):to(to),w(w){}
 20 };
 21 vector<Node >G[maxn];
 22 void addEdge(int u,int v,int w)
 23 {
 24     G[u].pb(Node(v,w));
 25     G[v].pb(Node(u,w));
 26 }
 27 //=========================================
 28 int vs[2*maxn];//欧拉序列,范围区间为 [1,total]
 29 int depth[2*maxn];//欧拉序列对应的深度序列
 30 int pos[maxn];//pos[i] : 结点 i 再欧拉序列中第一次出现的位置
 31 ll cost[maxn];//cost[i] : 结点 i 距根据点的距离
 32 ll costR[maxn];//costR[i] : 结点 i 距最近红色祖先结点的距离,初始化为 -1
 33 int total;//欧拉序列的大小
 34 void dfs(int u,int f,int dep,ll dis)
 35 {
 36     vs[++total]=u;
 37     depth[total]=dep;
 38     pos[u]=total;
 39     cost[u]=dis;
 40     for(int i=0;i < G[u].size();++i)
 41     {
 42         Node e=G[u][i];
 43         if (e.to == f)
 44             continue;
 45         costR[e.to]=(costR[e.to] == 0 ? 0:costR[u]+e.w);
 46         dfs(e.to,u,dep+1,dis+e.w);
 47         vs[++total]=u;
 48         depth[total]=dep;
 49     }
 50 }
 51 //==================RMQ======================
 52 struct Node2
 53 {
 54     int mm[2 * maxn];
 55     int dp[2 * maxn][20];
 56     void ST()
 57     {
 58         int n=total;
 59         mm[0] = -1;
 60         for (int i = 1; i <= n; i++)
 61         {
 62             mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1];
 63             dp[i][0]=i;
 64         }
 65         for (int j=1;j <= mm[n];j++)
 66             for (int i=1;i+(1<<j)-1 <= n;i++)
 67                 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]])
 68                     dp[i][j]=dp[i][j-1];
 69                 else
 70                     dp[i][j]=dp[i+(1<<(j-1))][j-1];
 71     }
 72     int Lca(int u, int v)
 73     {
 74         u=pos[u],v=pos[v];
 75         if (u > v)
 76             swap(u, v);
 77         int k = mm[v-u+1];
 78         if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]])
 79             return vs[dp[u][k]];
 80         return vs[dp[v-(1<<k)+1][k]];
 81     }
 82 }_rmq;
 83 //==========================================
 84 int a[maxn];
 85 int sum;
 86 bool cmp(int a, int b)
 87 {
 88     return costR[a] > costR[b];
 89 }
 90 bool Check(ll x)
 91 {
 92     if(costR[a[0]] <= x)
 93         return true;
 94     int lca=a[0];
 95     for(int i=1;i < sum;i++)
 96     {
 97         if(costR[a[i]] <= x)
 98             break;
 99         lca=_rmq.Lca(lca,a[i]);
100     }
101     for(int i = 0;i < sum;i++)
102     {
103         if(costR[a[i]] <= x)
104             return true;
105         if(cost[a[i]]-cost[lca] > x)
106             return false;
107     }
108     return true;
109 }
110 void Solve()
111 {
112     dfs(1,-1,0,0);
113     _rmq.ST();
114     while(q--)
115     {
116         scanf("%d",&sum);
117         for (int i=0;i < sum; i++)
118             scanf("%d",&a[i]);
119         sort(a,a+sum,cmp);
120         ll l=0,r=costR[a[0]];
121         while(l < r)
122         {
123             ll mid=(l+r)/2;
124             if(Check(mid))
125                 r=mid;
126             else
127                 l=mid + 1;
128         }
129         printf("%lld
",l);
130     }
131 }
132 void init()
133 {
134     mem(costR,-1);
135     total=0;
136     for(int i=0;i < maxn;++i)
137         G[i].clear();
138 }
139 int main()
140 {
141     int t;
142     scanf("%d", &t);
143     while(t--)
144     {
145         init();
146         scanf("%d%d%d",&n,&m,&q);
147         while(m--)
148         {
149             int red;
150             scanf("%d",&red);
151             costR[red]=0;
152         }
153         costR[1]=0;
154         for(int i=1;i<n;i++)
155         {
156             int u,v,w;
157             scanf("%d%d%d",&u,&v,&w);
158             addEdge(u,v,w);
159         }
160         Solve();
161     }
162     return 0;
163 }
View Code

•出现的问题

  1、用 vector 存储图比用 链式前向星存储图要慢

    (1)vector : 

    (2)链式前向星:

  2、平常一直在用的RMQ会超时

 1 //=====================RMQ===================
 2 struct Node1
 3 {
 4     int dp[20][2*maxn];
 5     void Preset()
 6     {
 7         for(int i=0;i < 2*maxn;++i)
 8             dp[0][i]=i;
 9     }
10     void ST()
11     {
12         int k=log(total)/log(2);
13         for(int i=1;i <= k;++i)
14             for(int j=1;j <= (total-(1<<i)+1);++j)
15                 if(depth[dp[i-1][j]] > depth[dp[i-1][j+(1<<(i-1))]])
16                     dp[i][j]=dp[i-1][j+(1<<(i-1))];
17                 else
18                     dp[i][j]=dp[i-1][j];
19     }
20     int Lca(int u,int v)
21     {
22         u=pos[u],v=pos[v];
23         if(u > v)
24             swap(u,v);
25         int k=log(v-u+1)/log(2);
26         if(depth[dp[k][u]] > depth[dp[k][v-(1<<k)+1]])
27             return vs[dp[k][v-(1<<k)+1]];
28         return vs[dp[k][u]];
29     }
30 }_rmq;
31 //===========================================
TLE
 1 //==================RMQ======================
 2 struct Node2
 3 {
 4     int mm[2 * maxn];
 5     int dp[2 * maxn][20];
 6     void ST()
 7     {
 8         int n=total;
 9         mm[0] = -1;
10         for (int i = 1; i <= n; i++)
11         {
12             mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1];
13             dp[i][0]=i;
14         }
15         for (int j=1;j <= mm[n];j++)
16             for (int i=1;i+(1<<j)-1 <= n;i++)
17                 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]])
18                     dp[i][j]=dp[i][j-1];
19                 else
20                     dp[i][j]=dp[i+(1<<(j-1))][j-1];
21     }
22     int Lca(int u, int v)
23     {
24         u=pos[u],v=pos[v];
25         if (u > v)
26             swap(u, v);
27         int k = mm[v-u+1];
28         if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]])
29             return vs[dp[u][k]];
30         return vs[dp[v-(1<<k)+1][k]];
31     }
32 }_rmq;
33 //==========================================
AC

  3、cost[ ] 很有用,如果 Check( ) 中不加    

      if(cost[a[i]]-cost[lca] > x)
        return false;

    会返回 WA,具体为什么,明天再好好想想%%%%%%%%%

 


分割线:2019.5.8

  中石油的这场重现赛又让我回想起了这道题留下的疑惑;

  现在再想想这道题,思路清晰了些许;

  一些不理解的地方瞬间顿悟了;

  ST表处理RMQ中,会多次求解 log2(x),这种算式是比较耗时的,我们预处理出所需的log2(x);

logTwo[i]=log2(i);

  如何预处理呢?

  首先想一下,三位数的二进制数的最大值为 111(2),四位数的二进制数的最小值为 1000(2)

  两者的关系是 (111)&(1000) = 0 , 而对于任意三位二进制数 x,y ,(x&y) != 0;

  有了这个关系后,就可以这么预处理了:

logTwo[0]=-1;
for(int i=1;i <= n;++i)
    logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1];

  这就是之前一直不理解的ST表加速的地方;

•Code

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define ll long long
  4 #define mem(a,b) memset(a,b,sizeof(a))
  5 #define INFll 0x3f3f3f3f3f3f3f3f
  6 const int maxn=1e5+50;
  7 
  8 int n,m,q;
  9 ll C[maxn];///C[i]:节点i到根节点1的花费
 10 ll CR[maxn];///CR[i]:节点i到其最近的红色祖先节点的花费
 11 int num;
 12 int head[maxn];
 13 struct Edge
 14 {
 15     int to;
 16     ll w;
 17     int next;
 18 }G[maxn<<1];
 19 void addEdge(int u,int v,ll w)
 20 {
 21     G[num]={v,w,head[u]};
 22     head[u]=num++;
 23 }
 24 struct LCA
 25 {
 26     int vs[maxn<<1];///欧拉序列
 27     int dep[maxn<<1];///欧拉序列中的节点对应的深度序列
 28     int pos[maxn<<1];///pos[i]:节点i在欧拉序列中第一次出现的位置
 29     int cnt;
 30     int logTwo[maxn<<1];///logTwo[i]:log2(i)
 31     int dp[maxn<<1][20];///dp[i][j]:[i,i+2^j-1]深度最小的点的下标(欧拉序列中的下标)
 32     void DFS(int u,int f,int depth,ll dist)
 33     {
 34         vs[++cnt]=u;
 35         dep[cnt]=depth;
 36         pos[u]=cnt;
 37         C[u]=dist;
 38         for(int i=head[u];~i;i=G[i].next)
 39         {
 40             int v=G[i].to;
 41             ll w=G[i].w;
 42             if(v == f)
 43                 continue;
 44             CR[v]=min(CR[v],CR[u]+w);
 45             DFS(v,u,depth+1,dist+w);
 46             vs[++cnt]=u;
 47             dep[cnt]=depth;
 48         }
 49     }
 50     void ST()
 51     {
 52         logTwo[0]=-1;
 53         for(int i=1;i <= cnt;++i)
 54         {
 55             dp[i][0]=i;
 56             ///:后的语句写错了,刚开始写成了logTwo[i],debug了好一会
 57             logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1];
 58         }
 59         for(int k=1;k <= logTwo[cnt];++k)
 60             for(int i=1;i+(1<<k)-1 <= cnt;++i)
 61                 if(dep[dp[i][k-1]] > dep[dp[i+(1<<(k-1))][k-1]])
 62                     dp[i][k]=dp[i+(1<<(k-1))][k-1];
 63                 else
 64                     dp[i][k]=dp[i][k-1];
 65     }
 66     void lcaInit(int root)
 67     {
 68         cnt=0;
 69         DFS(root,root,0,0);
 70         ST();
 71     }
 72     int lca(int u,int v)///返回节点u,v的LCA
 73     {
 74         u=pos[u];
 75         v=pos[v];
 76 
 77         if(u > v)
 78             swap(u,v);
 79 
 80         int k=logTwo[v-u+1];
 81         if(dep[dp[u][k]] > dep[dp[v-(1<<k)+1][k]])
 82             return vs[dp[v-(1<<k)+1][k]];
 83         else
 84             return vs[dp[u][k]];
 85     }
 86 }_lca;
 87 
 88 int qCnt;
 89 int query[maxn<<1];
 90 
 91 bool Check(ll mid)
 92 {
 93     int lca=0;///不满足条件的点的LCA
 94     for(int i=1;i <= qCnt;++i)
 95     {
 96         if(CR[query[i]] <= mid)
 97             continue;
 98         if(lca == 0)
 99             lca=query[i];
100         else/// > mid的点LCA
101             lca=_lca.lca(lca,query[i]);
102     }
103 
104     for(int i=1;i <= qCnt;++i)
105     {
106         if(CR[query[i]] <= mid)
107             continue;
108 
109         ///如果将lca点涂红后还不能使其 <= mid,返回false
110         if(C[query[i]]-C[lca] > mid)
111             return false;
112     }
113     return true;
114 }
115 void Solve()
116 {
117     _lca.lcaInit(1);
118 
119     for(int i=1;i <= q;++i)
120     {
121         scanf("%d",&qCnt);
122 
123         ll l=-1,r=0;
124         for(int j=1;j <= qCnt;++j)
125         {
126             scanf("%d",query+j);
127             r=max(r,CR[query[j]]);
128         }
129 
130         while(r-l > 1)
131         {
132             ll mid=l+((r-l)>>1);
133             if(Check(mid))
134                 r=mid;
135             else
136                 l=mid;
137         }
138         printf("%lld
",r);
139     }
140 }
141 void Init()
142 {
143     num=0;
144     mem(head,-1);
145     mem(CR,INFll);///初始化为最大值
146 }
147 int main()
148 {
149 //    freopen("C:\Users\hyacinthLJP\Desktop\in&&out\contest","r",stdin);
150     int test;
151     scanf("%d",&test);
152     while(test--)
153     {
154         Init();
155         scanf("%d%d%d",&n,&m,&q);
156         for(int i=1;i <= m;++i)
157         {
158             int red;
159             scanf("%d",&red);
160             CR[red]=0;
161         }
162         CR[1]=0;
163         for(int i=1;i < n;++i)
164         {
165             int u,v,w;
166             scanf("%d%d%d",&u,&v,&w);
167             addEdge(u,v,w);
168             addEdge(v,u,w);
169         }
170         Solve();
171     }
172     return 0;
173 }
View Code
原文地址:https://www.cnblogs.com/violet-acmer/p/9677889.html