[虚树模板] 洛谷P2495 消耗战

[update]

好像有个东西叫笛卡尔树,好像是这样建的.....

inline void build_d() {
    stk[top = 1] = 1;
    for(int i = 2; i <= n; i++) {
        while(top && val[stk[top]] <= val[i]) {
            ls[i] = stk[top];
            top--;
        }
        if(ls[i]) {
            fa[ls[i]] = i;
        }
        if(top) {
            fa[i] = stk[top];
            rs[stk[top]] = i;
        }
        stk[++top] = i;
    }
    RT = stk[1];
    return;
}

题意:给定树上k个点,求切断这些点到根路径的最小代价。∑k <= 5e5

解:虚树。

构建虚树大概是这样的:设加入点与栈顶的lca为y,比较y和栈中第二个元素的DFS序大小关系。

代码如下:

 1 inline bool cmp(const int &a, const int &b) {
 2     return pos[a] < pos[b];
 3 }
 4 
 5 inline void build_t() {
 6     std::sort(imp + 1, imp + k + 1, cmp);
 7     TP = top = 0;
 8     stk[++top] = imp[1];
 9     use[imp[1]] = Time;
10     E[imp[1]] = 0;
11     for(int i = 2; i <= k; i++) {
12         int x = imp[i], y = lca(x, stk[top]);
13         if(use[x] != Time) {
14             use[x] = Time;
15             E[x] = 0;
16         }
17         if(use[y] != Time) {
18             use[y] = Time;
19             E[y] = 0;
20         }
21         while(top > 1 && pos[y] <= pos[stk[top - 1]]) {
22             ADD(stk[top - 1], stk[top]);
23             top--;
24         }
25         if(stk[top] != y) {
26             ADD(y, stk[top]);
27             stk[top] = y;
28         }
29         stk[++top] = x;
30     }
31     while(top > 1) {
32         ADD(stk[top - 1], stk[top]);
33         top--;
34     }
35     RT = stk[top];
36     return;
37 }

