线段树合并

线段树合并

应用范围:将子树的信息合并给父亲节点,并且权值线段树的下标值域和节点数相近。

CF600E Lomsat gelral

题目链接

题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。

(1 <= n<=1e5)

解法:线段树合并,这个东西的时空复杂度都很玄学,姑且认为时间为(O(nlogn)),空间为(常数( imes log_n imes n)),常数一般为(4-8)

#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
	nxt[++tot] = fir[u];
	fir[u] = tot;
	vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
	if(sum[lz[a]] < sum[rz[a]])
	{
		sum[a] = sum[rz[a]];
		ans[a] = ans[rz[a]];
	}
	if(sum[lz[a]] > sum[rz[a]])
	{
		sum[a] = sum[lz[a]];
		ans[a] = ans[lz[a]];
	}
	if(sum[lz[a]] == sum[rz[a]])
	{
		sum[a] = sum[lz[a]];
		ans[a] = ans[lz[a]] + ans[rz[a]];
	}
	return;
}
int merge(int a, int b, int l, int r)
{
	if(a == 0) return b;
	if(b == 0) return a;
	if(l == r)
	{
		sum[a] += sum[b];
		ans[a] = l;
		return a;
	}
	int mid = (l + r) >> 1;
	lz[a] = merge(lz[a], lz[b], l, mid);
	rz[a] = merge(rz[a], rz[b], mid + 1, r);
	pushup(a);
	return a;
}
void update(int &a, int l, int r, int v)
{
	if(!a) a = ++cnt;
	int mid = (l + r) >> 1;
	if(l == r)
	{
		sum[a] += 1;
        ans[a] = l;
		return;
	}
	if(mid >= v) update(lz[a], l, mid, v);
	if(mid < v) update(rz[a], mid + 1, r, v); 
	pushup(a);
	return;
}
void dfs(int u, int fa)
{
	for(int i = fir[u]; i; i = nxt[i])
	{
		int v = vv[i];
		if(v == fa) continue;
		dfs(v, u);
		merge(root[u], root[v], 1, 100000);
	}
	update(root[u], 1, 100000, col[u]);
    ans[u] = ans[root[u]];
}
int main()
{
	scanf("%d", &n); cnt = n;
	for(int i = 1; i <= n; i++) 
	{
		scanf("%d", &col[i]); root[i] = i;
	}
	for(int i = 1; i < n; i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v); add(v, u);
	}
	dfs(1, 0);
	for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
	return 0;
}

雨天的尾巴

题目链接

注意(ans)要在(dfs)时计算,不然当前节点的(root)可能被父亲节点继承,然后就炸了。

#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
	nxt[++tot] = fir[u];
	fir[u] = tot;
	vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
	if(sum[lz[a]] < sum[rz[a]])
	{
		sum[a] = sum[rz[a]];
		ans[a] = ans[rz[a]];
	}
	if(sum[lz[a]] > sum[rz[a]])
	{
		sum[a] = sum[lz[a]];
		ans[a] = ans[lz[a]];
	}
	if(sum[lz[a]] == sum[rz[a]])
	{
		sum[a] = sum[lz[a]];
		ans[a] = ans[lz[a]] + ans[rz[a]];
	}
	return;
}
int merge(int a, int b, int l, int r)
{
	if(a == 0) return b;
	if(b == 0) return a;
	if(l == r)
	{
		sum[a] += sum[b];
		ans[a] = l;
		return a;
	}
	int mid = (l + r) >> 1;
	lz[a] = merge(lz[a], lz[b], l, mid);
	rz[a] = merge(rz[a], rz[b], mid + 1, r);
	pushup(a);
	return a;
}
void update(int &a, int l, int r, int v)
{
	if(!a) a = ++cnt;
	int mid = (l + r) >> 1;
	if(l == r)
	{
		sum[a] += 1;
        ans[a] = l;
		return;
	}
	if(mid >= v) update(lz[a], l, mid, v);
	if(mid < v) update(rz[a], mid + 1, r, v); 
	pushup(a);
	return;
}
void dfs(int u, int fa)
{
	for(int i = fir[u]; i; i = nxt[i])
	{
		int v = vv[i];
		if(v == fa) continue;
		dfs(v, u);
		merge(root[u], root[v], 1, 100000);
	}
	update(root[u], 1, 100000, col[u]);
    ans[u] = ans[root[u]];
}
int main()
{
	scanf("%d", &n); cnt = n;
	for(int i = 1; i <= n; i++) 
	{
		scanf("%d", &col[i]); root[i] = i;
	}
	for(int i = 1; i < n; i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v); add(v, u);
	}
	dfs(1, 0);
	for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/Akaina/p/11843252.html