【NOI2014】购票(树形dp+树剖+斜率优化)

考虑树形 dp,设 d p i dp_i dpi i i i 节点到 SZ 市的最小费用, d i s i dis_i disi 为 SZ 市到 i i i 节点的距离。

显然初始化 d p 1 = 0 dp_1=0 dp1=0,然后 d i s i dis_i disi 可以提前预处理出来。

然后有 d p u = min ⁡ ( d p v + ( d i s u − d i s v ) × p u + q u ) dp_u=min(dp_v+(dis_u-dis_v) imes p_u+q_u) dpu=min(dpv+(disudisv)×pu+qu)。( v v v u u u 的祖先, d i s u − d i s v ≤ l u dis_u-dis_vleq l_u disudisvlu)。

拆开来: d p u = d i s u × p u + q u + min ⁡ ( − d i s v × p u + d p v ) dp_u=dis_u imes p_u+q_u+min(-dis_v imes p_u+dp_v) dpu=disu×pu+qu+min(disv×pu+dpv)

发现拆出来跟 v v v 有关系的只有 − d i s v × p u + d p v -dis_v imes p_u+dp_v disv×pu+dpv

然后看到这个东西就联想到斜率优化。

那么思路就显而易见了:先对原树进行树剖,然后对于每个点 i i i 维护一条直线 f i ( x ) = − d i s i × x + d p i f_i(x)=-dis_i imes x+dp_i fi(x)=disi×x+dpi

至于如何维护这些直线,我们可以在每个线段树节点开一个 vector,那么可以证明线段树总的空间复杂度是 O ( n log ⁡ n ) O(nlog n) O(nlogn),因为每一层只有 n n n 条直线,一共有 log ⁡ n log n logn 层。

然后要不断插入一些直线,再用类似半平面交的方法维护下凸壳。

至于 l l l 的限制,我们可以先用倍增求出来每一个点最远能跳到哪个点。

顺便再推荐一道类似的好题:陶陶的难题2。

代码如下:(有详细注释)

#include<bits/stdc++.h>

#define LN 18
#define N 200010
#define ll long long
#define INF 0x7fffffffffffffff

using namespace std;

struct Point
{
	double x,y;
	Point(){};
	Point(double xx,double yy){x=xx,y=yy;}
};

struct Line
{
	ll k,b;//直线的斜率、截距
	Line(){};
	Line(ll kk,ll bb){k=kk,b=bb;}
	ll get(ll x)
	{
		return k*x+b;
	}
};

Point intersection(Line a,Line b)//两直线求交点
{
	double x=1.0*(b.b-a.b)/(a.k-b.k);
	double y=1.0*a.k*x+a.b;
	return Point(x,y);
}

bool onright(Point a,Line l)//判断一个点是否在这条直线的上方
{
	return 1.0*l.k*a.x-a.y+l.b<=0;
}

int n,subtask;
int cnt,head[N],nxt[N],to[N];
int tot,size[N],d[N],son[N],id[N],rk[N],top[N],farest[N];
int fa[N][LN];
ll p[N],q[N],limit[N],dis[N],w[N],dp[N],sum[N][LN];//有很多地方要开long long,注意

vector<Line>line[N<<2];

void adde(int u,int v,ll wi)
{
	to[++cnt]=v;
	w[cnt]=wi;
	nxt[cnt]=head[u];
	head[u]=cnt;
}

void dfs(int u)
{
	size[u]=1;
	for(int i=1;i<=17;i++)
	{
		fa[u][i]=fa[fa[u][i-1]][i-1];
		sum[u][i]=sum[u][i-1]+sum[fa[u][i-1]][i-1];
	}
	int now=u;
	for(int i=17;i>=0;i--)//倍增预处理最远能跳到哪个点
	{
		if(fa[now][i]&&limit[u]>=sum[now][i])
		{
			limit[u]-=sum[now][i];
			now=fa[now][i];
		}
	}
	farest[u]=now;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
        d[v]=d[u]+1;
		dis[v]=dis[u]+w[i];
		sum[v][0]=w[i];
		dfs(v);
		size[u]+=size[v];
		if(size[v]>size[son[u]]) son[u]=v;
	}
}

void dfs1(int u,int tp)
{
	top[u]=tp;
	id[u]=++tot;
	rk[tot]=u;
	if(son[u]) dfs1(son[u],tp);
	for(int i=head[u];i;i=nxt[i])
		if(to[i]!=son[u])
			dfs1(to[i],to[i]);
}

void half(int k,Line tmp)//插入新直线
{
	int size=line[k].size();
	while(size>=2&&onright(intersection(line[k][size-2],line[k][size-1]),tmp))//记得加size>=2
	{
		size--;
		line[k].pop_back();
	}
	line[k].push_back(tmp);
}

void insert(int k,int l,int r,int qx,Line tmp)
{
	half(k,tmp);
	if(l==r) return;
	int mid=(l+r)>>1;
	if(qx<=mid) insert(k<<1,l,mid,qx,tmp);
	else insert(k<<1|1,mid+1,r,qx,tmp);
}

ll query(int k,int l,int r,int ql,int qr,ll x)
{
	if(ql<=l&&r<=qr)
	{
		int siz=line[k].size()-1,L=0,R=siz,ans;
		while(L<=R)//二分枚举x在哪条直线上
		{
			int mid=(L+R)>>1;
			double lp=mid>0?intersection(line[k][mid-1],line[k][mid]).x:-INF;//算出直线mid与直线mid-1的交点
			double rp=mid<siz?intersection(line[k][mid],line[k][mid+1]).x:INF;//算出直线mid与直线mid+1的交点
			if(lp<=x&&x<=rp)
			{
				ans=mid;
				break;
			}
			if(x<lp) R=mid-1;
			else L=mid+1;
		}
		return line[k][ans].get(x);//将x代入得到y值
	}
	int mid=(l+r)>>1;
	ll ans=INF;
	if(ql<=mid) ans=min(ans,query(k<<1,l,mid,ql,qr,x));
	if(qr>mid) ans=min(ans,query(k<<1|1,mid+1,r,ql,qr,x));
	return ans;
}

void solve(int u)
{
	insert(1,1,n,id[u],Line(-dis[u],dp[u]));//将直线插入线段树内
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i],a=u,b=farest[v];
		ll ans=INF;
		while(top[a]!=top[b])//树剖求dp
		{
			ans=min(ans,query(1,1,n,id[top[a]],id[a],p[v]));
			a=fa[top[a]][0];
		}
		ans=min(ans,query(1,1,n,id[b],id[a],p[v]));
		dp[v]=ans+dis[v]*p[v]+q[v];
		solve(v);
	}
}

int main()
{
	scanf("%d%d",&n,&subtask);
	for(int i=2;i<=n;i++)
	{
		ll w;
		scanf("%d%lld%lld%lld%lld",&fa[i][0],&w,&p[i],&q[i],&limit[i]);
		adde(fa[i][0],i,w);
	}
	dfs(1);
	dfs1(1,1);
	solve(1);
	for(int i=2;i<=n;i++)
		printf("%lld
",dp[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/ez-lcw/p/14448668.html