P4542[ZJOI2011]营救皮卡丘【费用流,Floyd】

正题

题目链接:https://www.luogu.com.cn/problem/P4542


题目大意

给出\(n+1\)个点\(m\)条边的无向图,\(k\)个人开始在\(0\)号点,一个人进入\(i\)号点之前必须要有人经过\(i-1\)号点,求第一个人进入\(n\)号点时所有人的最短移动距离和。

\(1\leq n\leq 150,1\leq m\leq 2\times 10^4,1\leq k\leq 10\)


解题思路

显然不能建\(n\times n\)个点跑费用流,考虑怎么优化。

我们可以缩去一些中间路程,对于每个人只留下第一次到达该点的这些点,但是我们需要适当改变边权。

\(Floyd\)求出\(d_{i,j}\)表示从\(i\)走到\(j\)且只走编号不大于\(max\{i,j\}\)的点的最短距离,这样因为如果一个人要走到\(j\),那么它一定是第一个到的,所以不能走过大于\(j\)的点,而前面的我们可以调整每个人的行走顺序来让前面的点都解锁后这个人再出发。

现在问题就变为了求\(k\)条权值和最小的路径覆盖所有点。其实不用上下界,因为是费用流,所以我们每个点拆成出/入点,然后入点向出点连一条\((1,-inf)\)\((inf,0)\)的边(前面是流量,后面是费用)

这样如果一个点不走会多一堆费用,所以肯定会经过所有点。

这样点数就是\(O(n)\)级别了


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define ll long long
using namespace std;
const ll N=310,inf=1e9;
struct node{
	ll to,next,w,c;
}a[N*N*10];
ll n,m,k,s,t,ans,tot=1;
ll ls[N],f[N],mf[N],d[N][N],pre[N];
bool v[N];queue<int> q;
void addl(ll x,ll y,ll w,ll c){
	a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;a[tot].w=w;a[tot].c=c;
	a[++tot].to=x;a[tot].next=ls[y];ls[y]=tot;a[tot].w=0;a[tot].c=-c;
	return;
}
bool spfa(){
	memset(f,0x3f,sizeof(f));
	f[s]=0;q.push(s);v[s]=1;mf[s]=inf;
	while(!q.empty()){
		ll x=q.front();q.pop();v[x]=0;
		for(ll i=ls[x];i;i=a[i].next){
			ll y=a[i].to;
			if(a[i].w&&f[x]+a[i].c<f[y]){
				f[y]=f[x]+a[i].c;pre[y]=i;
				mf[y]=min(mf[x],a[i].w);
				if(!v[y])q.push(y),v[y]=1;
			}
		}
	}
	return f[t]<1e18;
}
void updata(){
	ll x=t;ans+=mf[x]*f[x];
	while(x!=s){
		a[pre[x]].w-=mf[t];
		a[pre[x]^1].w+=mf[t];
		x=a[pre[x]^1].to;
	}
	return;
}
signed main()
{
	scanf("%lld%lld%lld",&n,&m,&k);
	memset(d,0x3f,sizeof(d));n++;
	for(ll i=1;i<=m;i++){
		ll x,y,w;
		scanf("%lld%lld%lld",&x,&y,&w);x++;y++;
		d[x][y]=min(d[x][y],w);
		d[y][x]=min(d[y][x],w);
	}
	for(ll i=1;i<=n;i++)d[i][i]=0;
	for(ll k=1;k<=n;k++)
		for(ll i=1;i<=n;i++)
			for(ll j=1;j<=n;j++)
				if(k<i||k<j)d[i][j]=min(d[i][j],d[i][k]+d[k][j]);
	s=2*n+1;t=s+1;
	addl(s,1,k,0);
	for(ll i=1;i<=n;i++){
		addl(i,i+n,1,-inf);
		addl(i,i+n,inf,0);
		addl(i+n,t,inf,0);
		for(ll j=i+1;j<=n;j++)
			if(d[i][j]<1e18)addl(i+n,j,inf,d[i][j]);
	}
	while(spfa())
		updata();
	printf("%lld\n",ans+n*inf);
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14426187.html