[题解] LuoguP4292 [WC2010]重建计划

https://www.luogu.com.cn/problem/P4292

线段树写炸调了一个小时,不愧是我

发现这是个01分数规划问题,我们的目的是在树上找一条路径使得下面这个东西最大((len_{s,t})表示(s)(t)路径上的边权和,(tot_{s,t})表示(s)(t)的路径上有多少条边)

[frac{len_{s,t}}{tot_{s,t}} ]

(满足限制(L le tot_{s,t} le U)

直接做并不是很好做,所以考虑二分,当二分到(mid)的时候,任务变为判断是否有一条路径使得

[frac{len_{s,t}}{tot_{s,t}} ge mid ]

稍微变个形

[len_{s,t}ge tot_{s,t} imes mid ]

[len_{s,t}-tot_{smt} imes midge 0 ]

我们把每条边的边权都减少(mid),然后问题变为了判断是否存在一条长度在([L,U])以内的路径满足路径上的边权和大于等于(0)

然后树上路径瞎想了一下点分治,发现二分一个(log),分治一个(log),还要用线段树啥的维护,再加一个(log),然后我们就收获了三个(log)的优秀复杂度...

由于窝太菜了...不会淀粉质的做法...所以来写一个板一点的长链剖分QAQ

考虑(DP),令(f[u][i])表示从(u)出发向下走(i)条边的最大边权和(默认以(1)节点为根)。

那么在枚举(u)的儿子(v)的时候,先枚举(i),用(f[v][i]+f[u][j])更新答案((L le i+j+1 le U)),更新完答案后再枚举(i)(f[v][i])更新(f[u][i+1])

这样(check)的复杂度是(n^2)的,但注意到以深度为下标,所以考虑长链剖分。

总之先长链剖分,然后先重链剖分那样两遍(Dfs),然后第二遍(Dfs)的时候优先遍历重儿子。

遍历的时候打上时间戳(dfn[u])

然后对于(f[u][i])直接放在(f[dfn[u]+i])上。

这样很自然的对于一点(u),他的重儿子会自动继承他的(f)

先对重儿子(dp)完后不要忘了把(f[dfn[u]+1]...f[dfn[u]+len[u]])(len[u])表示(u)沿着长链向下走到叶子节点的步数)都加上(w_{u,son[u]})(即(u)到其重儿子的边的边权),然后将(f[dfn[u]])设为(0)(f[u][0]=0))。

于是线段树维护即可,复杂度(O(Tn log n))(T)为二分次数)。

(Code:)

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;++i)
#define per(i,a,n) for (int i=n-1;i>=a;--i)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define all(x) (x).begin(),(x).end()
#define SZ(x) ((int)(x).size())
typedef double db;
typedef long long ll;
typedef pair<int,int> PII;
typedef vector<int> VI;

const int N=1e5+10;
const db INF=1e18,eps=1e-5;
int n,L,U;
vector<pair<int,db>> e[N];

namespace SegmentTree {
	struct node {
		int l,r; db add,mx;
	}t[N<<2];

	inline void pushup(int x) {t[x].mx=max(t[x<<1].mx,t[x<<1|1].mx);}
	inline void pushdown(int x) {
		if (t[x].add==0) return;
		int lc=x<<1,rc=x<<1|1; db add=t[x].add;
		t[lc].add+=add,t[rc].add+=add;
		t[lc].mx+=add,t[rc].mx+=add;
		t[x].add=0;
	}

	void build(int x,int l,int r) {
		t[x].add=0,t[x].l=l,t[x].r=r,t[x].mx=-INF;
		if (l==r) return; int mid=(l+r)>>1; 
		build(x<<1,l,mid),build(x<<1|1,mid+1,r);
		pushup(x);
	}

	void mdf(int x,int p,db v) {
		int l=t[x].l,r=t[x].r,mid=(l+r)>>1;
		if (l==r) {t[x].mx=max(t[x].mx,v);return;}
		pushdown(x);
		if (p<=mid) mdf(x<<1,p,v); else mdf(x<<1|1,p,v);
		pushup(x);
	}

