【题解】CSP2019 简要题解

D1T1 code

签到题,大家都会。

可以从高位往低位确定,如果遇到 \(1\),则将排名取反一下。

注意要开 unsigned long long

#include <bits/stdc++.h>

typedef unsigned long long u64; 

const int MaxN = 100; 

u64 n, K; 
bool ans[MaxN]; 

inline void solve(u64 dep, u64 k)
{
	if (dep == 0)
		return; 
	
	u64 lsze = 1ull << (dep - 1); 
	if (k < lsze)
	{
		ans[dep] = false; 
		solve(dep - 1, k); 
	}
	else
	{
		ans[dep] = true; 
		solve(dep - 1, lsze - (k - lsze) - 1); 
	}
}

int main()
{
    freopen("code.in", "r", stdin); 
    freopen("code.out", "w", stdout); 
    
	std::cin >> n >> K; 
	
	solve(n, K); 
	
	for (int i = n; i >= 1; --i)
		putchar(ans[i] ? '1' : '0'); 
	
	return 0; 
}

D1T2 brackets

简单题,大家都会。

大家的做法都好巨,我只会奇奇怪怪的做法。

考虑每次加一个括号之后答案的增量,显然只有加右括号的时候答案会增加。

我们记两个量, \(lst\_lef_i\) 表示 \(i\to 1\) 的路径上最后一个没有匹配的左括号,\(lst\_blk_i\) 表示以 \(i\) 结尾的合法子串个数(本质上是数出类似这种串 ...(...)(...)(...)(...) 的极长合法后缀可以分成几段 (...))。

这两个量可以直接线性推出来,然后就做完了。时间复杂度 \(O(n)\)

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = getchar())); 
	x = ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
}

typedef long long s64; 

const int MaxNV = 5e5 + 5; 
const int MaxNE = MaxNV; 

int n; 
int fa[MaxNV]; 
char s[MaxNV]; 

int lst_lef[MaxNV]; 
int lst_blk[MaxNV]; 

s64 ans[MaxNV], xor_ans; 

int main()
{
    freopen("brackets.in", "r", stdin); 
    freopen("brackets.out", "w", stdout); 
    
	scanf("%d%s", &n, s + 1); 
	for (int i = 2; i <= n; ++i)
		read(fa[i]); 
	
	if (s[1] == '(')
		lst_lef[1] = 1; 
	
	for (int u = 2; u <= n; ++u)
	{
		ans[u] = ans[fa[u]]; 
		
		if (s[u] == '(')
		{
			lst_lef[u] = u; 
			lst_blk[u] = 0; 
		}
		else
		{
			if (lst_lef[fa[u]])
			{
				int lef_u = lst_lef[fa[u]]; 
				
				lst_lef[u] = lst_lef[fa[lef_u]]; 
				lst_blk[u] = lst_blk[fa[lef_u]] + 1; 
				ans[u] += lst_blk[u]; 
			}
			else
			{
				lst_lef[u] = 0; 
				lst_blk[u] = 0; 
			}
        }
		xor_ans ^= 1LL * u * ans[u]; 
	}
	
	std::cout << xor_ans << std::endl; 
	
	return 0; 
}

D1T3 tree

细节题。这个题的 idea 挺好的,就是容易分类讨论挂。

难度其实不大,放在 D1T3 其实没啥毛病。可能出题人高估了我的代码能力。我太菜了,考场上调不出来。

一个显然的贪心就是从小到大枚举数字,然后判断这个数字最终能送到那个位置。显然每次我们都贪心地选取最小的位置。

考虑一条路径 \(u_1\to u_2\to \dots \to u_k\),假设我们要将 \(u_1\) 的原来数字送到 \(u_k\),那么需要满足下列条件:

  • \((u_1,u_2)\) 是和 \(u_1\) 相连的所有边中第一个被删除的
  • \((u_{k - 1}, u_k)\) 是和 \(u_k\) 相连的所有边中最后一个被删除的
  • \((u_{i-1},u_i)\) 必须比 \((u_i,u_{i+1})\) 先删除,并且在删除 \((u_{i-1},u_i)\) 后,删除 \((u_i,u_{i+1})\) 之前,不能有和 \(u_i\) 相连的其他边被删除。

