题解-[CEOI2017]Building Bridges

[CEOI2017]Building Bridges

(n) 个桥墩,高 (h_i)(w_i)。连接 (i)(j) 消耗代价 ((h_i-h_j)^2),用不到的桥墩被拆除,代价为 (w_i)。求使 (1)(n) 联通的最小代价。

数据范围:(2le nle 10^5)(0le h_i,|w_i|le 10^6)


非常经典的李超线段树维护 ( exttt{dp}) 的题目,小蒟蒻来分享一下。


很明显 (w_i) 是大片大片消耗的,所以记 (s_i=sum_{j=1}^i w_j)

(f_i) 表示连接到第 (i) 个桥墩的最小代价。可以野蛮推式:

[egin{split} f_i=&min{f_j+(h_i-h_j)^2+s_{i-1}-s_j}\ =&min{f_j+h_i^2-2h_ih_j+h_j^2+s_{i-1}-s_j}\ =&h_i^2+s_{i-1}+min{f_j-2h_ih_j+h_j^2-s_j}\ end{split}\ ]

这貌似是个斜率优化式子,但蒟蒻不管,用李超线段树怎么做呢?

考虑李超线段树的作用:多条线段(直线),求单点最值。

发现这个 (j) 有很多,而 (i) 就只有当前一个:所以可以 (i) 对应单点,(j) 对应线。换句话说,可以把每个 (f_i) 求出来后添加一条直线

[ extrm{let line } g(x)=(-2h_j)x+(f_j+h_j^2-s_j)\ extrm{let } x=h_i extrm{ to get }min{f_j-2h_ih_j+h_j^2-s_j} extrm{.}\ ]


这题有几个坑,本来是应该由你来快乐地调试的,但是既然写了题解,蒟蒻就放出来了:

  1. 因为要计算 (h_j^2),所以要开 ( exttt{long long}) 或用 (1ll) 乘之。
  2. 这个李超线段树是权值线段树,下标要开 (10^6) 个,节点个数要开 (4cdot 10^6) 个。

这个做法貌似有点辜负了这题的难度,但是蒟蒻只会这么做。蒟蒻讲不清楚,还是放个蒻蒻的代码吧:

#include <bits/stdc++.h>
using namespace std;

//Start
#define lng long long
#define db double
#define mk make_pair
#define pb push_back
#define fi first
#define se second
#define rz resize
const int inf=0x3f3f3f3f;
const lng INF=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1e5,M=1e6;
int n,h[N+7];
lng w[N+7],f[N+7];

//Lichaotree
typedef pair<lng,lng> line;
lng g(line&li,int x){return li.fi*x+li.se;}
int inter(line&la,line&lb){return db(lb.se-la.se)/(la.fi-lb.fi);}
line v[(M<<2)+7];
void add(line li,int k=1,int l=0,int r=M){
	int mid((l+r)>>1);
	lng ly1=g(li,l),ry1=g(li,r),ly=g(v[k],l),ry=g(v[k],r);
	if(ly1>=ly&&ry1>=ry);
	else if(ly1<=ly&&ry1<=ry) v[k]=li;
	else {
		int in=inter(li,v[k]);
		if(ly1<=ly){
			if(in<=mid) add(li,k<<1,l,mid);
			else add(v[k],k<<1|1,mid+1,r),v[k]=li;
		} else {
			if(in>mid) add(li,k<<1|1,mid+1,r);
			else add(v[k],k<<1,l,mid),v[k]=li;
		}
	}
}
lng get(int x,int k=1,int l=0,int r=M){
	lng res(g(v[k],x));
	if(l==r) return res;
	int mid((l+r)>>1);
	if(mid>=x) res=min(res,get(x,k<<1,l,mid));
	else res=min(res,get(x,k<<1|1,mid+1,r));
	return res;
}

//Main
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&h[i]);
	for(int i=1;i<=n;i++) scanf("%lld",&w[i]),w[i]+=w[i-1];
	fill(v+1,v+(M<<2)+1,mk(0,INF));
	f[1]=0,add(mk(-2ll*h[1],1ll*h[1]*h[1]-w[1]));
	for(int i=2;i<=n;i++){
		f[i]=1ll*h[i]*h[i]+w[i-1]+get(h[i]);
		add(mk(-2ll*h[i],f[i]+1ll*h[i]*h[i]-w[i]));
	}
	printf("%lld
",f[n]);
	return 0;
}

祝大家学习愉快!

原文地址:https://www.cnblogs.com/George1123/p/12803378.html