树的统计

树的统计

思路

用线段树维护单点修改,区间查询最大值,区间和

注意

记得修改的点对应的下标在线段树上是(id[u]),因为我们的线段树是按(dfn)形成的

和其他的树链剖分题是一样的

注意

当查询区间最大值的时候,一定记得先把下标设成负无穷,因为我们的答案有可能是负数

代码

#include<bits/stdc++.h>
using namespace std;
const int N=200010;
int ne[N],head[N],ver[N],idx;
int id[N],cnt,nw[N];
int top[N],son[N],sz[N],dep[N],fa[N];
int n,m;
int w[N];
string s;
const int inf=0x3f3f3f3f;
void add(int u,int v)
{
	ne[idx]=head[u];
	ver[idx]=v;
	head[u]=idx;
	idx++;
}

void dfs1(int u,int father,int depth)
{
	fa[u]=father;
	dep[u]=depth;
	sz[u]=1;
	for(int i=head[u];i!=-1;i=ne[i])
	{
		int j=ver[i];
		if(j==father)continue;
		dfs1(j,u,depth+1);
		sz[u]+=sz[j];
		if(sz[son[u]]<sz[j]) son[u]=j;
	}
}

void dfs2(int u,int t)
{
	top[u]=t;
	id[u]=++cnt;
	nw[cnt]=w[u];
	if(!son[u])return ;
	dfs2(son[u],t);
	for(int i=head[u];i!=-1;i=ne[i])
	{
		int j=ver[i];
		if(j==fa[u]||j==son[u]) continue;
		dfs2(j,j);
	}
}
struct node{
	int l,r;
	long long sum,maxx;
}tr[N*4];

void pushup(int p)
{
	tr[p].maxx=max(tr[p<<1].maxx,tr[p<<1|1].maxx);
	tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum;
}

void build(int p,int l,int r)//建树
{
	tr[p]={l,r,nw[r],nw[r]};
	if(l==r)return ;
	int mid=(l+r)/2;
	build(p<<1,l,mid);
	build(p<<1|1,mid+1,r);
	pushup(p);//记得pushup
}


void update(int p,int x,int d)
{
	if(tr[p].l==tr[p].r) {tr[p].maxx=d,tr[p].sum=d;return ;}
	int mid=(tr[p].l+tr[p].r)/2;
	if(x<=mid) update(p<<1,x,d);
	if(x>mid) update(p<<1|1,x,d);
	pushup(p);//记得pushup
}

long long query_max(int p,int l,int r)//区间查询最大值
{
	long long ans=-inf;
	if(tr[p].l>=l&&tr[p].r<=r)
	{
		return tr[p].maxx;
	}
	int mid=(tr[p].l+tr[p].r)/2;
	if(l<=mid) ans=max(ans,query_max(p<<1,l,r));
	if(r>mid) ans=max(ans,query_max(p<<1|1,l,r));
	return ans;
}

long long query_sum(int p,int l,int r)//区间查询和
{
	long long ans=0;
	if(tr[p].l>=l&&tr[p].r<=r)
	{
		return tr[p].sum;
	}
	int mid=(tr[p].l+tr[p].r)/2;
	if(l<=mid) ans+=query_sum(p<<1,l,r);
	if(r>mid) ans+=query_sum(p<<1|1,l,r);
	return ans;
}

long long tree_query_max(int u,int v)//链上查询最大值
{
	long long ans=-inf;
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])
		 swap(u,v);
		 ans=max(ans,query_max(1,id[top[u]],id[u]));
		 u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);
	ans=max(ans,query_max(1,id[v],id[u]));
	return ans;
}

long long tree_query_sum(int u,int v)//链上查询区间和
{
	long long ans=0;
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])
			swap(u,v);
		ans+=query_sum(1,id[top[u]],id[u]);
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);
	ans+=query_sum(1,id[v],id[u]);
	return ans;
}
inline int read()
{
	int x=0;
	int f=1;
	char ch;
	ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')
		f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=x*10,x=x+ch-'0';
		ch=getchar();
	}
	return x*f;
}

int main()
{
//	freopen("1.in","r",stdin);
//	freopen("res1.out","w",stdout);
	memset(head,-1,sizeof(head));
	n=read();
	for(int i=1;i<=n-1;i++)
	{
		int a,b;
		a=read();
		b=read();
		add(a,b);
		add(b,a);
	}
	for(int i=1;i<=n;i++)
		w[i]=read();
	dfs1(1,-1,1);
	dfs2(1,1);
	build(1,1,n);
	m=read();
	for(int i=1;i<=m;i++)
	{
		int a,b;
		cin>>s;
		a=read();
		b=read();
		if(s=="QMAX")
		{
			cout<<tree_query_max(a,b)<<endl;	
		}
		else if(s=="QSUM")
		{
			cout<<tree_query_sum(a,b)<<endl;
		}
		else {
			update(1,id[a],b);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/bangdexuanyuan/p/13926549.html