然后本题建虚树跑DP就行了。注意虚树根节点的DP初值和虚树的边权。

  1 #include <cstdio>
  2 #include <algorithm>
  3 
  4 typedef long long LL;
  5 const int N = 250010;
  6 const LL INF = 1e18;
  7 
  8 struct Edge {
  9     int nex, v;
 10     LL len;
 11 }edge[N << 1], EDGE[N << 1]; int tp, TP;
 12 
 13 int e[N], siz[N], stk[N], top, Time, n, fa[N][20], k, RT, num, pos[N], pw[N], now[N], imp[N], E[N], d[N], use[N];
 14 LL f[N], small[N][20];
 15 
 16 inline void ADD(int x, int y, LL z) {
 17     TP++;
 18     EDGE[TP].v = y;
 19     EDGE[TP].len = z;
 20     EDGE[TP].nex = E[x];
 21     E[x] = TP;
 22     return;
 23 }
 24 
 25 inline void add(int x, int y, LL z) {
 26     top++;
 27     edge[top].v = y;
 28     edge[top].len = z;
 29     edge[top].nex = e[x];
 30     e[x] = top;
 31     return;
 32 }
 33 
 34 void DFS_1(int x, int father) { // get fa small
 35     fa[x][0] = father;
 36     d[x] = d[father] + 1;
 37     pos[x] = ++num;
 38     for(int i = e[x]; i; i = edge[i].nex) {
 39         int y = edge[i].v;
 40         if(y == father) {
 41             continue;
 42         }
 43         small[y][0] = edge[i].len;
 44         DFS_1(y, x);
 45     }
 46     return;
 47 }
 48 
 49 inline int lca(int x, int y) {
 50     if(d[x] > d[y]) {
 51         std::swap(x, y);
 52     }
 53     int t = pw[n];
 54     while(t >= 0 && d[x] < d[y]) {
 55         if(d[fa[y][t]] >= d[x]) {
 56             y = fa[y][t];
 57         }
 58         t--;
 59     }
 60     if(x == y) {
 61         return x;
 62     }
 63     t = pw[n];
 64     while(t >= 0 && fa[x][0] != fa[y][0]) {
 65         if(fa[x][t] != fa[y][t]) {
 66             x = fa[x][t];
 67             y = fa[y][t];
 68         }
 69         t--;
 70     }
 71     return fa[x][0];
 72 }
 73 
 74 inline bool cmp(const int &a, const int &b) {
 75     return pos[a] < pos[b];
 76 }
 77 
 78 inline LL getMin(int x, int y) {
 79     LL ans = INF;
 80     int t = pw[d[y] - d[x]];
 81     while(t >= 0 && y != x) {
 82         if(d[fa[y][t]] >= d[x]) {
 83             ans = std::min(ans, small[y][t]);
 84             y = fa[y][t];
 85         }
 86         t--;
 87     }
 88     return ans;
 89 }
 90 
 91 inline void build_t() {
 92     std::sort(imp + 1, imp + k + 1, cmp);
 93     TP = top = 0;
 94     stk[++top] = imp[1];
 95     use[imp[1]] = Time;
 96     E[imp[1]] = 0;
 97     for(int i = 2; i <= k; i++) {
 98         int x = imp[i], y = lca(x, stk[top]);
 99         if(use[x] != Time) {
100             use[x] = Time;
101             E[x] = 0;
102         }
103         if(use[y] != Time) {
104             use[y] = Time;
105             E[y] = 0;
106         }
107         while(top > 1 && pos[y] <= pos[stk[top - 1]]) {
108             ADD(stk[top - 1], stk[top], getMin(stk[top - 1], stk[top]));
109             top--;
110         }
111         if(stk[top] != y) {
112             ADD(y, stk[top], getMin(y, stk[top]));
113             stk[top] = y;
114         }
115         stk[++top] = x;
116     }
117     while(top > 1) {
118         ADD(stk[top - 1], stk[top], getMin(stk[top - 1], stk[top]));
119         top--;
120     }
121     RT = stk[top];
122     return;
123 }
124 
125 void DFS(int x) {
126     siz[x] = (now[x] == Time);
127     LL temp = 0;
128     for(int i = E[x]; i; i = EDGE[i].nex) {
129         int y = EDGE[i].v;
130         f[y] = EDGE[i].len;
131         DFS(y);
132         siz[x] += siz[y];
133         if(siz[y]) {
134             temp += f[y];
135         }
136     }
137     if(now[x] != Time) {
138         f[x] = std::min(f[x], temp);
139     }
140     return;
141 }
142 
143 void out(int x) {
144     return;
145 }
146 
147 int main() {
148     scanf("%d", &n);
149     /*if(n > 100) {
150         return -1;
151     }*/
152     int x, y; LL z;
153     for(int i = 1; i < n; i++) {
154         scanf("%d%d%lld", &x, &y, &z);
155         add(x, y, z);
156         add(y, x, z);
157     }
158     // get lca min_edge
159     DFS_1(1, 0);
160     for(int i = 2; i <= n; i++) {
161         pw[i] = pw[i >> 1] + 1;
162     }
163     for(int j = 1; j <= pw[n]; j++) {
164         for(int i = 1; i <= n; i++) {
165             fa[i][j] = fa[fa[i][j - 1]][j - 1];
166             small[i][j] = std::min(small[i][j - 1], small[fa[i][j - 1]][j - 1]);
167         }
168     }
169 
170     int m;
171     scanf("%d", &m);
172     for(Time = 1; Time <= m; Time++) {
173         scanf("%d", &k);
174         //printf("
 k = %d 
", k);
175         for(int i = 1; i <= k; i++) {
176             scanf("%d", &imp[i]);
177             now[imp[i]] = Time;
178         }
179         //printf("input over 
");
180         build_t();
181         //out(RT);
182         f[RT] = getMin(1, RT);
183         DFS(RT);
184         printf("%lld
", f[RT]);
185     }
186 
187     return 0;
188 }
AC代码
原文地址:https://www.cnblogs.com/huyufeifei/p/10404454.html