BZOJ3631(树链剖分)

差不多可以说是树链剖分的模板题了,直接维护即可。


#include <bits/stdc++.h>

using namespace std;

#define REP(i,n)                for(int i(0); i <  (n); ++i)
#define rep(i,a,b)              for(int i(a); i <= (b); ++i)
#define dec(i,a,b)              for(int i(a); i >= (b); --i)
#define for_edge(i,x)           for(int i = H[x]; i; i = X[i])

#define LL      long long
#define ULL     unsigned long long
#define MP      make_pair
#define PB      push_back
#define FI      first
#define SE      second
#define INF     1 << 30


const int N     =    300000      +       10;
const int M     =    10000       +       10;
const int Q     =    1000        +       10;
const int A     =    30          +       1;

int E[N << 1], H[N << 1], X[N << 1];
int c[N];
int top[N];
int fa[N];
int deep[N];
int num[N];
int son[N];
int fp[N];
int p[N];
int et, pos;
int a[N];
int n, x, y;

inline int lowbit(int x){ return (x) & (-x);}

inline int query(int x){int ret = 0; for (; x; x -= lowbit(x)) ret += c[x]; return ret;}
inline void add(int x, int val){ for (; x <= n; x += lowbit(x)) c[x] += val;}

inline void addedge(int a, int b){
	E[++et] = b, X[et] = H[a], H[a] = et;
	E[++et] = a, X[et] = H[b], H[b] = et;
}

void dfs(int x, int pre){
	deep[x] = deep[pre] + 1;
	fa[x] = pre;
	num[x] = 1;
	for_edge(i, x){
		int v = E[i];
		if (v != pre){
			dfs(v, x);
			num[x] += num[v];
			if (son[x] != -1 || num[v] > num[son[x]])
				son[x] = v;
		}
	}
}

void getpos(int x, int sp){
	top[x] = sp;
	p[x] = ++pos;
	fp[p[x]] = x;
	if (son[x] == -1) return;
	getpos(son[x], sp);
	for_edge(i, x){
		int v = E[i];
		if (v != son[x] && v != fa[x]) 
			getpos(v, v);
	}
}

void cover(int u, int v, int val){
	int f1 = top[u], f2 = top[v];
	int tmp = 0;
	while (f1 != f2){
		if (deep[f1] < deep[f2]){
			swap(f1, f2);
			swap(u, v);
		}
		add(p[f1], val);
		add(p[u] + 1, -val);
		u = fa[f1];
		f1 = top[u];
	}

	if (deep[u] > deep[v]) swap(u, v);
	add(p[u], val);
	add(p[v] + 1, -val);
}

int main(){
#ifndef ONLINE_JUDGE
	freopen("test.txt", "r", stdin);
	freopen("test.out", "w", stdout);
#endif

	scanf("%d", &n);
	rep(i, 1, n) scanf("%d", a + i);
	rep(i, 1, n - 1){
		scanf("%d%d", &x, &y);
		addedge(x, y);
	}
	

	memset(son, -1, sizeof son);
	dfs(1, 0);
	getpos(1, 1);	
	rep(i, 1, n - 1){
		x = a[i], y = a[i + 1];
		cover(x, y, 1);
	}

	rep(i, 1, n) if (i == a[1]) printf("%d
", query(p[i]));
	else printf("%d
", query(p[i]) - 1);
	

	return 0;

}




原文地址:https://www.cnblogs.com/cxhscst2/p/6648803.html