ABC221F Diameter set 题解

题面

题意简述:

给定一棵 (n) 个节点的树,设它的直径是 (D),问有多少个集合满足集合中每两个点的距离都为 (D)

( exttt{Data Range:} 1le nle 2 imes 10^5)

考虑直径的性质:

  • 树的每一条直径一定都经过一个公共点 / 一条公共边。经过的是点还是边取决于直径的长度是奇数还是偶数。

那么按照直径长度的奇偶性分类讨论,直接计算即可。

具体的:

  • 若直径长度为偶数,设中点为 (mid),答案为 (prodlimits_{vin son_{mid}}(cnt_v+1)-scnt-1),其中 (cnt_v)(v) 子树中距离 (v) 长度为 (frac{D}{2}-1) 的个数,(scnt) 为距离 (mid) 长度为 (frac{D}{2}) 的点的个数。
  • 若直径长度为奇数,设中间的边为 ((mid,fmid)),那么答案就是 (cnt_{mid} imes cnt_{fmid})(cnt_{mid})(mid) 子树中距离 (mid) 长度为 (frac{D}{2}) 的点的个数。

代码:

#include <bits/stdc++.h>
#define DC int T = gi <int> (); while (T--)
#define DEBUG fprintf(stderr, "Passing [%s] line %d
", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
#define fi first
#define se second
#define pb push_back
#define mp make_pair

using namespace std;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair <int, int> PII;
typedef pair <LL, LL> PLL;

template <typename T>
inline T gi()
{
	T x = 0, f = 1; char c = getchar();
	while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
	while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
	return f * x;
}

const int N = 200003, M = N << 1, mod = 998244353;

int n;
int tot, head[N], ver[M], nxt[M];
int fa[N];
int mx, lft, rght;
int cnt;

inline void add(int u, int v) {ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;}

inline int qpow(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = 1ll * res * x % mod;
		x = 1ll * x * x % mod, y >>= 1;
	}
	return res;
}

void dfs(int u, int f, int dis)
{
	if (dis > mx) mx = dis, rght = u;
	fa[u] = f;
	for (int i = head[u]; i; i = nxt[i])
	{
		int v = ver[i];
		if (v == f) continue;
		dfs(v, u, dis + 1);
	}
}

void dfsson(int u, int f, int tar, int dis)
{
	if (tar == dis) ++cnt;
	for (int i = head[u]; i; i = nxt[i])
	{
		int v = ver[i];
		if (v == f) continue;
		dfsson(v, u, tar, dis + 1);
	}
}

int main()
{
	//freopen(".in", "r", stdin); freopen(".out", "w", stdout);
	n = gi <int> ();
	for (int i = 1; i < n; i+=1)
	{
		int u = gi <int> (), v = gi <int> ();
		add(u, v), add(v, u);
	}
	dfs(1, 0, 0);
	lft = rght, mx = 0;
	dfs(lft, 0, 0);
	int d = mx;
	if (d % 2 == 0)
	{
		int mid = rght;
		for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
		int ans = 1, scnt = 0;
		for (int i = head[mid]; i; i = nxt[i])
		{
			int v = ver[i];
			cnt = 0;
			dfsson(v, mid, d / 2 - 1, 0);
			ans = 1ll * ans * (cnt + 1) % mod;
			scnt += cnt;
		}
		printf("%d
", (ans - 1 - scnt + mod) % mod);
	}
	else
	{
		int mid = rght;
		for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
		int fmid = fa[mid];
		dfsson(mid, fmid, d / 2, 0);
		int tcnt = cnt; cnt = 0;
		dfsson(fmid, mid, d / 2, 0);
		printf("%lld
", 1ll * cnt * tcnt % mod);
	}
	return !!0;
}
原文地址:https://www.cnblogs.com/xsl19/p/abc221f.html