那么我们就对每个点,维护出与其相连的所有边的限制。这些限制具体可以用一个链表来表示,并且需要记录每个点强制限制的第一个删除的边,和最后一个删除的边。

实现的时候,就从当前枚举的数字所在的点开始 dfs,显然满足第一个和第三个条件的点构成一个联通块。我们只需要在 dfs 的时候顺带判断这些条件能否满足即可。

时间复杂度 \(O(n^2)\)。细节比较多,我是根据考场的混乱思路瞎写的,相信读者一定有比我更优秀的实现方法。

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = getchar())); 
	x = ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
}

template <class T>
inline void putint(T x)
{
	static char buf[25], *tail = buf; 
	if (!x)
		putchar('0'); 
	else
	{
		for (; x; x /= 10) *++tail = x % 10 + '0'; 
		for (; tail != buf; --tail) putchar(*tail); 
	}
}

const int MaxN = 2e3 + 5; 

int n; 
int idx[MaxN], col[MaxN], fa[MaxN]; 
int adj[MaxN][MaxN], deg[MaxN]; 

int ans[MaxN]; 
int fir[MaxN], lst[MaxN]; 
int head[MaxN][MaxN], sze[MaxN][MaxN]; 
int pre[MaxN][MaxN], suf[MaxN][MaxN]; 

inline void init()
{
	read(n); 
	for (int i = 1; i <= n; ++i)
	{
		fir[i] = lst[i] = 0; 
		deg[i] = ans[i] = fa[i] = 0; 

		for (int j = 1; j <= n; ++j)
		{
			head[i][j] = j; 
			sze[i][j] = 1; 
			pre[i][j] = suf[i][j] = 0; 
		}
	}

	for (int i = 1; i <= n; ++i)
	{
		read(idx[i]); 
		col[idx[i]] = i; 
	}

	for (int i = 1; i < n; ++i)
	{
		int u, v; 
		read(u), read(v); 
		adj[u][++deg[u]] = v; 
		adj[v][++deg[v]] = u; 
	}
}

inline void dfs(int u, int src)
{
	if (u != src)
	{
		bool flg = true; 

		if (deg[u] != 1)
		{
			flg &= fir[u] != fa[u] && !suf[u][fa[u]]; 
			flg &= !lst[u] || lst[u] == fa[u]; 
			if (fir[u] && head[u][fir[u]] == head[u][fa[u]])
				flg &= sze[u][head[u][fa[u]]] == deg[u]; 
		}

		if (flg)
		{
			if (!ans[src] || u < ans[src])
				ans[src] = u; 
		}
	}
	for (int i = 1; i <= deg[u]; ++i)
	{
		int v = adj[u][i]; 
		if (v == fa[u])
			continue; 

		fa[v] = u; 

		bool flg = true; 
		if (u == src)
		{
			if (deg[u] != 1)
			{
				flg &= lst[u] != v && !pre[u][v]; 
				if (lst[u] && head[u][lst[u]] == head[u][v])
					flg &= sze[u][head[u][v]] == deg[u]; 
			}
		}
		else
		{
			flg &= !suf[u][fa[u]] || suf[u][fa[u]] == v; 
			flg &= !pre[u][v] || pre[u][v] == fa[u]; 
			flg &= suf[u][fa[u]] == v || head[u][v] != head[u][fa[u]]; 
			flg &= head[u][fir[u]] != head[u][v] && head[u][lst[u]] != head[u][fa[u]]; 
			if (head[u][lst[u]] == head[u][v] && head[u][fir[u]] == head[u][fa[u]])
				flg &= sze[u][head[u][lst[u]]] + sze[u][head[u][fir[u]]] == deg[u]; 
		}

		if (flg)
			dfs(v, src); 
	}
}

