启发式合并

启发式合并

先看看什么是启发式算法。

启发式算法可以这样定义:一个基于直观或经验构造的算法,在可接受的花费(指计算时间和空间)下给出待解决组合优化问题每一个实例的一个可行解,该可行解与最优解的偏离程度一般不能被预计。现阶段,启发式算法以仿自然体算法为主,主要有蚁群算法、模拟退火法、神经网络等。

(from) 百度百科

再来看看启发式合并

这个东西的原理很简单,就是你考虑合并 (2) 个数据结构,如果直接合并,复杂度去最坏情况,为 (O(n^2))

但是,你只需要记录一下 (size) 然后每次跑的时候把 (size) 小的合并到 (size) 大的。

这个算法感觉和之前的暴力也没什么区别就吧。

但是,你会发现每个元素最多合并 (log_m) 次,(n) 个元素,最坏复杂度 (O(n * log_n))

至于是为什么,不多赘述。因为不会。。

其实启发式合并和线段树合并一样,只是一种工具,一种优化。

不过,启发式合并可以适用于各种不同的数据结构,比如 (set)(splay) 等等。

伪代码就不放了,因为它适用面太广了,没有什么固定的模板。

例题:

P3201 [HNOI2009] 梦幻布丁

题目大意

你有 (n) 个布丁,每个布丁都有它的颜色,一共有 (m) 次操作。

每次操作可以:

第一,把所有颜色为 (x) 的布丁变成颜色为 (y) 的布丁。

第二,问当前一共有多少段颜色。

题解

对于每个颜色维护一个数据结构,在进行操作 (1) 的时候,就把 (x)(y) 所在数据结构进行启发式合并。

思路很简单,但是代码实现却需要处理很多细节。

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <queue>
#define maxn 200020
#define ls x << 1
#define rs x << 1 | 1
#define inf 0x3f3f3f3f
#define inc(i) (++ (i))
#define dec(i) (-- (i))
#define mid ((l + r) >> 1)
// #define int long long
#define XRZ 1000000003
#define debug() puts("XRZ TXDY");
#define mem(i, x) memset(i, x, sizeof(i));
#define Next(i, u) for(register int i = head[u]; i ; i = e[i].nxt)
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout);
#define Rep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i <= i##Limit ; inc(i))
#define Dep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i >= i##Limit ; dec(i))
int dx[10] = {1, -1, 0, 0};
int dy[10] = {0, 0, 1, -1};
using namespace std;
inline int read() {
    register int x = 0, f = 1; register char c = getchar();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
} int ans, a[maxn], head[maxn], nxt[maxn], num[maxn], S[maxn], fa[maxn];
void merge(int x, int y) {
	for(int i = head[x]; i; i = nxt[i]) ans -= (a[i - 1] == y) + (a[i + 1] == y);
	for(int i = head[x]; i; i = nxt[i]) a[i] = y;
	nxt[S[x]] = head[y], head[y] = head[x], num[y] += num[x];
	head[x] = S[x] = num[x] = 0;
}
signed main() { int n = read(), m = read();
	Rep(i, 1, n) { a[i] = read();
		fa[a[i]] = a[i], ans += a[i] != a[i - 1];
		if(head[a[i]] == 0) S[a[i]] = i;
		inc(num[a[i]]); nxt[i] = head[a[i]], head[a[i]] = i;
	} Rep(i, 1, m) { int opt = read();
		if(opt == 1) { int x = read(), y = read();
			if(x == y) continue;
			if(num[fa[x]] > num[fa[y]]) swap(fa[x], fa[y]);
			if(num[fa[x]] == 0) continue;
			merge(fa[x], fa[y]);
		} else printf("%d
", ans);
	}
	return 0;
}

P5290 [十二省联考2019] 春节十二响

题目描述

给你 (n) 个节点,取出每个节点都要付出相应的代价。

你可以一次取多个 不在 一条从 (1) 出发的链 的节点一起取出,代价为其中最大的。

题解

首先考虑部分分,

如果你只有链的情况。

考虑贪心,最大的只能和另一条链上最大的匹配。

不断配对就可以了。

再扩展到树上。

是不是对于每 (2) 条链做一次合并,最后的复杂度是 (O(n ^ 2))

这个是不可以过的,所以我们需要考虑优化,既然是在启发式合并里面,那就显然是用启发式合并去优化这个暴力合并。

记录 (size) 把小的合并到大的即可,复杂度 (O(n * log_n))

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#include <queue>
#define maxn 200001
#define ls x << 1
#define rs x << 1 | 1
#define inf 0x3f3f3f3f
#define inc(i) (++ (i))
#define dec(i) (-- (i))
#define mid ((l + r) >> 1)
#define int long long
#define XRZ 1000000003
#define debug() puts("XRZ TXDY");
#define mem(i, x) memset(i, x, sizeof(i));
#define Next(i, u) for(register int i = head[u]; i ; i = e[i].nxt)
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout);
#define Rep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i <= i##Limit ; inc(i))
#define Dep(i, a, b) for(register int i = (a) , i##Limit = (b) ; i >= i##Limit ; dec(i))
int dx[10] = {1, -1, 0, 0};
int dy[10] = {0, 0, 1, -1};
using namespace std;
inline int read() {
    register int x = 0, f = 1; register char c = getchar();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
} int Ans, a[maxn];//, head[maxn];
priority_queue<int> Q[maxn]; vector<int> s, qwq[maxn];
// struct node { int nxt, to;} e[maxn << 1];
// void add(int x, int y) { e[inc(cnt)] = (node) {head[x], y}; head[x] = cnt;}
void merge(int x, int y) {
	if(Q[x].size() < Q[y].size()) swap(Q[x], Q[y]);
	while(Q[y].size()) {
		s.push_back(max(Q[x].top(), Q[y].top()));
		Q[x].pop(), Q[y].pop();
	} while(s.size()) Q[x].push(s.back()), s.pop_back();
}
void Dfs(int x) {
	// Next(i, x) { int v = e[i].to; Dfs(v); merge(x, v);}
	Rep(i, 0, qwq[x].size() - 1) Dfs(qwq[x][i]), merge(x, qwq[x][i]);
	Q[x].push(a[x]);
}
signed main() { int n = read();
	Rep(i, 1, n) a[i] = read();
	Rep(i, 2, n) { int u = read(); qwq[u].push_back(i); }
	// Rep(i, 2, n) { int u = read(); add(i, u);}
	Dfs(1); while(Q[1].size()) Ans += Q[1].top(), Q[1].pop();
	printf("%lld", Ans);
	return 0;
}
原文地址:https://www.cnblogs.com/Flash-plus/p/13834039.html