[BZOJ3991][SDOI2015]寻宝游戏

BZOJ
Luogu

sol

用set维护有宝物的点集。
可以证明行走路径(a[1],a[2]...a[n])一定是按照点的dfs序排列。
因为(dist(u,v)=dep[u]+dep[v]+2*dep[lca(u,v)]),dfs序相邻可以最小化(dep[lca(u,v)])
插入要加上两个新贡献并减去一个旧贡献,删除也是减去两个旧贡献并加上一个新贡献。

code

#include<cstdio>
#include<algorithm>
#include<set>
using namespace std;
#define ll long long
int gi()
{
    int x=0,w=1;char ch=getchar();
    while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
    if (ch=='-') w=0,ch=getchar();
    while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return w?x:-x;
}
const int N = 1e5+5;
struct edge{int to,next,w;}a[N<<1];
int n,m,head[N],cnt,fa[N],dep[N],sz[N],son[N],top[N],dfn[N];ll dis[N],ans;
struct node{
	int u;
	bool operator < (const node &b) const
		{return dfn[u]<dfn[b.u];}
};
set<node>S;
void link(int u,int v,int w){a[++cnt]=(edge){v,head[u],w};head[u]=cnt;}
void dfs1(int u,int f)
{
	fa[u]=f;dep[u]=dep[f]+1;sz[u]=1;
	for (int e=head[u];e;e=a[e].next)
	{
		int v=a[e].to;if (v==f) continue;
		dis[v]=dis[u]+a[e].w;
		dfs1(v,u);
		sz[u]+=sz[v];if (sz[v]>sz[son[u]]) son[u]=v;
	}
}
void dfs2(int u,int up)
{
	top[u]=up;dfn[u]=++cnt;
	if (son[u]) dfs2(son[u],up);
	for (int e=head[u];e;e=a[e].next)
		if (a[e].to!=fa[u]&&a[e].to!=son[u])
			dfs2(a[e].to,a[e].to);
}
int lca(int u,int v)
{
    while (top[u]^top[v])
    {
        if (dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }
    return dep[u]<dep[v]?u:v;
}
ll dist(int u,int v){return dis[u]+dis[v]-dis[lca(u,v)]*2;}
set<node>::iterator pre(set<node>::iterator t)
{
	if (t==S.begin()) t=S.end();
	t--;
	return t;
}
set<node>::iterator nxt(set<node>::iterator t)
{
	t++;
	if (t==S.end()) t=S.begin();
	return t;
}
set<node>::iterator t,tl,tr;
int main()
{
	n=gi();m=gi();
	for (int i=1,u,v,w;i<n;++i)
	{
		u=gi();v=gi();w=gi();
		link(u,v,w);link(v,u,w);
	}
	dfs1(1,0);cnt=0;dfs2(1,1);
	while (m--)
	{
		int u=gi();
		if (S.find((node){u})==S.end())
		{
			S.insert((node){u});
			t=S.find((node){u});
			tl=pre(t);tr=nxt(t);
			ans+=dist(u,(*tl).u)+dist(u,(*tr).u);
			ans-=dist((*tl).u,(*tr).u);
		}
		else
		{
			if (S.size()==1) {S.erase((node){u});puts("0");continue;}
			t=S.find((node){u});
			tl=pre(t);tr=nxt(t);
			ans-=dist(u,(*tl).u)+dist(u,(*tr).u);
			ans+=dist((*tl).u,(*tr).u);
			S.erase((node){u});
		}
		printf("%lld
",ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/zhoushuyu/p/8463658.html