inline void modify(int x, int src)
{
	if (!x)
		return; 

	lst[x] = fa[x]; 

	int y = x; 
	while (fa[y] != src)
	{
		suf[fa[y]][fa[fa[y]]] = y; 
		pre[fa[y]][y] = fa[fa[y]]; 

		if (head[fa[y]][fa[fa[y]]] != head[fa[y]][y])
		{
			int l = head[fa[y]][fa[fa[y]]]; 
			int z = y; 
			while (z)
			{
				++sze[fa[y]][l]; 
				head[fa[y]][z] = l; 
				z = suf[fa[y]][z]; 
			}
		}
		y = fa[y]; 
	}

	fir[src] = y; 
}

inline void solve()
{
	for (int c = 1; c <= n; ++c)
	{
		int u = idx[c]; 
		if (n == 1)
		{
			puts("1"); 
			continue; 
		}

		fa[u] = 0; 
		dfs(u, u); 
		modify(ans[u], u); 

		putint(ans[u]); 
		putchar(" \n"[c == n]); 
	}
}

int main()
{
    freopen("tree.in", "r", stdin); 
    freopen("tree.out", "w", stdout); 
    
	int orzczk; 
	read(orzczk); 

	while (orzczk--)
	{
		init(); 
		solve(); 
	}

	return 0; 
}

D2T1 meal

简单题,就我不会。考场上降智太严重了,会了 \(O(mn^3)\) 竟然不会 \(O(mn^2)\)。我校其他选手全部 AC 此题,水平高下立判。

因为如果有食材超过一半,那么最多只能有一个这样的食材,所以不难想到用总的方案数减去有一个主要食材超过一半的方案数。总的方案数就是

\[\prod_{i=1}^n\left(1+\sum_{j=1}^ma_{i,j}\right)-1 \]

减一是因为不能一个都不选。

考虑如何限制某个食材超过一半,显然我们可以考虑枚举这个食材,然后把用这个食材的菜权值看成 \(+1\),不用这个食材的菜权值看成 \(-1\),那么相当于选的所有菜的总权值要大于 \(0\)

具体地,我们可以用一个背包 DP 实现。显然大家都会,就不讲了。

\(f(i,j)\) 表示前 \(i\) 种方法,选的菜的总权值为 \(j\) 的方案数。

假设现在限制的是第 \(p\) 种食材,那么转移非常显然:

\[\begin{aligned} f(i,j) &\leftarrow f(i-1,j-1)\times a_{i,p}\\ f(i,j) &\leftarrow f(i-1,j+1)\times \sum_{j \neq p}a_{i,j} \end{aligned} \]

为了避免负数下标,可以加一个常数。时间复杂度 \(O(mn^2)\)

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = getchar())); 
	x = ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
}

const int MaxN = 1e2 + 5; 
const int MaxM = 2e3 + 5; 
const int mod = 998244353; 

int n, m; 
int a[MaxN][MaxM], sum[MaxN]; 

int f[MaxN][MaxN << 1]; 

inline void add(int &x, const int &y)
{
	x += y; 
	if (x >= mod)
		x -= mod; 
}

inline void dec(int &x, const int &y)
{
	x -= y; 
	if (x < 0)
		x += mod; 
}

inline int minus(int x, const int &y)
{
	x -= y; 
	return x < 0 ? x + mod : x; 
}

int main()
{
	freopen("meal.in", "r", stdin); 
	freopen("meal.out", "w", stdout); 

	read(n), read(m); 
	for (int i = 1; i <= n; ++i)
	{
		for (int j = 1; j <= m; ++j)
		{
			read(a[i][j]); 
			add(sum[i], a[i][j]); 
		}
	}

	int ans = 1; 
    for (int i = 1; i <= n; ++i)
        ans = 1LL * ans * (sum[i] + 1) % mod; 
    
    dec(ans, 1); 
	for (int p = 1; p <= m; ++p)
	{
		f[0][n] = 1; 
		for (int i = 1; i <= n; ++i)
			for (int j = 0; j <= (n << 1); ++j)
			{
				f[i][j] = f[i - 1][j]; 
				if (j > 0)
					add(f[i][j], 1LL * f[i - 1][j - 1] * a[i][p] % mod); 
				if (j < (n << 1))
					add(f[i][j], 1LL * f[i - 1][j + 1] * minus(sum[i], a[i][p]) % mod); 
			}

		for (int i = 1; i <= n; ++i)
			dec(ans, f[n][i + n]); 
	}

	printf("%d\n", ans); 

	return 0; 
}

