[loj2470]有向图

参考ExtremeSpanningTrees,考虑优化整体二分时求$g_{i}in {w_{mid},w_{mid+1}}$的最优解

对于$m=n-1$的问题,不需要去网络流,可以直接树形dp

但为了保证复杂度,我们在整体二分中的复杂度只能是$o(点集大小)$,这样可能就比较麻烦

首先要建出虚树(保留其中lca的点),并预处理出每一个点到深度最小的祖先使得其中边的方向都相同,之后就可以判断相邻两点是否有大小关系

对于$m=n$的问题,可以先暴力枚举基环上的一点,之后按照$m=n-1$的情况去做

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 300005
  4 #define ll long long
  5 #define oo 1e15
  6 #define u e[p][k][i]
  7 #define y0 y00
  8 vector<int>vv,e[2][N];
  9 vector<pair<int,int> >ve;
 10 int n,m,x,y,x0,y0,d[N],w[N],dfn[N],dep[N],las[N],g[N],fa[N][21],a[N],v[N],st[N],vis[N],bl[N],ans[N];
 11 ll sum,f[N][2];
 12 bool cmp1(int x,int y){
 13     return dfn[x]<dfn[y];
 14 }
 15 bool cmp2(int x,int y){
 16     return bl[x]<bl[y];
 17 }
 18 int find(int k){
 19     if (k==fa[k][0])return k;
 20     return fa[k][0]=find(fa[k][0]);
 21 }
 22 int lca(int x,int y){
 23     if (dep[x]<dep[y])swap(x,y);
 24     for(int i=20;i>=0;i--)
 25         if (dep[fa[x][i]]>=dep[y])x=fa[x][i];
 26     if (x==y)return x;
 27     for(int i=20;i>=0;i--)
 28         if (fa[x][i]!=fa[y][i]){
 29             x=fa[x][i];
 30             y=fa[y][i];
 31         }
 32     return fa[x][0];
 33 }
 34 void dfs(int k,int f,int s){
 35     dfn[k]=++x;
 36     dep[k]=s;
 37     fa[k][0]=f;
 38     for(int i=1;i<=20;i++)fa[k][i]=fa[fa[k][i-1]][i-1];
 39     for(int p=0;p<2;p++)
 40         for(int i=0;i<e[p][k].size();i++){
 41             if (u==f)continue;
 42             las[u]=p;
 43             if (p^las[k])g[u]=s;
 44             else g[u]=g[k];
 45             dfs(u,k,s+1);
 46         }
 47 }
 48 void dp(int k,int fa){
 49     vis[k]=1;
 50     for(int p=0;p<2;p++)
 51         for(int i=0;i<e[p][k].size();i++)
 52             if (u!=fa){
 53                 dp(u,k);
 54                 f[k][p^1]+=f[u][p^1];
 55                 f[k][p]+=min(f[u][0],f[u][1]);
 56             }
 57 }
 58 void get_plan(int k,int fa,int type){
 59     vis[k]=1;
 60     if (type==2)type=(f[k][1]<f[k][0]);
 61     bl[k]=type;
 62     for(int p=0;p<2;p++)
 63         for(int i=0;i<e[p][k].size();i++)
 64             if (u!=fa){
 65                 if (p!=type)get_plan(u,k,type);
 66                 else get_plan(u,k,2);
 67             }
 68 }
 69 void calc(int l,int r,int x,int y){
 70     if (x==y){
 71         for(int i=l;i<=r;i++)ans[a[i]]=v[x];
 72         return;
 73     }
 74     sort(a+l,a+r+1,cmp1);
 75     st[0]=0;
 76     vv.clear();
 77     ve.clear();
 78     for(int j=l;j<=r;j++){
 79         vv.push_back(a[j]);
 80         if (!st[0]){
 81             st[++st[0]]=a[j];
 82             continue;
 83         }
 84         int k=lca(a[j],st[st[0]]);
 85         while ((st[0]>1)&&(dep[k]==dep[lca(a[j],st[st[0]-1])])){
 86             ve.push_back(make_pair(st[st[0]-1],st[st[0]]));
 87             st[0]--;
 88         }
 89         if (st[st[0]]!=k){
 90             vv.push_back(k);
 91             ve.push_back(make_pair(k,st[st[0]]));
 92             st[st[0]]=k;
 93         }
 94         st[++st[0]]=a[j];
 95     }
 96     for(int i=0;i<vv.size();i++){
 97         e[0][vv[i]].clear(),e[1][vv[i]].clear();
 98         vis[vv[i]]=f[vv[i]][0]=f[vv[i]][1]=0;
 99     }
100     for(;st[0]>1;st[0]--)ve.push_back(make_pair(st[st[0]-1],st[st[0]]));
101     for(int i=0;i<ve.size();i++){
102         int xx=ve[i].first,yy=ve[i].second;
103         if (dep[xx]>=g[yy]){
104             e[las[yy]][xx].push_back(yy);
105             e[las[yy]^1][yy].push_back(xx);
106         }
107     }
108     int mid=(x+y>>1),tot=0;
109     for(int j=l;j<=r;j++){
110         if (a[j]==x0)tot++;
111         if (a[j]==y0)tot++;
112     }
113     if (tot<2){
114         for(int j=l;j<=r;j++){
115             f[a[j]][0]=1LL*abs(d[a[j]]-v[mid])*w[a[j]];
116             f[a[j]][1]=1LL*abs(d[a[j]]-v[mid+1])*w[a[j]];
117         }
118         for(int j=0;j<vv.size();j++)
119             if (!vis[vv[j]])dp(vv[j],0);
120         for(int j=0;j<vv.size();j++)vis[vv[j]]=0;
121         for(int j=0;j<vv.size();j++)
122             if (!vis[vv[j]])get_plan(vv[j],0,2);
123     }
124     else{
125         ll sum0=0,sum1=0;
126         for(int p=0;p<2;p++){
127             for(int j=l;j<=r;j++){
128                 f[a[j]][0]=1LL*abs(d[a[j]]-v[mid])*w[a[j]];
129                 f[a[j]][1]=1LL*abs(d[a[j]]-v[mid+1])*w[a[j]];
130             }
131             f[x0][p^1]=oo;
132             if (p)f[y0][0]=oo;
133             for(int j=0;j<vv.size();j++)
134                 if (!vis[vv[j]]){
135                     dp(vv[j],0);
136                     if (!p)sum0+=min(f[vv[j]][0],f[vv[j]][1]);
137                     else sum1+=min(f[vv[j]][0],f[vv[j]][1]);
138                 }
139             for(int j=0;j<vv.size();j++)vis[vv[j]]=0;
140         }
141         if (sum0<sum1){
142             for(int j=l;j<=r;j++){
143                 f[a[j]][0]=1LL*abs(d[a[j]]-v[mid])*w[a[j]];
144                 f[a[j]][1]=1LL*abs(d[a[j]]-v[mid+1])*w[a[j]];
145             }
146             f[x0][1]=oo;
147             for(int j=0;j<vv.size();j++)
148                 if (!vis[vv[j]])dp(vv[j],0);
149             for(int j=0;j<vv.size();j++)vis[vv[j]]=0;
150         }
151         for(int j=0;j<vv.size();j++)
152             if (!vis[vv[j]])get_plan(vv[j],0,2);
153     }
154     sort(a+l,a+r+1,cmp2);
155     for(int j=l;j<=r+1;j++)
156         if ((j>r)||(bl[a[j]])){
157             if (l<j)calc(l,j-1,x,mid);
158             if (j<=r)calc(j,r,mid+1,y);
159             return;
160         }
161 }
162 int main(){
163     scanf("%d%d",&n,&m);
164     for(int i=1;i<=n;i++)scanf("%d",&d[i]);
165     for(int i=1;i<=n;i++)scanf("%d",&w[i]);
166     for(int i=1;i<=n;i++)fa[i][0]=i;
167     for(int i=1;i<=m;i++){
168         scanf("%d%d",&x,&y);
169         if (find(x)==find(y))x0=x,y0=y;
170         else{
171             fa[x][0]=find(y);
172             e[0][x].push_back(y);
173             e[1][y].push_back(x);
174         }
175     } 
176     x=0;
177     dfs(1,1,0);
178     memcpy(v,d,sizeof(v));
179     sort(v+1,v+n+1);
180     int nn=unique(v+1,v+n+1)-v-1;
181     for(int i=1;i<=n;i++)a[dfn[i]]=i;
182     calc(1,n,1,nn);
183     for(int i=1;i<=n;i++)sum+=1LL*w[i]*abs(d[i]-ans[i]);
184     printf("%lld",sum);
185 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/14296829.html