	db qry(int x,int ql,int qr) {
		if (ql>qr||ql<0||qr<0) return -INF;
		int l=t[x].l,r=t[x].r,mid=(l+r)>>1;
		if (ql<=l&&r<=qr) return t[x].mx;
		pushdown(x); db ans=-INF;
		if (ql<=mid) ans=max(ans,qry(x<<1,ql,qr));
		if (mid<qr) ans=max(ans,qry(x<<1|1,ql,qr));
		return ans;
	}

	void upd(int x,int ql,int qr,db v) {
		int l=t[x].l,r=t[x].r,mid=(l+r)>>1;
		if (ql<=l&&r<=qr) {t[x].mx+=v,t[x].add+=v;return;}
		pushdown(x);
		if (ql<=mid) upd(x<<1,ql,qr,v);
		if (mid<qr) upd(x<<1|1,ql,qr,v);
		pushup(x);
	}
}
using SegmentTree::mdf;
using SegmentTree::qry;
using SegmentTree::build;
using SegmentTree::upd;

int mdep[N],len[N],dep[N],son[N],fa[N],dfn[N],tim;

void dfs1(int u,int f) {
	fa[u]=f,dep[u]=mdep[u]=dep[f]+1;
	for (auto ei:e[u]) {
		int v=ei.fi; if (v==f) continue;
		dfs1(v,u);
		if (mdep[v]>mdep[u]) 
			son[u]=v,mdep[u]=mdep[v];
	}
	len[u]=mdep[u]-dep[u]+1;  // 嘛...这里len[u]是上面说的len[u]+1...珂以理解的吧...
}
void dfs2(int u) {
	dfn[u]=++tim; if (!son[u]) return;
	dfs2(son[u]);
	for (auto ei:e[u]) if (ei.fi!=fa[u]&&ei.fi!=son[u])
		dfs2(ei.fi);
}

db mxlen;
void dp(int u) {
	if (son[u]) dp(son[u]);
	db ws=0;
	for (auto ei:e[u]) if (ei.fi==son[u]) {ws=ei.se;break;}
	upd(1,dfn[u]+1,dfn[u]+len[u]-1,ws),mdf(1,dfn[u],0);
	for (auto ei:e[u]) {
		int v=ei.fi; db w=ei.se;
		if (v==son[u]||v==fa[u]) continue;
		dp(v);
		rep(i,0,len[v]) {
			int l=L-i-1,r=min(U-i-1,len[u]-1);
			db tmp=qry(1,dfn[v]+i,dfn[v]+i);
			mxlen=max(mxlen,w+tmp+qry(1,dfn[u]+l,dfn[u]+r));
		}
		rep(i,0,len[v]) mdf(1,dfn[u]+i+1,w+qry(1,dfn[v]+i,dfn[v]+i));
	}
	int l=L,r=min(len[u]-1,U);
	mxlen=max(mxlen,qry(1,dfn[u]+l,dfn[u]+r));
}

bool check(db avg) {
	mxlen=-INF;
	rep(i,1,n+1) for (auto &ei:e[i]) ei.se-=avg;
	build(1,1,n),dp(1);
	rep(i,1,n+1) for (auto &ei:e[i]) ei.se+=avg;
	return mxlen>=0;
}

int main() {
#ifdef LOCAL
	freopen("a.in","r",stdin);
#endif
	scanf("%d%d%d",&n,&L,&U);
	rep(i,0,n-1) {
		int u,v; db w;
		scanf("%d%d%lf",&u,&v,&w);
		e[u].pb(mp(v,w));
		e[v].pb(mp(u,w));
	}
	dfs1(1,0),dfs2(1);
	db l=0,r=1e6;
	while (l+eps<r) {
		db mid=(l+r)/2.0;
		if (check(mid)) l=mid; else r=mid;
	}
	printf("%.3f
",l);
	return 0;
}
原文地址:https://www.cnblogs.com/wxq1229/p/12539567.html