D2T2 partition

打表找规律题,考场上来不及了。

开始我们有一个显然的 DP 是,设 \(f(i,j)\) 表示前 \(i\) 个数,最后一段是 \([j+1,i]\),的最小平方和。这样的 DP 实现地优秀一点可以做到 \(O(n^2)\)

强烈的感觉告诉我们,这题有奇妙结论,考场上当然是打表。打表后不难发现在合法范围内,\(f(i,j)\)\(j\) 单调递减

结论的证明参考出题人myy的博客: http://matthew99.blog.uoj.ac/blog/5299

简单总结一下这个证明:

结论: 把所有解的断点从大到小写下来,然后剩下的位置补0,那么最优解对应的序列在所有位置都是最大值。(不难发现,这个定义使最优解唯一)

证明: 结论等价于,对于每个解,从后往前将每一段的和写出来,然后补无限个零,得到一个对应的序列,那么最优解对应的序列任意位置的前缀和都是最小的。

假设这个对应从后往前写出的每一段和的序列为 \(\{b_i\}\),考虑另一个解对应的序列 \(\{c_i\}\),显然不会出现某个位置 \(k\) 的前缀和 \(\sum_{i=1}^kb_i > \sum_{i=1}^kc_i\),否则就会和最优解的定义矛盾。

因此现在需要证明的就是对于任意一个满足

\[\forall k, \sum_{i=1}^kb_i\leq \sum_{i=1}^kc_i \]

的解对应序列 \(c\)\(c\)\(b\) 不同),都有

\[\sum_{i=1}^{+\infty}b_i^2 < \sum_{i=1}^{+\infty}c_i^2 \]

证明的思路是,将序列 \(c\) 经过一些使平方和减小的调整,并且使得任意时刻都满足所有位置的前缀和不小于 \(b\),最后让 \(c\) 变成 \(b\)。这样就能证明 \(c\) 的平方和不小于 \(b\) 的。(注意调整的含义是直接对 \(c\) 进行修改,在调整过程中没有必要保证 \(c\) 存在对应的原序列中的解,我们只关心 \(c\)\(b\) 平方和的大小关系)

注意到对于一个单调不增的序列 \(a\),若 \(i<j\)\(a_i>a_{i+1},a_{j-1}>a_j, a_i-a_j\geq 2\),将 \(a_i\) 减一,\(a_j\) 加一,可以使 \(a\) 仍然单调不增,并且平方和减小。

找到第一个满足 \(c_u>b_u\) 的位置 \(u\),在 \(u\) 的后面一定能找到第一个位置 \(v\) 满足 \(c_v<b_v\)(因为 \(c\)\(b\) 所有元素的总和一样)。因为 \(c\)\(b\) 都单调不增,所以 \(c_u-c_v\geq 2\),即区间 \([u,v]\) 的权值跨度至少为 \(2\)。找到最小的满足 \(i\geq u,c_{i}>c_{i+1}\) 的位置 \(i\),找到最大的满足 \(j\leq v,c_{j-1}>c_j\) 的位置 \(j\),那么 \(u\leq i<j\leq v\),并且 \(c_i-c_j\geq 2\)。于是将 \(c_i\) 减一,\(c_j\) 加一,可以使 \(c\) 仍然递增,并且平方和减小,不难发现,这么操作仍然保证了 \(c\) 的每个位置的前缀和不小于 \(b\) 的。

于是不断这么操作,一定能使 \(c\) 最终变成 \(b\),而在操作过程中平方和不断减小,于是原来的 \(c\) 的平方和是大于 \(b\) 的。\(\square\)

