最近公共祖先(least common ancestors,LCA)

摘要:

  本文主要介绍了解决LCA(最近公共祖先问题)的两种算法,分别是离线Tarjan算法和在线算法,着重展示了在具体题目中的应用细节。


  最近公共祖先是指对于一棵有根树T的两个结点u和v,它们的LCA(T,u,v)表示一个结点x,满足x是u和v的公共祖先且x深度尽可能的大(也即最近)。

  求最近公共祖先有两种方法:一种是离线求解算法,也就是将询问全部存起来,处理完之后一次回答所有询问;另一种方法就是在线求解算法,对于每次询问,动态地回答。

离线算法Tarjan

  Tarjan算法就是利用深度优先搜索的框架,对于新搜索到的一个结点,首先创建由这个结点构成的集合,再对当前结点的每一个子树进行搜索,每搜索完一棵子树,则可以确定这棵子树之内的LCA问题都已经解决。其他的LCA问题肯定都在这个子树之外,这时把子树所形成集合与当前结点的集合合并,并将当前结点设为这个集合的祖先。

  之后继续搜索下一棵子树,直到当前结点的所有子树搜索完,这时把当前结点也设为已经被检查过的,同时可以处理有关当前结点的LCA询问,如果有一个从当前结点到v的询问,且v已经被检查过,则由于进行的是深度优先搜索,当前结点和v的LCA一定还没有被检查过,然而这个最近公共祖先的包含v的子树一定已经搜索过了,那么这个最近公共祖先一定是v所在集合的祖先。

算法实现如下:

 1 const int maxn = 10005;//结点数
 2 bool vis[maxn];
 3 int tree[maxn][maxn], ans[maxn][maxn], fa[maxn];
 4 //tree[u][0]表示结点u有几个孩子,分别是
 5 //ans[u][v]表示u和v的LCA
 6 //fa[i]表示i的祖先
 7 set<int> query[maxn];//保存关于u结点的询问
 8 
 9 int n;
10 void init() {
11     for(int i = 0; i <= n; i++) {
12         fa[i] = i;
13         vis[i] = 0;
14     }
15 }
16 
17 int Find(int x) {//查找一个集合的祖先
18     return fa[x] == x ? x : fa[x] = Find(fa[x]);
19 }
20 
21 void Union(int x, int y) {//合并两个集合,将x并入y中
22     int fx = Find(x);
23     int fy = Find(y);
24     fa[fx] = fy;
25 }
26 
27 void dfs(int u) {//dfs遍历树
28     vis[u] = 1;
29     for(int i = 1; i <= tree[u][0]; i++) {
30         int v = tree[u][i];
31         if(vis[v])  continue;
32         dfs(v);
33         Union(v, u);//当遍历完一棵子树的时候,就将子树和父亲合并
34     }
35     for(set<int>::iterator it = query[u].begin();
36     it != query[u].end(); it++) {//当处理完u结点的子树时,就可以回答部分有关u的询问
37         int v = *it;
38         if(vis[v]) {             //如果v被访问过,就可以回答这个询问
39             ans[u][v] = Find(v); //u and v 的LCA就是v所在集合的公共祖先
40             query[v].erase(query[u].find(u));//将v中的此询问删掉
41         }
42     }
43 }

  Tarjan算法可以解决LCA查询要求实现知道全部查询提问,如果LCA要求即问即答,就需要使用在线算法。

在线算法

  在线算法需要对该树进行预处理,生成三个序列:欧拉序列、深度序列、遍历结点第一次出现的时间序列,然后通过RMQ(区间最值查询)来O(1)地回答问题。

巧妙的是只用对树进行一次深度优先遍历,就可以得到这三个序列了。

  结点第一次出现的时间:就是深度优先遍历的过程中第一次遍历到这个结点的时间,该序列的长度是n,记为pos数组,即pos[u] = 3,表示u结点是第三个遍历到的。

  欧拉序列:按照深度优先遍历,依次经过的结点按照遍历顺序全部记录下来,包括回溯的过程,也就是一个点可能被记录多次。该序列的长度由深搜的过程决定,记为t数组。

  深度序列:该序列的长度和欧拉序列的长度一致,记录的是欧拉序列中对应结点的深度,记为dep。

  有了这三个序列,假设我们需要查询LCA(T,u,v),通过pos[u]和pos[v]可以知道u和v结点在t数组和dep数组中是第几个,利用深度优先遍历的过程,可以知道在dep[pos[u]]~dep[pos[v]]中深度最小的结点就是LCA(T,u,v)了。

