wqs 二分学习笔记

又称为带权二分

一种优化凸函数 dp 的方式,明显的标志是选 k 个。

一般这种玩意都是可以强套一个 wqs 二分上去,消一个 O(n) 加一个 (O(log)),而且还是从状态数上消一个。

我们从 LCT 这道题来引入。

首先题目要求选 k+1 条不相交链的权值和最大。

设出 (dp[i][j][0/1/2]) 表示以 (i) 为根的子树在图上的度数为 (0,1,2),他的子树中含有 (j) 条链。并且当度数为 (1) 的时候这条链不计入 (j) 中。

分类讨论的有点繁琐就不写了,看看代码吧(讲真,这个 dp 还挺神仙的

我觉得 xtq 树形 dp 的方式还挺用的

inline void dfs(int x,int ff)
{
	dp[0][x][0] = dp[1][x][0] = dp[2][x][1] = dp[3][x][0] = 0;
	for(int i=head[x],v;i;i=nxt[i])
	{
		v = ver[i];
		if(v == ff) continue;
		dfs(v , x);
		// for(int u = 0;u <= 2;u++)
		// 	for(int j=0;j <= k;j++) aux[u][j] = -INF;
		memset(aux[0],0xcf,sizeof(aux[0])),memset(aux[1],0xcf,sizeof(aux[1]));
		memset(aux[2],0xcf,sizeof(aux[2]));
		for(int u = 0;u <= k;u++)
			for(int q = 0;q + u <= k;q ++)
				aux[0][u + q] = max(aux[0][u + q],dp[0][x][u] + dp[3][v][q]);
		for(int u = 0;u <= k;u ++)
			for(int q = 0;q + u <= k;q ++)
				aux[1][u + q] = max(aux[1][u + q],max(dp[1][x][u] + dp[3][v][q],dp[0][x][u] + dp[1][v][q] + Edge[i]));
		for(int u = 0;u <= k ; u++)
			for(int q = 0;q + u<= k;q ++)
			{
				aux[2][u + q] = max(aux[2][u + q],dp[2][x][u] + dp[3][v][q]);
				// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
			}

		for(int u = 0;u <= k ; u++)
			for(int q = 0;q + u + 1<= k;q ++)
			{
				// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[2][x][u] + dp[3][v][q]);
				aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
			}
			// printf("%d : 
",x);
		for(int u =0 ;u <=2;u++)
			for(int q = 0;q <= k;q ++ )
			{
				// printf("aux[%d][%d] = %d
",u,q,aux[u][q]);
				dp[u][x][q] = aux[u][q];
			} 
		dp[1][x][0] = max(dp[1][x][0] , dp[1][v][0] + Edge[i]);
	}
	// dp[0][x][1] = max(0,dp[0][x][1]);
	for(int i=1;i<=k;i++)
	dp[3][x][i] = max(dp[0][x][i],max(dp[1][x][i - 1],dp[2][x][i]));
	// for(int i=1;i<=)
}

我们现在的 (dp)(O(nk^2)) 的。这个复杂度大的离谱。

这时候请出我们的带权二分来。这里默认我们的 dp 函数是一个凸函数

我们将原来的 (dp)(j) 那维限制去掉,这样就可以将复杂度降到 (O(n))。但是这样不能保证我们恰好选了 k+1 条链,所以我们要对原函数做一些魔改。

设原函数为 (ans(x)),当 (ans'(x)=0) 时,ans(x) 取得最大值。我们不加修改的 dp 求出来的就是这个东西。

现在我们设一个新函数 (g(x) = ans(x) +val imes x),这个函数一阶导为减函数,二阶导为一个上凸函数。所以我们可以通过调节 val (就是斜率)来调节 (g'(x)) 的零点,这样就能调节出当 (g(x)) 取得最值得时候(ans(x)) 恰好取得 (k+1) 条链,这样皆大欢喜。

关于恰好选 k 个是一个凸函数,你要想如果恰好选 (1) 个,那我们肯定选择最大的那个,选两个我会把次大的选上,这样每次的增量都不如上一个大,就会形成一个凸函数。更仔细一点,如果有那么一点点限制,要求恰好选 (k) 个,那我后面选的东西可能会影响到前面选的,并且这时候还要求数量达到我们要求的,就被迫舍弃权值最优,来追求数量,这就导致了凸函数的后半段产生。

复杂度 (O(blog k))

关于 wqs 的实际操作来说,有一点点细节。

对最大值来说:你考虑二分一个惩罚值,当你选的少了的时候,我们想让它下次选得再多一点,就会把惩罚值下调,反之就会上调。

对最小值来说:选得少了的时候,我们想让它下次再多选一点,惩罚值就会下调,反之上调。

对于多点共线的情况,我们优先选物品最少的或者最多的,二分的时候只要物品在 k 的我们指定的一侧时就去更新答案。

P4383 [八省联考2018]林克卡特树

#include<bits/stdc++.h>

using namespace std;

#define int long long
#define pii pair<int,int>

template<typename _T>
inline void read(_T &x)
{
	x=0;char s=getchar();int f=1;
	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
	x*=f;
}
const int np = 3e5 + 5;
int head[np],ver[np * 2],nxt[np * 2],Edge[np * 2];
int tit;
inline void add(int x,int y,int w)
{
	ver[++tit] = y;
	Edge[tit] = w;
	nxt[tit] = head[x];
	head[x] = tit;
}

struct qwq
{
	int f,fanga;

	friend qwq operator+(qwq a,qwq b)
	{
		return (qwq){a.f + b.f , a.fanga + b.fanga};
	}

	inline friend  qwq Max(qwq a,qwq b)
	{
		if(a.f == b.f)
		{
			if(a.fanga > b.fanga) return a;
			else return b;
		} 
		if(a.f > b.f) return a;
		else return b;
	}

}dp[5][np];
// dp[d][i] 表示前以 i 为根的树,度数为 d 的
int n,k,sakura;

inline void dfs(int x,int ff)
{
	dp[0][x] = (qwq){0,0},dp[1][x] = (qwq){0,0},dp[2][x] = (qwq){sakura,1};
	for(int i=head[x],v;i;i=nxt[i])
	{
		v = ver[i];
		if(v == ff) continue;
		dfs(v,x);
		dp[2][x] = Max(dp[2][x] + dp[3][v],dp[1][x] + dp[1][v] + (qwq){Edge[i] + sakura,1});
		dp[1][x] = Max(dp[1][x] + dp[3][v],dp[0][x] + dp[1][v] + (qwq){Edge[i],0});
		dp[0][x] = dp[0][x] + dp[3][v];
	}
	dp[3][x] = Max(dp[0][x],Max(dp[1][x] + (qwq){sakura,1},dp[2][x]));
}

inline void judging(int x)
{
	sakura = x;
	dfs(1,0);
}

signed main()
{
	read(n),read(k);
	k++;
	for(int i=1,a,b,w;i<=n - 1;i ++ )
	{
		read(a),read(b),read(w);
		add(a,b,w);
		add(b,a,w);
	}
	int l = -1e8,r = 1e8,Ans=0;
	while(l <= r)
	{
		int mid = l + r >> 1;
		judging(mid);
		if(dp[3][1].fanga >= k)
		{
			Ans = dp[3][1].f - k * mid;
			// printf("%lld %lld
",dp[3][1].fanga,Ans);
			r = mid - 1;
		}
		else l = mid + 1;
	}
	printf("%lld",Ans);
}

P6246 [IOI2000] 邮局 加强版

#include<bits/stdc++.h>

using namespace std;

#define int long long
#define pii pair<int,int>

template<typename _T>
inline void read(_T &x)
{
	x=0;char s=getchar();int f=1;
	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
	x*=f;
}
const int np = 1e6 + 5;
int a[np];
int sum[np],n,k;

inline int Abs(int x)
{
	return x < 0?-x:x;
}

inline int calc(int l,int r)
{
	int op = l + r >> 1;
	return a[op] * (op - l + 1) - sum[op]+sum[l - 1] + Abs(a[op] * (r-op+1) - sum[r] + sum[op - 1]);
}

struct qwq
{
	int f,fanga;
	friend qwq operator+(qwq a,qwq b)
	{
		return (qwq){a.f + b.f,a.fanga + b.fanga};
	}
	friend bool operator<(qwq a,qwq b)
	{
		if(a.f == b.f) return a.fanga < b.fanga;
		else return a.f < b.f;
	}
}dp[np];

int sakura;
// int l_[2333],r_[2333],juec[2333];
struct qaq
{
	int l_,r_,juec;
	// int nx
}que[np * 2];
int top = 0;

inline int binary(qaq u,int op)
{
	int l = u.l_,r = u.r_,opt = u.juec,ans =  u.r_ + 1;//l <= op?op:0;
	while(l <= r)
	{
		int mid = l + r >> 1;
		if(dp[op] + (qwq){calc(op + 1,mid) + sakura,1} < dp[opt] + (qwq){calc(opt + 1,mid) + sakura,1}) ans = mid,r = mid - 1;
		else l = mid + 1;
	}
	return ans;
}

inline void solve()
{
	int head = 1,tail = 1;
	dp[0] = (qwq){0,0};
	que[head] = (qaq){1,n,0};
	for(int i=1;i<=n;i++)
	{
		while(head < tail && que[head].r_ < i) head++;
		int j = que[head].juec;
		dp[i] = dp[j] + (qwq){calc(j + 1,i) + sakura,1};// + sakura;
		int spilt = 0;
		while(head < tail && que[tail].l_ == binary(que[tail],i)) spilt = que[tail].l_,tail--;
		spilt = binary(que[tail],i);
		if(spilt)
		{
			que[tail].r_ = spilt - 1;
	//		printf("%lld ",spilt);
			que[++tail] = (qaq){spilt,n,i};			
		}
	}
	// if(sakura == 0)
//	for(int i=1;i<=n;i++)
//	{
//		printf("%lld ",dp[i].f);
//	}
//	printf("
");
}

namespace subtask{
	int fp[500][4333];
	inline int bbinary(int c,qaq u,int op)
	{
		int l = u.l_,r = u.r_,opt = u.juec,ans =  u.r_ + 1;//l <= op?op:0;
		while(l <= r)
		{
			int mid = l + r >> 1;
			if(fp[c - 1][op] + calc(op + 1,mid) < fp[c - 1][opt] + calc(opt + 1,mid)) ans = mid,r = mid - 1;
			else l = mid + 1;
		}
		return ans;
	}
	inline void solve1(int c)
	{
		int head = 1,tail = 1;
		fp[c][0] = 0;
		que[head] = (qaq){1,n,0};
		for(int i=1;i<=n;i++)
		{
			while(head < tail && que[head].r_ < i) head++;
			int j = que[head].juec;
			fp[c][i] = fp[c - 1][j] + calc(j + 1,i);// + sakura;
			int spilt = 0;
			while(head < tail && que[tail].l_ == bbinary(c,que[tail],i)) spilt = que[tail].l_,tail--;
//			int spilt = 0;
			spilt = bbinary(c,que[tail],i);
			if(spilt)
			{
				que[tail].r_ = spilt - 1;
		//		printf("%lld ",spilt);
				que[++tail] = (qaq){spilt,n,i};			
			}
		}
	}

	inline void Main()
	{
		memset(fp,0x3f,sizeof(fp));
		fp[1][0] = 0;
		for(int i=1;i<=n;i++) fp[1][i] = calc(1,i);
		for(int i=2;i <= k;i++)
		{
			solve1(i);
//			for(int j=1;j<=n;j++)
//			printf("%lld ",fp[i][j]);
//			printf("
");
		}
		printf("%lld",fp[k][n]);
	}
}

inline void judging(int x)
{
	sakura = x;
	solve();
}

signed main()
{
	read(n),read(k);
	for(int i=1;i<=n;i++)
	{
		read(a[i]);
		sum[i] = sum[i - 1] + a[i];
	}	
	int l = -1e8,r = 1e8,Ans(0);
	 while(l <= r)
	 {
	 	int mid = l + r >> 1;
	 	judging(mid);
//	 	printf("%lld %lld
",dp[n].f,dp[n].fanga);
	 	if(dp[n].fanga <= k)
	 	{
	 		Ans = dp[n].f - k * sakura;
//			printf("%lld
",Ans);
	 		r = mid - 1;
	 	}
	 	else l = mid + 1;
	 }
	printf("%lld",Ans);
}

据说这个东西还能用来优化个模拟费用流啥的,笑死,根本不会费用流。

原文地址:https://www.cnblogs.com/-Iris-/p/15340311.html