因此我们对于每个 \(i\) 只要记一个最优决策点即可,每个 \(i\) 的决策点一定是取在最靠右的合法位置。记这个位置为 \(p_i\)。考虑一个决策点 \(j(j<i)\) 是合法的当且仅当 \(s_i-s_j\geq s_j-s_{p_j}\),其中 \(s_i\) 表示 \(1\dots i\) 的前缀和。

移项一下这个条件就是 \(2s_j-s_{p_j} \leq s_i\),到这里可以用一个BIT 做到 \(O(n \log n)\)

但是实际上,不难发现 \(s_i\) 是单增的,也就是说一个决策点 \(j\) 对于当前的 \(i\) 是合法的,那么对于后面的肯定仍是合法的。于是我们考虑维护一个\(2s_j-s_{p_j}\) 单增的单调队列。对于每个 \(i\) 就从队头开始找到最后一个合法决策点,再将 \(i\) 插到队尾,并且将队尾的那些 \(2s_j-s_{p_j}\geq 2s_i-s_{p_i}\) 的决策点弹掉。这样的时间复杂度就是 \(O(n)\) 了。

本题还有个问题就是,你不能一边做这个 DP 一边算这个 DP 值,得把 \(p_i\) 存起来最后算(因为空间不够)。高精还需要压位/二进制。当然,你在 OJ 上用 __int128 也是可以的。

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	static bool opt; 
	while (!isdigit(ch = getchar()) && ch != '-'); 
	x = (opt = ch == '-') ? 0 : ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
	if (opt)
		x = ~x + 1; 
}

template <class T>
inline void putint(T x)
{
	static char buf[45], *tail = buf; 
	if (!x)
		putchar('0'); 
	else
	{
		if (x < 0)
		{
			putchar('-'); 
			x = ~x + 1; 
		}
		for (; x; x /= 10) *++tail = x % 10 + '0'; 
		for (; tail != buf; --tail) putchar(*tail); 
	}
}

typedef long long s64; 

const int MaxN = 4e7 + 5; 
const s64 mod = 1e9; 

int n, type, ql, qr; 
int que[MaxN], maxp[MaxN], b[MaxN]; 
s64 s[MaxN]; 

struct bignum
{
	int len; 
	s64 a[7]; 
	bignum(){}
	bignum(s64 t)
	{
		len = 1; 
		memset(a, 0, sizeof(a)); 

		a[1] = t % mod; 
		if (t >= mod)
			a[++len] = t / mod; 
	}

	inline void operator += (const bignum &rhs)
	{
		len = std::max(len, rhs.len); 
		for (int i = 1; i <= len; ++i)
		{
			a[i] += rhs.a[i]; 
			a[i + 1] += a[i] / mod; 
			a[i] %= mod; 
		}
		if (a[len + 1])
			++len; 
	}

	inline bignum operator * (const bignum &rhs) const
	{
		bignum res(0); 
		res.len = len + rhs.len; 
		for (int i = 1; i <= len; ++i)
			for (int j = 1; j <= rhs.len; ++j)
				res.a[i + j - 1] += a[i] * rhs.a[j]; 
		for (int i = 1; i < res.len; ++i)
		{
			res.a[i + 1] += res.a[i] / mod; 
			res.a[i] %= mod; 
		}
		while (res.len > 1 && !res.a[res.len])
			--res.len; 
		return res; 
	}

	inline void print()
	{
		printf("%d", (int)a[len]); 
		for (int i = len - 1; i >= 1; --i)
			printf("%09d", (int)a[i]); 
	}
}res(0); 

inline s64 calc(int x)
{
	return 2 * s[x] - s[maxp[x]]; 
}

