【YbtOJ#763】攻城略池

题目

题目链接:https://www.ybtoj.com.cn/contest/120/problem/1

(nleq 10^5,l_ileq 10^3,d_ileq 10^8)

思路

(f_x) 是点 (x) 被攻占的时间。显然这个值可以二分,然后枚举子树内的每一个点,计算在二分到的时间内从枚举到的点可以过去多少人。
(mid) 时间内会被攻占当且仅当

[d_xleq sum_{yin ext{subtree}(x)}max(mid-f_y-( ext{dep}_y- ext{dep}_x),0) ]

把括号拆开来,考虑把 (x) 子树内的点扔到权值线段树上,权值线段树上的节点 ([l,r]) 储存所有 ( ext{dep}_y+f_yin[l,r]) 的点的权值之和以及数量。
然后二分可以直接在线段树上二分,当我们到达区间 ([l,r]) 时,记 (c) 为权值在 ([1,mid]) 的点的数量,(v) 为权值在 ([1,mid]) 的点的权值和,那么我们往右边二分当且仅当

[c imes mid-v<a_x ]

然后往上的时候线段树合并就可以了。
线段树值域上界是 (3 imes 10^8),动态开点就可以了。
时间复杂度 (O(nlog (d+nl)))

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=100010,LG=30,MAXN=N*LG*4,Lim=3e8;
int n,ans,tot,a[N],head[N],dep[N],rt[N],f[N];

struct edge
{
	int next,to,dis;
}e[N*2];

void add(int from,int to,int dis)
{
	e[++tot]=(edge){head[from],to,dis};
	head[from]=tot;
}

struct SegTree
{
	int lc[MAXN],rc[MAXN],cnt[MAXN];
	ll sum[MAXN];
	
	int merge(int x,int y)
	{
		if (!x || !y) return x|y;
		sum[x]+=sum[y]; cnt[x]+=cnt[y];
		lc[x]=merge(lc[x],lc[y]);
		rc[x]=merge(rc[x],rc[y]);
		return x;
	}
	
	int update(int x,int l,int r,int v)
	{
		if (!x) x=++tot;
		cnt[x]++; sum[x]+=v;
		if (l==r) return x;
		int mid=(l+r)>>1;
		if (v<=mid) lc[x]=update(lc[x],l,mid,v);
			else rc[x]=update(rc[x],mid+1,r,v);
		return x;
	}
	
	int query(int x,int l,int r,int k,int c,ll s)
	{
		if (l==r) return l;
		int mid=(l+r)>>1;
		ll ans=1LL*(cnt[lc[x]]+c)*mid-(sum[lc[x]]+s);
		if (ans>=a[k]) return query(lc[x],l,mid,k,c,s);
			else return query(rc[x],mid+1,r,k,c+cnt[lc[x]],s+sum[lc[x]]);
	}
}seg;

void dfs(int x,int fa)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			dep[v]=dep[x]+e[i].dis;
			dfs(v,x);
			rt[x]=seg.merge(rt[x],rt[v]);
		}
	}
	f[x]=max(seg.query(rt[x],0,Lim,x,0,0)-dep[x],0);
	rt[x]=seg.update(rt[x],0,Lim,f[x]+dep[x]);
	ans=max(ans,f[x]);
}

signed main()
{
	freopen("conquer.in","r",stdin);
	freopen("conquer.out","w",stdout);
//	return printf("%d
",sizeof(seg)/1024/1024),0;
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1,x,y,z;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z); add(y,x,z);
	}
	dfs(1,0);
	printf("%lld",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/stoorz/p/14426212.html