洛谷P2680 运输计划 树上差分 LCA 倍增 tarjan

洛谷P2680 运输计划
树上差分 LCA 倍增 tarjan
题意 给出若干条路径 你可以把其中的一条边变为 0
求变为 0 后的最短路径

首先发现答案满足单调性
那么就可以二分这个答案
首先我们 用tarjan 或者 倍增等算法 预处理 出两点之间的路径距离 ,以及lca
然后我们将两点之间的距离排序
枚举答案 mid

然后判断就是将 所有路径距离 大于 mid 的边 +1
然后最后如果 取边贡献次数为 num(num 表示 路径长 > mid 的个数)
在这些边中选一条最长的边 mx
如果 最长路径 - mx > mid 则说明 mid 不合法 要 增大 这个 mid
然后 每一轮 加 边的贡献 都是 O(n) 显然会T
考虑优化 因为首先是修改 然后在是查询
所以我们可以用差分的思想

d[ i ] 表示i这个点 连向他的父亲的那条边的贡献
每次加一条路径 u v
相当于 d[ u ]++ d[ v ]++ d[lca(u,v)]--
d[ i ] 其实表示的 是 d[ i ] - (d[ ls(i) ] + d[rs(i)] )
然后查询 一条边的贡献其实就他的子树 的和
时间复杂度 nlogn

  1 #include <cstdio>
  2 #include <cstdlib>
  3 #include <cmath>
  4 #include <cstring>
  5 #include <algorithm>
  6 #include <iostream> 
  7 #include <iomanip> 
  8 using namespace std ; 
  9 
 10 const int maxn = 300011 ; 
 11 struct node{
 12     int from,to,pre,val ; 
 13 }e[2*maxn];
 14 struct data{
 15     int u,v,lca,dist,id ; 
 16 }q[maxn];
 17 int n,Q,cnt,TI,l,r,mid,mx,t ;
 18 int dist[maxn],head[maxn],IN[maxn],OUT[maxn],f[maxn][21],d[maxn] ;  
 19 
 20 inline int read() 
 21 {
 22     int x = 0, f = 1 ; 
 23     char ch = getchar() ; 
 24     while(ch<'0'||ch>'9') { if(ch=='-') f = -1 ; ch = getchar() ; } 
 25     while(ch>='0'&&ch<='9') { x = x * 10+ch-48 ; ch = getchar() ; }
 26     return x * f ; 
 27 }
 28 
 29 inline bool cmp(data a,data b) 
 30 {
 31     return a.dist > b.dist ; 
 32 }
 33 
 34 inline void add(int x,int y,int v) 
 35 {
 36     e[++cnt].to = y ; 
 37     e[cnt].from = x ; 
 38     e[cnt].pre = head[ x ] ; 
 39     e[cnt].val = v ; 
 40     
 41     head[x] = cnt ;  
 42 }
 43 
 44 inline void dfs(int u,int fa) 
 45 {
 46     int v ; 
 47     IN[ u ] = ++TI ; 
 48     f[ u ][ 0 ] = fa ; 
 49     for(int i=head[ u ];i;i = e[ i ].pre) 
 50     {
 51         v = e[ i ].to ; 
 52         if(v!=fa) 
 53             dist[ v ] = dist[ u ] + e[ i ].val,dfs(v,u); 
 54     }
 55     OUT[ u ] = ++TI ; 
 56     
 57 }
 58 
 59 inline void pre() 
 60 {
 61     dist[ 1 ] = 0 ; 
 62     dfs( 1,-1 ) ; 
 63     f[ 1 ][ 0 ] = 1 ; 
 64     for(int j=1;j<=20;j++) 
 65         for(int i=1;i<=n;i++) 
 66             f[ i ][ j ] = f[ f[ i ][ j-1 ] ][ j-1 ] ; 
 67 }
 68 
 69 inline bool is_root(int u,int v )   //  u 是  v 的祖先  
 70 {
 71     if( IN[ u ] <= IN[ v ]&& IN[ v ] <= OUT[ u ] ) return 1 ; 
 72     return 0 ; 
 73 }
 74  
 75 inline int getlca(int u,int v) 
 76 {
 77     if(is_root(u,v)) return u ; 
 78     if(is_root(v,u)) return v ; 
 79     for(int i=20;i>=0;i--) 
 80         if(!is_root(f[ u ][ i ],v)) u = f[ u ][ i ] ; 
 81     return f[ u ][ 0 ] ; 
 82 }
 83 
 84 inline void sum(int u,int fa) 
 85 {
 86     int v ; 
 87     for(int i=head[ u ];i;i =e[i].pre) 
 88     {
 89         v = e[ i ].to ; 
 90         if(v != fa) 
 91             sum(v,u),d[ u ]+=d[ v ] ; 
 92     }
 93 }
 94 
 95 inline bool check(int mid) 
 96 {
 97     for(int i=0;i<=n;i++) d[ i ] = 0 ; 
 98     int num = 0 ; 
 99     for(int i=1;i<=Q;i++) 
100     {
101         if(q[ i ].dist <= mid ) break ; 
102         d[ q[ i ].u ]++ ; 
103         d[ q[ i ].v ]++ ; 
104         d[ q[ i ].lca ]-=2 ; 
105         num++ ; 
106     } 
107     sum( 1,-1 ) ; 
108     int u,v ; 
109     mx = 0 ; 
110     for(int i=1;i<=cnt;i+=2) 
111     {
112         u = e[ i ].from ; 
113         v = e[ i ].to ;  
114         if(dist[ u ] > dist[ v ]) { t = u ; u = v ; v = t ; }   
115         if( d[ v ]==num ) 
116             if( e[ i ].val > mx ) mx = e[ i ].val ; 
117     }
118     if( q[ 1 ].dist - mx <=mid ) 
119         return 1 ; 
120     else 
121         return 0 ; 
122 }
123 
124 int main() 
125 {
126     int x,y,v ; 
127     n = read() ;  Q = read() ;  
128     for(int i=1;i<n;i++) 
129     {
130         x = read() ; y = read() ; v = read() ; 
131         add(x,y,v) ; add(y,x,v) ;  
132     }
133     pre() ; 
134     for(int i=1;i<=Q;i++) 
135     {
136         q[ i ].u = read() ; q[ i ].v = read() ; 
137         q[ i ].lca = getlca( q[ i ].u ,q[ i ].v ) ; 
138         q[ i ].dist = dist[ q[ i ].u ] + dist[ q[ i ].v ] - 2*dist[ q[ i ].lca ] ; 
139         q[ i ].id = i ; 
140     }
141     sort(q+1,q+Q+1,cmp) ; 
142     
143     l = 0 ; 
144     r = q[ 1 ].dist ;  
145     
146     while( l < r ) 
147     {
148         mid = ( l + r ) >>1 ; 
149         if(check(mid)) 
150             r = mid ; 
151         else 
152             l = mid + 1 ;  
153     }
154     printf("%d
",r) ; 
155     return 0 ; 
156 } 
原文地址:https://www.cnblogs.com/third2333/p/7110720.html