[题解] [HNOI2014] 世界树

题面

[HNOI2014]世界树

题解

从数据范围很容易看出是个虚树DP(可惜看出来了也还是不会做)

虚树大家应该都会, 不会的话自己去搜吧, 我懒得讲了, 我们在这里只需要考虑如何DP即可

首先我们需要求出每个点被哪个点所控制, 设(u)点被(bl[u])所控制, 两遍DFS即可, 考虑儿子对父亲的影响和父亲对儿子的影响

代码细节相信不要我说, 能做这道题的总不可能不会DFS吧

还是贴一下自己看吧

void dfs1(int u, int fa)
{
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue; 
		dfs1(v, u); int dv = dep[bl[v]] - dep[u], du = bl[u] ? dep[bl[u]] - dep[u] : 0x3f3f3f3f; 
		if(dv < du || (dv == du && bl[v] < bl[u])) bl[u] = bl[v]; 
	}
}
void dfs2(int u, int fa)
{
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue; 
		int dv = dis(bl[v], v), du = dis(bl[u], v); 
		if(du < dv || (du == dv && bl[u] < bl[v])) bl[v] = bl[u]; 
		dfs2(v, u); 
	}
}
//分别在递归前和递归后处理一下就完事了

然后考虑如何计算答案, 对虚树上每一条边讨论

①: 边的两端被同一个节点所控制, 加上这两个点不在虚树中的儿子的sz即可

②: 边的两端被不同点控制, 我们需要找出一个分界点, 满足此分界点归下面那个点的(bl[])所控制, 此分界点的父亲被上面那个点的(bl[])所控制, 用个数据结构维护一下或者倍增跳一下就可以了

至于每个点不在虚树中的儿子的sz, 拿当前点的sz减去所有他在虚树中的儿子的sz即可

具体实现细节参见代码(参考了一下题解的思路嘿嘿嘿)

代码实现

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#define N 300005
using namespace std;

int n, m, Q, head[N], sz[N], son[N], dep[N], dfn[N], f[N][21], top[N], cnt, tp, a[N], b[N], stk[N], bl[N], con[N], ans[N], l[N]; 
struct edge { int from, to, next; } e[N << 1]; 

inline int read()
{
    int x = 0, w = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
    return x * w;
}

inline void add(int u, int v) { e[++cnt] = (edge) { u, v, head[u] }; head[u] = cnt; }

void dfs_sz(int u, int fa)
{
	sz[u] = 1; dep[u] = dep[fa] + 1; f[u][0] = fa; 
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue; 
		dfs_sz(v, u); sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v; 
	}
}

void dfs_top(int x, int y)
{
	dfn[x] = ++cnt; top[x] = y; 
	if(!son[x]) return; dfs_top(son[x], y); 
	for(int i = head[x]; i; i = e[i].next) if(e[i].to != f[x][0] && e[i].to != son[x]) dfs_top(e[i].to, e[i].to); 
}

int LCA(int x, int y)
{
	while(top[x] != top[y])
	{
		if(dep[top[x]] < dep[top[y]]) swap(x, y); 
		x = f[top[x]][0]; 
	}
	return dep[x] < dep[y] ? x : y; 
}

bool cmp(int x, int y) { return dfn[x] < dfn[y]; }

void dfs1(int u, int fa)
{
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue; 
		dfs1(v, u); int dv = dep[bl[v]] - dep[u], du = bl[u] ? dep[bl[u]] - dep[u] : 0x3f3f3f3f; 
		if(dv < du || (dv == du && bl[v] < bl[u])) bl[u] = bl[v]; 
	}
}

int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }

void dfs2(int u, int fa)
{
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue; 
		int dv = dis(bl[v], v), du = dis(bl[u], v); 
		if(du < dv || (du == dv && bl[u] < bl[v])) bl[v] = bl[u]; 
		dfs2(v, u); 
	}
}

void dp(int u)
{
	for(int s, mid, nt, dv, du, i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; dp(v); s = mid = v;  
		for(int j = l[dep[v]]; j >= 0; j--) if(dep[f[s][j]] > dep[u]) s = f[s][j]; 
		con[u] -= sz[s];
		if(bl[u] == bl[v]) { ans[bl[u]] += sz[s] - sz[v]; continue; }
		for(int j = l[dep[v]]; j >= 0; j--)
		{
			nt = f[mid][j]; if(dep[nt] <= dep[u]) continue; 
			dv = dis(bl[v], nt), du = dis(bl[u], nt); 
			if(dv < du || (dv == du && bl[v] < bl[u])) mid = nt; 
		}
		ans[bl[u]] += sz[s] - sz[mid]; 
		ans[bl[v]] += sz[mid] - sz[v]; 
	}
	ans[bl[u]] += con[u]; 
}

void query()
{
	m = read(); cnt = tp = 0; 
	for(int i = 1; i <= m; i++) bl[a[++cnt] = b[i] = read()] = b[i]; 
	sort(a + 1, a + cnt + 1, cmp); 
	for(int i = 1; i < m; i++) a[++cnt] = LCA(a[i], a[i + 1]); 
	a[++cnt] = 1; sort(a + 1, a + cnt + 1, cmp); 
	int len = unique(a + 1, a + cnt + 1) - a - 1; 
	cnt = 0; for(int i = 1; i <= len; i++) head[a[i]] = 0, con[a[i]] = sz[a[i]]; 
	for(int i = 1; i <= len; i++)
	{
		while(tp && dfn[a[i]] >= dfn[stk[tp]] + sz[stk[tp]]) tp--; 
		if(tp) add(stk[tp], a[i]); stk[++tp] = a[i]; 
	}
	dfs1(1, 0); dfs2(1, 0); dp(1); 
	for(int i = 1; i <= m; i++) printf("%d%c", ans[b[i]], i == m ? '
' : ' '); 
	for(int i = 1; i <= len; i++) bl[a[i]] = ans[a[i]] = con[a[i]] = 0; 
}

int main()
{
	n = read(); for(int i = 2; i <= n; i++) l[i] = l[i >> 1] + 1; 
	for(int i = 1; i < n; i++) { int u = read(), v = read(); add(u, v); add(v, u); }
	cnt = 0; dfs_sz(1, 0); dfs_top(1, 0); 
	for(int i = 1; i <= n; i++)
		for(int  j = 1; j <= 20; j++) 
			f[i][j] = f[f[i][j - 1]][j - 1]; 
	Q = read(); while(Q--) query(); 
	return 0;
}
原文地址:https://www.cnblogs.com/ztlztl/p/10991790.html