SPOJ1825 Free tour II 树分治

题意:带边权树上有白点和黑点,问你最多不经过k个黑点使得路径最长(注意,路径有负数)

解题思路:基于树的点分治。数的路径问题,具体看09QZC论文,特别注意 当根为黑时的情况

解题代码:

  1 // File Name: spoj1825.cpp
  2 // Author: darkdream
  3 // Created Time: 2014年10月05日 星期日 20时20分33秒
  4 
  5 #include<vector>
  6 #include<list>
  7 #include<map>
  8 #include<set>
  9 #include<deque>
 10 #include<stack>
 11 #include<bitset>
 12 #include<algorithm>
 13 #include<functional>
 14 #include<numeric>
 15 #include<utility>
 16 #include<sstream>
 17 #include<iostream>
 18 #include<iomanip>
 19 #include<cstdio>
 20 #include<cmath>
 21 #include<cstdlib>
 22 #include<cstring>
 23 #include<ctime>
 24 #define LL long long 
 25 #define maxn 200015
 26 using namespace std;
 27 struct node{
 28     int ne;
 29     int w;
 30     node(int _ne,int _w)
 31     {
 32         ne = _ne ; 
 33         w = _w;
 34     }
 35 };
 36 int n ,K, m ; 
 37 int col[maxn];
 38 int vis[maxn];
 39 vector <node> mp[maxn];
 40 int sum[maxn];
 41 int mx[maxn];
 42 int cnum[maxn];
 43 void getsize(int k,int la)
 44 {
 45     sum[k] = 1; 
 46     mx[k] = 0;
 47     int num = mp[k].size();
 48     int tt = 0 ;
 49     for(int i = 0 ;i < num;i ++)
 50     {
 51        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
 52        {
 53            getsize(mp[k][i].ne,k);
 54            mx[k] = max(sum[mp[k][i].ne],mx[k]);
 55            sum[k] += sum[mp[k][i].ne];
 56        }
 57     }
 58 }
 59 int root;
 60 int mxv; 
 61 int getroot(int k,int la ,int tans)
 62 {
 63      int tt = max(tans - sum[k],mx[k]);
 64      if(tt < mxv)
 65      {
 66         mxv = tt;
 67         root = k ; 
 68      }
 69      int num = mp[k].size();
 70      for(int i = 0 ;i < num ;i ++)
 71      {
 72        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
 73        {
 74            getroot(mp[k][i].ne,k,tans);
 75        }
 76      }
 77 }
 78 LL ans = 0 ;
 79 LL dp[maxn];
 80 LL tdp[maxn];
 81 bool cmp(node a, node b)
 82 {
 83     return cnum[a.ne] < cnum[b.ne];
 84 }
 85 void getdep(int k ,int la,int tc,LL dep)
 86 { 
 87     int st = (col[k] == 1?1:0) ;
 88     tdp[tc+st] = max(tdp[tc+st],dep); //这个点是G点的时候
 89     int num = mp[k].size();
 90     for(int i = 0 ;i < num ;i ++)
 91     {
 92         if(!vis[mp[k][i].ne] && mp[k][i].ne != la )
 93         {
 94             getdep(mp[k][i].ne,k,tc + st,dep + mp[k][i].w);
 95         }
 96     }
 97 }
 98 void getcnum(int k ,int la)
 99 {
100     if(col[k])
101         cnum[k] = 1; 
102     else cnum[k] = 0 ; 
103     int tt = 0 ;
104     int num = mp[k].size();
105     for(int i = 0 ;i < num;i ++)
106     {
107        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
108        {
109            getcnum(mp[k][i].ne,k);
110           if(cnum[mp[k][i].ne] > tt)
111               tt = cnum[mp[k][i].ne];
112        }
113     }
114     cnum[k] += tt;
115 }
116 void solve(int k)
117 {
118     getsize(k,0);
119     mxv = 1e9;
120     getroot(k,0,sum[k]);
121     k = root;
122     
123     getcnum(k,0);    
124     //printf("*****%d %d
",k,cnum[k]);    
125     int num = mp[k].size();
126     memset(dp,0,(cnum[k]+3)*sizeof(LL));
127     int tk ;
128     int st = 0 ;
129     if(col[k])
130     {
131         tk = K + 1;
132         st = 1;
133     }
134     else tk = K ;
135     int la =0 ; 
136     //int size = min(cnum[k],K);
137     sort(mp[k].begin(),mp[k].end(),cmp);
138     for(int i = 0 ;i < num ;i ++)
139     {
140         if(vis[mp[k][i].ne])
141             continue;
142         
143         memset(tdp,0,(cnum[mp[k][i].ne]+3)*sizeof(tdp[0]));
144         if(col[k])
145            getdep(mp[k][i].ne,k,1,mp[k][i].w);        
146         else 
147            getdep(mp[k][i].ne,k,0,mp[k][i].w);        
148     //    printf("**********%d
",tk);
149         
150         
151         int tt = min(cnum[mp[k][i].ne]+st,K);    
152 //        printf("%d %d
",cnum[mp[k][i].ne]+st,K);
153         for(int j = st ;j <= tt;j ++)
154         {
155            if(tk - j <= la)
156            {
157             if(tdp[j] + dp[tk-j]> ans)
158             {
159                 ans = tdp[j] + dp[tk-j];
160             }
161            }else{
162              if(tdp[j] + dp[la]> ans)
163              {
164                 ans = tdp[j] + dp[la];
165              }
166            }
167         } 
168         dp[0] = max(dp[0],tdp[0]);
169         //printf("%d %d
",n,cnum[mp[k][i].ne]);
170         /*if(tdp[tt+st+1] != 0)
171         {
172           printf("&&&&&&&&&&&&&&
");
173         }*/
174         for(int j = 1 ;j <= tt+st; j ++)
175         {
176             dp[j] = max(dp[j],tdp[j]);
177             dp[j] = max(dp[j],dp[j-1]);
178         }
179    //     for(int j = 0;j <= K;j ++)
180     //        printf("%lld ",dp[j]);
181     //    puts("");
182         la = tt + st;
183     }
184     //puts("**********8");
185     vis[k] = 1;
186     for(int i = 0;i < num;i ++)
187     {
188         if(!vis[mp[k][i].ne])
189             solve(mp[k][i].ne);
190     }
191     return; 
192 }
193 int main(){
194    //freopen("out","r",stdin);    
195    //freopen("output.txt","w",stdin);
196    while(scanf("%d %d %d",&n,&K,&m) != EOF){
197     int temp ; 
198     memset(vis,0,sizeof(vis));
199     memset(col,0,sizeof(col));
200     for(int i = 1;i <= n;i ++)
201         mp[i].clear();
202     for(int i = 1;i <= m;i ++)
203     {
204         scanf("%d",&temp);
205         col[temp]  = 1;  
206     }
207     for(int i = 1;i <= n - 1;i ++)
208     {
209         int a, b , w; 
210         scanf("%d %d %d",&a,&b,&w);
211         mp[a].push_back(node(b,w));
212         mp[b].push_back(node(a,w));
213     }
214     ans = 0; 
215     solve(1);
216     printf("%lld
",ans);
217    }
218     return 0;
219 }
View Code
没有梦想,何谈远方
原文地址:https://www.cnblogs.com/zyue/p/4013037.html