算法学习————线段树合并

一、线段树合并的思想

线段树合并,顾名思义,就是建立一棵新的线段树保存原有的两颗线段树的信息。

二、线段树合并的流程

假设我们合并到了两棵树的pos位置

  1. 如果a有pos位置,b没有,那么新的线段树pos位置赋成a,返回

  2. 如果b有pos位置,a没有,赋成b,返回

  3. 如果此时已经合并到两棵线段树的叶子节点了,就把b在pos的值加到a上,把新线段树上的pos位置赋成a,返回

  4. 递归处理左子树

  5. 递归处理右子树

  6. 用左右子树的值更新当前节点

  7. 将新线段树上的pos位置赋成a,返回

代码:

int merge(int a,int b,int l,int r){
	if (!a) return b;
	if (!b) return a;
	int res = a;
	if (l == r){
		t[res].sum = t[a].sum+t[b].sum;
		t[res].ans = l;
		return res; 
	}
	int mid = (l+r >> 1);
	t[res].l = merge(t[a].l,t[b].l,l,mid);
	t[res].r = merge(t[a].r,t[b].r,mid+1,r);
//	cout<<l<<" "<<r<<endl;
//	cout<<"merge = "<<a<<" "<<t[a].sum<<" "<<t[a].ans<<" "<<b<<" "<<t[b].sum<<" "<<t[b].ans<<endl;
//	cout<<"res = "<<res<<" "<<t[res].l<<" "<<t[res].r<<endl;
	pushup(res);
	return res;
}

例题:CF600E Lomsat gelral

线段树怎么维护呢??

建一棵权值线段树,线段树上的每个节点维护两个值,当前区间颜色出现的最大次数,和出现次数为最大次数的颜色的和

这样每次更新的时候,如果左儿子的颜色出现的最大次数大,直接等于左儿子,但注意不要修改左右两个儿子的指针

如果右儿子的颜色出现的最大次数大,直接等于右儿子,如果相等,则把颜色的和更新

每次在dfs回溯的时候把儿子和自己的合并,最后再把自己加进去

代码:

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define ll long long
#define B cout<<"Breakpoint"<<endl;
#define O(x) cout<<#x<<" "<<x<<endl;
#define o(x) cout<<#x<<" "<<x<<" ";
using namespace std;
int read(){
	int x = 1,a = 0;char ch = getchar();
	while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();} 
	while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
	return x*a; 
}
const int maxn = 2e5+10;
int n,col[maxn];
struct node{
	int to,nxt;
}ed[maxn << 1];
int head[maxn],tot;
void add(int u,int to){
	ed[++tot].to = to;
	ed[tot].nxt = head[u];
	head[u] = tot;
}
struct SEGTree{
	int l,r,sum;
	ll ans;
}t[maxn << 2];
void pushup(int x){
	int l = t[x].l,r = t[x].r;
//	cout<<"pushup =  = "<<l<<" "<<r<<" "<<t[l].sum<<" "<<t[x].ans<<" "<<t[r].sum<<" "<<t[r].ans<<endl;
	if (t[l].sum > t[r].sum){
		t[x].ans = t[l].ans;
		t[x].sum = t[l].sum;
	}
	if (t[r].sum > t[l].sum){
		t[x].ans = t[r].ans;
		t[x].sum = t[r].sum;
	}
	if (t[l].sum == t[r].sum){
		t[x].sum = t[l].sum;
		t[x].ans = t[l].ans+t[r].ans; 
	}
//	cout<<"pushup = "<<t[x].sum<<" "<<t[x].ans<<endl;
}
int cnt;
void modify(int &x,int lst,int l,int r,int p,int k){
	if (!x) x = ++cnt;
	if (l == r){
		t[x].sum = t[lst].sum+1;
		t[x].ans = l;
//		cout<<"modify = "<<x<<" "<<l<<" "<<t[x].sum<<" "<<t[x].ans<<endl;
		return;
	}
	int mid = (l+r >> 1);
	if (p <= mid) t[x].r = t[lst].r,modify(t[x].l,t[lst].l,l,mid,p,k);
	else t[x].l = t[lst].l,modify(t[x].r,t[lst].r,mid+1,r,p,k);
//	cout<<x<<" "<<t[x].l<<" "<<t[x].r<<endl;
	pushup(x);
}
int merge(int a,int b,int l,int r){
	if (!a) return b;
	if (!b) return a;
	int res = a;
	if (l == r){
		t[res].sum = t[a].sum+t[b].sum;
		t[res].ans = l;
		return res; 
	}
	int mid = (l+r >> 1);
	t[res].l = merge(t[a].l,t[b].l,l,mid);
	t[res].r = merge(t[a].r,t[b].r,mid+1,r);
//	cout<<l<<" "<<r<<endl;
//	cout<<"merge = "<<a<<" "<<t[a].sum<<" "<<t[a].ans<<" "<<b<<" "<<t[b].sum<<" "<<t[b].ans<<endl;
//	cout<<"res = "<<res<<" "<<t[res].l<<" "<<t[res].r<<endl;
	pushup(res);
	return res;
}
int root[maxn];
ll ans[maxn];
void dfs(int x,int fa){
	for (int i = head[x];i;i = ed[i].nxt){
		int to = ed[i].to;
		if (to == fa) continue;
		dfs(to,x);
		root[x] = merge(root[x],root[to],1,n);
	} 
	modify(root[x],root[x],1,n,col[x],1);
	ans[x] = t[root[x]].ans;
}
int main(){
	n = read();
	for (int i = 1;i <= n;i++) col[i] = read();
	for (int i = 1;i < n;i++){
		int x = read(),y = read();
		add(x,y),add(y,x);
	}
	dfs(1,0);
	for (int i = 1;i <= n;i++) printf("%lld ",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/little-uu/p/14743781.html