int main()
{
	freopen("partition.in", "r", stdin); 
	freopen("partition.out", "w", stdout); 

	read(n), read(type); 
	if (type == 0)
	{
		for (int i = 1; i <= n; ++i)
		{
			int x; 
			read(x); 
			s[i] = s[i - 1] + x; 
		}
	}
	else
	{
		int x, y, z, m; 
		read(x), read(y), read(z), read(b[1]), read(b[2]), read(m); 
		for (int i = 1, lstp = 0; i <= m; ++i)
		{
			int p, l, r; 
			read(p), read(l), read(r); 
			for (int j = lstp + 1; j <= p; ++j)
			{
				if (j > 2)
					b[j] = (1LL * x * b[j - 1] + 1LL * y * b[j - 2] + z) % (1 << 30); 
				s[j] = s[j - 1] + b[j] % (r - l + 1) + l; 
			}
			lstp = p; 
		}
	}
	
	que[ql = qr = 1] = 0; 
	for (int i = 1; i <= n; ++i)
	{
		while (ql < qr && calc(que[ql + 1]) <= s[i])
			++ql; 
		maxp[i] = que[ql]; 
		while (ql <= qr && calc(que[qr]) >= calc(i))
			--qr; 
		que[++qr] = i; 
	}

	int cur = n; 
	while (cur)
	{ 
		res += bignum(s[cur] - s[maxp[cur]]) * bignum(s[cur] - s[maxp[cur]]); 
		cur = maxp[cur]; 
	}

	res.print(); 
	
	return 0; 
}

D2T3 centroid

不难题,就我不会。因为这是 D2T3 所以我只想着写暴力,实际上这题挺简单的。

将题目中的计算方式转化为:对于每个点,计算它成为重心的方案数。

那么考虑一个点 \(u\),将其硬点为,对于它的每个儿子 \(v\),相当于在 \(v\) 的子树再选出一个大小为 \(s_0\) 的子树,这个 \(s_0\) 需要在一个范围。具体地,我们记 \(s'\) 表示 \(u\) 除了 \(v\) 以外的其他儿子的最大子树大小,记 \(s_v\) 表示 \(v\) 的子树大小。

那么就有

\[s_v-s_0 \leq \left\lfloor\frac {n-s_0}2\right\rfloor \\ s_0 \geq 2s_v-n \]

以及

\[s^\prime\leq \left\lfloor \frac {n-s_0}2\right\rfloor\\ s_0\leq n-2s^\prime \]

那么就要求 \(s_0\in [2s_v-n,n-2s^\prime]\)

对应到原树(钦定 \(1\) 为根的有根树)上,这相当于分两类讨论:

  • \(v\)\(u\) 的一个儿子,这个时候直接在 dfs 时用一个 BIT 存下当前结点到根的路径上所有点的询问区间对应的贡献,然后遍历到一个点的时候在 BIT 上单点查询即可。
  • \(v\)\(u\) 的父亲,这时候的统计比较复杂,需要分成几块统计:
    • 原树中不在 \(u\) 子树中且不在 \(u\) 到根路径上的结点的 \(size\in [2(n-s_u)-n,n-2s^\prime]\),这个可以用整棵树\(size\) 在这个区间的点数减去子树内 \(size\) 在这个区间的点数,再减去到根结点的路径上的点 \(size\) 在这个区间的点数。子树询问可以归到第一类的 BIT 中,到根结点的可以另外再维护一个 BIT。
    • 原树中 \(u\) 到根的路径上,向父亲方向的「子树」。这个可以 dfs 的时候归到上面第二个 BIT 维护。

那么这样我们就可以用 BIT 来做了。时间复杂度 \(O(n\log n)\),常数有点大,可能打不过线段树选手。

(补充:我太菜了,其实那些询问看成二维数点然后离线 BIT 就可以了,不用我这么麻烦,但是这样常数好像也很大

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = getchar())); 
	x = ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
}

template <class T>
inline void putint(T x)
{
	static char buf[25], *tail = buf; 
	if (!x)
		putchar('0'); 
	else
	{
		for (; x; x /= 10) *++tail = x % 10 + '0'; 
		for (; tail != buf; --tail) putchar(*tail); 
	}
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y)
		x = y; 
}

template <class T>
inline void tense(T &x, const T &y)
{
	if (x > y)
		x = y; 
}

typedef long long s64; 

const int MaxNV = 3e5 + 5; 
const int MaxNE = MaxNV << 1; 

int n; 
int ect, adj[MaxNV]; 
int nxt[MaxNE], to[MaxNE]; 

