Codeforces 1111E DP + 树状数组 + LCA + dfs序

题意:给你一颗树,有q次询问,每次询问给你若干个点,这些点可以最多分出m组,每组要满足两个条件:1:每组至少一个点,2:组内的点不能是组内其它点的祖先,问这样的分组能有多少个?

思路:https://blog.csdn.net/BUAA_Alchemist/article/details/86765501

代码:

#include <bits/stdc++.h>
#define LL long long
#define lowbit(x) (x & (-x))
using namespace std;
const LL mod = 1000000007;
const int maxn = 100010;
vector<int> G[maxn];
vector<int> a;
int dfn[maxn], sz[maxn], tot, t;
LL dp[maxn][310];
int n;
void add(int x, int y) {
	G[x].push_back(y);
	G[y].push_back(x);
}
void dfs(int x, int fa) {
	dfn[x] = ++tot;
	sz[x] = 1;
	for (auto y : G[x]) {
		if(y == fa) continue;
		dfs(y, x);
		sz[x] += sz[y];
	}
}
queue<int> q;
int dep[maxn], f[maxn][20];
void bfs() {
	q.push(1);
	dep[1] = 1;
	while(q.size()) {
		int x = q.front();
		q.pop();
		for (auto y : G[x]) {
			if(dep[y]) continue;
			dep[y] = dep[x] + 1;
			//dis[y] = dis[x] + 1;
			f[y][0] = x;
			for (int j = 1; j <= t; j++)
				f[y][j] = f[f[y][j - 1]][j - 1];
			q.push(y);
		}
	}
}

int lca(int x, int y) {
	if(dep[x] > dep[y]) swap(x, y);
	for (int i = t; i >= 0; i--)
		if(dep[f[y][i]] >= dep[x]) y = f[y][i];
	if(x == y) return y;
	for (int i = t; i >= 0; i--)
		if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}
struct BIT {
	int c[maxn];
	int ask(int x) {
		int ans = 0;
		for(; x; x -= lowbit(x)) ans += c[x];
		return ans;
	}
	
	void add(int x, int y) {
		for(; x <= n; x += lowbit(x)) c[x] += y;
	}
	
};
BIT tr;
int h[maxn], vis[maxn];
int main() {
	int u, v, T;
	scanf("%d%d", &n, &T);
	t = (int)(log(n) / log(2)) + 1;
	for (int i = 1; i < n; i++) {
		scanf("%d%d", &u, &v);
		add(u, v);
	}
	dfs(1, -1);
	bfs();
	int k, m, r, x;
	LL ans = 0;
	while(T--) {
		scanf("%d%d%d",&k, &m, &r);
		ans = 0;
		for (int i = 1; i <= k; i++) {
			scanf("%d", &x);
			vis[x] = 1;
			a.push_back(x);
			tr.add(dfn[x], 1);
			tr.add(dfn[x] + sz[x], -1);
		}
		for (int i = 0; i < k; i++) {
			int LCA = lca(a[i], r);
			h[i + 1] = tr.ask(dfn[a[i]]) + tr.ask(dfn[r]) - 2 * tr.ask(dfn[LCA]) + vis[LCA] - 1;
		}
		sort(h + 1, h + 1 + k);
		dp[0][0] = 1;
		for (int i = 1; i <= k; i++)
			for (int j = 0; j <= min(i, m); j++) {
				if(j > 0)
					dp[i][j] = (LL)((LL)dp[i - 1][j - 1] + ((LL)dp[i - 1][j] * max(0, j - h[i])) % mod) % mod;
			}
		for (int i = 1; i <= m; i++)
			ans = (ans + dp[k][i]) % mod;
		printf("%lld
", ans);
		for (int i = 0; i < k; i++) {
			tr.add(dfn[a[i]], -1);
			tr.add(dfn[a[i]] + sz[a[i]], 1);
			vis[a[i]] = 0;
		}
		a.clear();
	}
} 

  

原文地址:https://www.cnblogs.com/pkgunboat/p/10918148.html