算法实现如下:

 1 const int maxn = 1006;
 2 
 3 int tot, ans[maxn][maxn], link[maxn][maxn];//link存树的结构
 4 int dep[maxn * 4], pos[maxn], t[maxn * 4];
 5 int dp[maxn * 4][12];//存储区间最值
 6 bool v[maxn];
 7 
 8 void dfs(int u, int dfn) {
 9     if(!v[u]) {
10         v[u] = 1;
11         pos[u] = tot;
12     }
13     dep[tot] = dfn; //深度序列
14     t[tot++] = u;   //欧拉序列
15     for(int i = 1; i <= link[u][0]; i++) {
16         dfs(link[u][i], dfn + 1);
17         
18         dep[tot] = dfn;
19         t[tot++] = u;
20     }
21     return;
22 }
23 
24 void init() {
25     for(int j = 0; (1 << j) <= tot; j++) {
26         for(int i = 0; i + (1 << j) <= tot; i++) {
27             if(j == 0)
28                 dp[i][j] = i;
29             else {
30                 if(dep[dp[i][j - 1]] < dep[dp[i + (1 << (j - 1))][j - 1]])
31                     dp[i][j] = dp[i][j - 1];
32                 else
33                     dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
34             }
35         }
36     }
37 }
38 
39 int RMQ(int p1, int p2) {
40     int k = log2(p2 - p1 + 1);
41     if( (1<<k) < p2 - p1 + 1) k++;
42     if(dep[dp[p1][k]] < dep[dp[p2 - (1 << k) + 1][k]])
43         return t[dp[p1][k]];
44     else
45         return t[dp[p2 - (1 << k) + 1][k]];
46 }
47 
48 int lca(int v1, int v2) {
49     if(pos[v1] < pos[v2])
50         return RMQ(pos[v1], pos[v2]);
51     else
52         return RMQ(pos[v2], pos[v1]);
53 }

  下面看一道例题:HDU 2586 How far away ?

题意

  输出一棵有根数,问任意两个结点间的距离

解题思路

  首先问一次计算一次不是什么好的办法,我们可以将每个结点到根结点的距离预处理出来,然后找到两个结点的最近公共祖先,然后答案就是dis[u] + dis[v] - 2 * dis[lca(u, v)]。因为是即问即答,所以采用在线的方法。

  注意RMQ中k的计算方式有所不同,采用之前的方法计算会发生数组访问越界。

代码如下:

  1 #include <cstdio>
  2 #include <vector>
  3 #include <cmath>
  4 #include <cstring>
  5 
  6 using namespace std;
  7 
  8 const int maxn = 41010;
  9 struct E{
 10     int v, ne, d;
 11     E(){}
 12     E(int _v, int _n, int _d): v(_v), ne(_n), d(_d){}
 13 }e[maxn * 2];
 14 
 15 int t[maxn * 2], dep[maxn * 2], pos[maxn], dis[maxn];
 16 int head[maxn], esize;
 17 bool vis[maxn];
 18 int dp[maxn * 2][30];
 19 int n, m, tot;
 20 
 21 void init() {
 22     esize = tot = 0;
 23     memset(vis, 0, sizeof(vis));
 24     memset(dis, 0, sizeof(dis));
 25     memset(head, -1, sizeof(head));
 26 }
 27 
 28 void add(int u, int v, int d) {
 29     e[esize] = E(v, head[u], d);
 30     head[u] = esize++;
 31 }
 32 
 33 void dfs(int u, int de) {
 34     if(!vis[u]) {
 35         vis[u] = 1;
 36         pos[u] = tot;
 37     }
 38     dep[tot] = de;
 39     t[tot++] = u;
 40 
 41     for(int i = head[u]; i != -1; i = e[i].ne) {
 42         int v = e[i].v;
 43         int d = e[i].d;
 44         if(vis[v]) continue;
 45         dis[v] = dis[u] + d;
 46         dfs(v, de + 1);
 47 
 48         dep[tot] = de;
 49         t[tot++] = u;
 50     }
 51     return;
 52 }
 53 
 54 void cdep() {
 55     for(int j = 0; (1 << j) < tot; j++) {
 56         for(int i = 1; i + (1 << j) < tot; i++) {
 57             if(j == 0)
 58                 dp[i][j] = i;
 59             else {
 60                 if(dep[dp[i][j - 1]] < dep[dp[i + (1 << (j - 1))][j - 1]])
 61                     dp[i][j] = dp[i][j - 1];
 62                 else
 63                     dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
 64             }
 65         }
 66     }
 67 }
 68 
 69 int RMQ(int u, int v) {
 70     int k = 0;
 71     k = log2(v - u + 1);
 72     if((1 << k) < v - u + 1) k++;
 73     /*int len = v - u + 1, k = 0;
 74     k = log(len * 1.0)/log(2.0);*/
 75     if(dep[dp[u][k]] < dep[dp[v - (1 << k) + 1][k]])
 76         return t[dp[u][k]];
 77     else
 78         return t[dp[v - (1 << k) + 1][k]];
 79 }
 80 
 81 int lca(int u, int v) {
 82     if(pos[u] < pos[v])
 83         return RMQ(pos[u], pos[v]);
 84     else
 85         return RMQ(pos[v], pos[u]);
 86 }
 87 
 88 int main()
 89 {
 90     int T;
 91     scanf("%d", &T);
 92     while(T--) {
 93         scanf("%d%d", &n, &m);
 94         init();
 95         for(int i = 1; i < n; i++) {
 96             int u, v, w;
 97             scanf("%d%d%d", &u, &v, &w);
 98             add(u, v, w);
 99             add(v, u, w);
100         }
101         dfs(1, 0);
102         cdep();
103 
104         int u, v;
105         while(m--) {
106             scanf("%d%d", &u, &v);
107             printf("%d
", dis[u] + dis[v] - 2 * dis[lca(u, v)]);
108         }
109     }
110     return 0;
111 }
原文地址:https://www.cnblogs.com/wenzhixin/p/9751157.html