int fa[MaxNV], sze[MaxNV]; 
s64 ans, sum[MaxNV], bit_q[MaxNV], bit_u[MaxNV]; 

int max_sze[MaxNV][2]; 

#define trav(u) for (int e = adj[u], v; v = to[e], e; e = nxt[e])

inline void addEdge(int u, int v)
{
	nxt[++ect] = adj[u]; 
	adj[u] = ect; 
	to[ect] = v; 
}

inline void init()
{
	read(n); 

	ect = 0; 
	ans = 0; 
	for (int i = 1; i <= n; ++i)
	{
		adj[i] = sum[i] = 0; 
		bit_q[i] = bit_u[i] = 0; 
		max_sze[i][0] = max_sze[i][1] = 0; 
	}

	for (int i = 1; i < n; ++i)
	{
		int u, v; 
		read(u), read(v); 
		addEdge(u, v); 
		addEdge(v, u); 
	}
}

inline void bit_modify(int x, int val, s64 *bit)
{
	for (; x <= n; x += x & -x)
		bit[x] += val; 
}

inline s64 bit_query(int x, s64 *bit)
{
	s64 res = 0; 
	for (; x; x ^= x & -x)
		res += bit[x]; 
	return res; 
}

inline void seg_modify(int l, int r, int val, s64 *bit)
{
	relax(l, 1), tense(r, n); 
	if (r < l) return; 

	bit_modify(l, val, bit); 
	bit_modify(r + 1, -val, bit); 
}

inline s64 seg_query(int l, int r, s64 *bit)
{
	relax(l, 1), tense(r, n); 
	if (r < l) return 0; 
	return bit_query(r, bit) - bit_query(l - 1, bit); 
}

inline void upt(int x, int s)
{
	if (s >= max_sze[x][0])
	{
		max_sze[x][1] = max_sze[x][0]; 
		max_sze[x][0] = s; 
	}
	else
		relax(max_sze[x][1], s); 
}

inline int max_else(int x, int s)
{
	if (s == max_sze[x][0])
		return max_sze[x][1]; 
	return max_sze[x][0]; 
}

inline void dfs_init(int u)
{
	sze[u] = 1; 
	trav(u)
		if (v != fa[u])
		{
			fa[v] = u; 
			dfs_init(v); 

			sze[u] += sze[v]; 
			upt(u, sze[v]); 
		}
	upt(u, n - sze[u]); 
}

inline void dfs(int u)
{
	int ul = std::max(1, 2 * (n - sze[u]) - n); 
	int ur = std::min(n, n - 2 * max_else(u, n - sze[u])); 
	if (ul <= ur)
		ans += (sum[ur] - sum[ul - 1]) * u; 
	seg_modify(ul, ur, -u, bit_q); 
	ans += bit_query(sze[u], bit_q); 
	ans += seg_query(ul, ur, bit_u) * u; 

	trav(u)
		if (v != fa[u])
		{
			int t_s = max_else(u, sze[v]); 
			int l = 2 * sze[v] - n, r = n - 2 * t_s; 
			
			seg_modify(l, r, u, bit_q); 
			bit_modify(sze[u], -1, bit_u); 
			bit_modify(n - sze[v], 1, bit_u); 

			dfs(v); 

			seg_modify(l, r, -u, bit_q);
			bit_modify(sze[u], 1, bit_u); 
			bit_modify(n - sze[v], -1, bit_u);  
		}

	seg_modify(ul, ur, u, bit_q); 
}

inline void solve()
{
	dfs_init(1); 

	for (int i = 1; i <= n; ++i)
		++sum[sze[i]]; 
	for (int i = 1; i <= n; ++i)
		sum[i] += sum[i - 1]; 

	dfs(1); 

	putint(ans); 
	putchar('\n'); 
}

int main()
{
	freopen("centroid.in", "r", stdin); 
	freopen("centroid.out", "w", stdout); 

	int orzczk; 
	read(orzczk); 
	while (orzczk--)
	{
		init(); 
		solve(); 

	}
	return 0; 
}
原文地址:https://www.cnblogs.com/cyx0406/p/11908087.html