[题解][Codeforces]Codeforces Round #635 (Div. 1) 简要题解

  • Chinese Round 果然对中国选手十分友好(

  • 原题解

A

题意

  • 给定一棵 (n) 个节点的有根树和一个 (k),满足 (1le kle n)

  • 选出 (k) 个点为黑点,其他点为白点

  • 求所有黑点到根的路径上白点个数之和的最大值

  • (1le nle 2 imes 10^5)

做法:贪心

  • 显然一个点为黑点则其子树全为黑点

  • 故问题可以视为 (k) 次,每次删掉一个叶子 (u),贡献为原树(dep_u-size_u)

  • 由于父亲的 (dep-size) 一定小于子节点,故取 (dep-size) 从大到小排序之后前 (k) 大的即可

  • (O(nlog n))

  • 利用 nth_element 可以做到 O(n)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 2e5 + 5, M = N << 1;

int n, k, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N], d[N], sze[N], a[N];
ll ans;

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}

void dfs(int u, int fu)
{
	fa[u] = fu; dep[u] = dep[fu] + 1; sze[u] = 1;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu) dfs(v, u), d[u]++, sze[u] += sze[v];
}

int main()
{
	int x, y;
	read(n); read(k);
	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
	dfs(1, 0);
	for (int i = 1; i <= n; i++) a[i] = dep[i] - sze[i];
	std::sort(a + 1, a + n + 1);
	for (int i = n - k + 1; i <= n; i++) ans += a[i];
	return std::cout << ans << std::endl, 0;
}

B

题意

  • 给定三个长度分别为 (n_r,n_g,n_b) 的数组 (r,g,b)

  • 从三个数组中各选一个数,设为 (x,y,z),求 ((x-y)^2+(y-z)^2+(z-x)^2) 的最小值

  • (1le n_r,n_g,n_ble 10^5)(1le r_i,g_i,b_ile 10^9)

做法:枚举+双指针

  • 假设 (xle yle z),则最优情况下 (x) 要尽可能大,(y) 要尽可能小

  • 故把三个数组排序,枚举 (x,y,z) 大小关系的 (6) 种排列之后,枚举 (y) 的值,用指针维护最大的 (x) 和最小的 (z)

  • (O(n_rlog n_r+n_glog n_g+n_blog n_b))

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 1e5 + 5;
const ll INF = 5e18;

int nr, ng, nb, r[N], g[N], b[N];

ll sqr(int x) {return 1ll * x * x;}

ll solve(int na, int nb, int nc, int *a, int *b, int *c)
{
	ll ans = INF;
	for (int i = 1, j = 1, k = 1; j <= nb; j++)
	{
		while (i <= na && a[i] <= b[j]) i++;
		while (k <= nc && b[j] > c[k]) k++;
		if (i > 1 && k <= nc) ans = std::min(ans,
			sqr(a[i - 1] - b[j]) + sqr(b[j] - c[k]) + sqr(c[k] - a[i - 1]));
	}
	return ans;
}

void work()
{
	read(nr); read(ng); read(nb);
	for (int i = 1; i <= nr; i++) read(r[i]);
	for (int i = 1; i <= ng; i++) read(g[i]);
	for (int i = 1; i <= nb; i++) read(b[i]);
	std::sort(r + 1, r + nr + 1); std::sort(g + 1, g + ng + 1);
	std::sort(b + 1, b + nb + 1);
	ll ans = solve(nr, ng, nb, r, g, b);
	ans = std::min(ans, solve(nr, nb, ng, r, b, g));
	ans = std::min(ans, solve(nb, nr, ng, b, r, g));
	ans = std::min(ans, solve(nb, ng, nr, b, g, r));
	ans = std::min(ans, solve(ng, nr, nb, g, r, b));
	ans = std::min(ans, solve(ng, nb, nr, g, b, r));
	printf("%lld
", ans);
}

int main()
{
	int T; read(T);
	while (T--) work();
	return 0;
}

C

题意

  • 给定长度为 (n) 的串 (S) 和长度为 (m) 的串 (T)

  • 一开始有一个空串 (A)

  • 每次操作可以选择把 (S) 的第一个字符加入 (A) 的开头或末尾,并把 (S) 的第一个字符删掉

  • 你可以执行任意不超过 (n) 的操作次数,求最后能使得 (T)(A) 的前缀的方案数,对 (998244353) 取模

  • (1le mle nle 3000)

做法:区间 DP

  • (f[l,r]) 表示插入了 (S) 的前 (r-l+1) 个字符,它们组成了最终的 (A) 串的区间 ([l,r]) 的方案数

  • 组成最终的 (A) 串的区间 ([l,r]),也就是说若 (iin[l,r])(ile m),则 (A_i=T_i)

  • 转移即枚举最后一个字符加在左边还是右边,判断其是否符合限制条件即可

  • 答案为 (sum_{i=m}^nf[1,i])

  • (O(n^2))

代码

#include <bits/stdc++.h>

const int N = 3005, djq = 998244353;

int n, m, f[N][N], ans;
char s[N], t[N];

int main()
{
	scanf("%s%s", s + 1, t + 1);
	n = strlen(s + 1); m = strlen(t + 1);
	for (int i = 1; i <= n + 1; i++) f[i][i - 1] = 1;
	for (int l = n; l >= 1; l--)
		for (int r = l; r <= n; r++)
		{
			if (l > m || s[r - l + 1] == t[l]) f[l][r] += f[l + 1][r];
			if (r > m || s[r - l + 1] == t[r]) f[l][r] += f[l][r - 1];
			if (f[l][r] >= djq) f[l][r] -= djq;
			if (l == 1 && r >= m)
				ans = (ans + f[l][r]) % djq;
		}
	return std::cout << ans << std::endl, 0;
}

D

题意

  • 交互题

  • 你有一堆麻将,点数从 (1)(n),每种点数的麻将个数在 ([0,n]) 之间,但你不知道它们具体是多少

  • 初始时可以知道这堆麻将中,碰(大小为 (3) 且点数相同的子集)的个数和吃(大小为 (3) 且点数形成公差为 (1) 的等差数列)的个数

  • 然后你可以加入最多 (n) 次某一种点数的麻将,加入一个麻将之后你可以得到此时碰和吃的个数

  • 还原初始时每种点数的麻将个数

  • (4le nle 100)

做法:数学

  • 当前(i) 种麻将有 (c_i) 个,则加入一个第 (i) 种麻将时会多出 (inom{c_i}2) 个碰和 (c_{i-2}c_{i-1}+c_{i-1}c_{i+1}+c_{i+1}c_{i+2}) 个吃

  • 如果只考虑吃的个数,则如果保证 (c_i>0) 则可以通过碰的个数的增量还原出 (c_i)

  • 考虑求点数为 (1) 的个数,可以得到如果事先加入一个 (1),就能保证 (c_i>0),再加入一个 (1) 即可查出 (ans_1)

  • 而加入 (1) 的好处是吃的个数增量为 (c_2c_3)

  • 于是考虑依次加入 (3,1,2,1),这样第二次吃的个数增量为 (ans_2(ans_3+1)),第四次吃的个数增量为 ((ans_2+1)(ans_3+1))

  • 这两个式子作差即可得到 (ans_3)。由于 (ans_3+1>0),故可以使用除法得到 (ans_2)

  • 而实际上我们也可以得到 (ans_4):考虑第三次吃的个数增量:((ans_3+1)(ans_1+1+ans_4)),也可以利用除法得到

  • 而对于 (i>4),也可以加入一个 (i-2),这时吃的个数增量表达式中只有 (ans_i) 是未知量,可以解出来。不过这样有一个问题:(ans_{i-1}) 可能为 (0),这样的方程会有无穷多个解

  • 故考虑倒着加:(n-1,n-2,dots,3,1,2,1)

  • 易得 (3,1,2,1) 移到最后不影响 (ans_{1dots 4}) 的求解,只是 (n>4) 时这样求解出来的 (ans_4) 需要减 (1)(在 (n-1,n-2,dots 4) 中加上了 (1)

  • 然后 (i)(3)(n-2),利用 (i) 被加入时吃的个数增量来解出 (ans_{i+2}),由于 (i+1) 在之前的过程中加过了 (1),故可以保证 (c_{i+1}) 不为 (0),这个方程一定可以解出来

  • (O(n)),操作次数为 (n)

代码

#include <bits/stdc++.h>

const int N = 110, M = N * N;

int n, ans[N], f[M], a[N], b[N];

void add(int v) {printf("+ %d
", v); fflush(stdout);}

int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n + 1; i++) f[i * (i - 1) >> 1] = i;
	scanf("%*d%*d");
	for (int i = 1; i <= n - 4; i++) add(n - i), scanf("%d%d", &a[i], &b[i]);
	add(3); scanf("%d%d", &a[n - 3], &b[n - 3]);
	add(1); scanf("%d%d", &a[n - 2], &b[n - 2]);
	add(2); scanf("%d%d", &a[n - 1], &b[n - 1]);
	add(1); scanf("%d%d", &a[n], &b[n]);
	ans[1] = f[a[n] - a[n - 1]] - 1;
	ans[3] = (b[n] - b[n - 1]) - (b[n - 2] - b[n - 3]) - 1;
	ans[2] = (b[n] - b[n - 1]) / (ans[3] + 1) - 1;
	ans[4] = (b[n - 1] - b[n - 2]) / (ans[3] + 1) - (ans[1] + 1) - (n > 4);
	for (int i = n - 3; i >= 2; i--)
	{
		int x = n - i;
		ans[x + 2] = (b[i] - b[i - 1] - ans[x - 2] * ans[x - 1] - ans[x - 1]
			* (ans[x + 1] + 1)) / (ans[x + 1] + 1) - (i > 2);
	}
	printf("! ");
	for (int i = 1; i <= n; i++) printf("%d ", ans[i]);
	return puts(""), 0;
}

E1

题意

  • 给定 (n)([0,2^m)) 内的数

  • 对于所有的 (0le ile m),求这些数有多少个子集的异或和,二进制下 (1) 的个数为 (i)

  • (1le nle 2 imes10^5)(0le mle 35)

做法:线性基+枚举((k) 较小)/DP((k) 较大)

  • 由于 E2 比 E1 难太太太多,就分开讲了

  • 显然先求线性基,设这个基由 (k) 个元素组成

  • 原一个子集的异或和可以表示成线性基内一个子集的异或和,再选上线性基外的一部分 (0),也就是线性基内一个子集的贡献为 (2^{n-k})

  • (k) 较小的时候,可以暴力枚举每个基变量是否选上:(O(2^k))

  • (k) 较大的时候,可以高斯消元求出简化阶梯矩阵(若矩阵第 (i) 行第 (i) 列为 (1) 则第 (i) 列的其他元素均为 (0)),然后 DP (f_{i,j,S}) 表示前 (i) 个基变量中选出了 (j) 个,不在基上的位异或和为 (S) 的方案数,统计答案时答案 (ans_{j+popcount(S)}+=f_{m-k,j,S})(O(2^{m-k}k^2))

  • 结合这两种算法可过 E1

代码

#include <bits/stdc++.h>
 
template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}
 
typedef long long ll;
 
const int N = 2e5 + 5, E = 40, C = 17000, djq = 998244353;
 
int n, m, orz = 1, cnt1, p1[N], cnt0, p0[N], f[E][E][C], st[E], ans[E];
ll a[N], b[E];
 
void ins(ll x)
{
	for (int i = m - 1; i >= 0; i--)
	{
		if (!((x >> i) & 1)) continue;
		if (b[i] == -1) return (void) (b[i] = x);
		else x ^= b[i];
	}
	orz = (orz << 1) % djq;
}
 
int cc(ll x)
{
	int res = 0;
	while (x) res += x & 1, x >>= 1;
	return res;
}
 
int main()
{
	read(n); read(m);
	for (int i = 0; i < m; i++) b[i] = -1;
	for (int i = 1; i <= n; i++) read(a[i]), ins(a[i]);
	for (int i = 0; i < m; i++) if (b[i] != -1)
		for (int j = i + 1; j < m; j++)
			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
	for (int i = 0; i < m; i++)
		if (b[i] != -1) p1[++cnt1] = i;
		else p0[++cnt0] = i;
	if (cnt1 <= 20)
	{
		for (int S = 0; S < (1 << cnt1); S++)
		{
			ll T = 0;
			for (int i = 1; i <= cnt1; i++)
				if ((S >> i - 1) & 1) T ^= b[p1[i]];
			ans[cc(T)]++;
		}
	}
	else
	{
		for (int i = 1; i <= cnt1; i++)
			for (int j = 1; j <= cnt0; j++)
				if ((b[p1[i]] >> p0[j]) & 1) st[i] |= 1 << j - 1;
		f[0][0][0] = 1;
		for (int i = 0; i < cnt1; i++)
			for (int j = 0; j <= i; j++)
				for (int S = 0; S < (1 << cnt0); S++)
				{
					f[i + 1][j][S] = (f[i + 1][j][S] + f[i][j][S]) % djq;
					f[i + 1][j + 1][S ^ st[i + 1]] = (f[i + 1][j + 1][S ^ st[i + 1]]
						+ f[i][j][S]) % djq;
				}
		for (int j = 0; j <= cnt1; j++)
			for (int S = 0; S < (1 << cnt0); S++)
			{
				int x = j + cc(S);
				ans[x] = (ans[x] + f[cnt1][j][S]) % djq;
			}
	}
	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
	puts("");
	return 0;
}

E2

题意

  • 同 E1,(0le mle 53)

做法:FWT+组合数学

  • 妙啊!!!( imes 4)

  • 考虑对于 E1 的第二种算法,把复杂度去掉两个 (k)

  • (A_S) 表示 (S) 是否能被线性基表出,(F^c_S) 表示 (S)(1) 的个数是否为 (c)

  • 我们不难 (neng) 想到 (ans_c) 等于 (FWT(A) imes FWT(F^c)) 所有项之和(这里的 ( imes) 是点乘)除以 (2^m) 后的结果(因为要做 IFWT)

  • 接下来考虑 (FWT(A)) 的性质

(FWT(A)) 仅由 (0)(2^k) 组成,且第 (S) 位为 (2^k) 当且仅当 (S) 与线性基内所有变量的交集大小都是偶数

  • 证明:

(S) 与所有基变量的交集大小都是偶数,由于 (S)(Tigoplus U) 的交集大小在奇偶性上等于 (Scap T)(Scap U) 的大小之和,故 (S) 与这个基表出的所有 (2^k) 个数的交集大小都为偶数,由 FWT 的定义可知 (FWT(A)) 的第 (S) 位为 (2^k)
否则 (S) 与这个基表出的所有 (2^k) 个数的交集大小中奇偶各占一半,由 FWT 的定义可知 (FWT(A)) 的第 (S) 位为 (0)

另一个性质:

(FWT(A)) 中为 (2^k) 的位只有 (2^{m-k}) 个,且组成另一个基

  • 证明:

(FWT(A)) 中第 (S) 位为 (2^k) 的条件转化一下:对于一个不在基上的位 (i),如果让第 (i) 位为 (1),则对于每个满足第 (i) 位为 (1) 的基变量 (j),要让 (S) 的第 (j) 位也异或上 (1)
这样就有了 (m-k) 个基变量,由于每个基变量的最低位互不相同,故它们可以组成一个基
但原线性基必须是简化阶梯矩阵,否则在基上的位 (i) 也会对其他在基上的位 (j) 造成影响

  • 于是求出这个大小为 (m-k) 的基后暴力枚举每个变量选或不选,即可得到 (FWT(A)) 中所有为 (2^k) 的位

  • 再考虑 (FWT(F^c)),容易发现 (FWT(F^c)) 的第 (S) 位值只和 (S)(1) 的个数有关

  • 即对于 (S),枚举一个 (1) 的个数为 (c)(T) 贡献 ((-1)^{|Scap T|}),相当于枚举一个 (i) 表示 (S)(T) 表示 (S)(T) 的交集大小

  • 于是 (FWT(F^c)) 包含 (d)(1) 的位值均为:

  • [w_{c,d}=sum_{i=0}^{min(c,d)}(-1)^iinom diinom{m-d}{c-i} ]

  • (FWT(A)) 中含 (c)(1) 的下标有 (q_c)(2^k),则:

  • [ans_c=frac 1{2^{m-k}}sum_{d=0}^mq_dw_{c,d} ]

  • 结合 (k) 较小的暴力枚举,复杂度为 (O(2^{frac m2}+m^3+n))

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 60, djq = 998244353, i2 = 499122177;

int n, m, orz = 1, cnt1, p[N], cnt0, cnt[N], ans[N], C[N][N];
ll b[N], a[N];

void ins(ll x)
{
	for (int i = m - 1; i >= 0; i--)
	{
		if (!((x >> i) & 1)) continue;
		if (b[i] == -1) return (void) (b[i] = x);
		else x ^= b[i];
	}
	orz = (orz << 1) % djq;
}

void dfs(int dep, int tar, ll T)
{
	if (dep == tar + 1) return (void) (ans[__builtin_popcountll(T)]++);
	dfs(dep + 1, tar, T); dfs(dep + 1, tar, T ^ a[dep]);
}

int main()
{
	ll x;
	read(n); read(m);
	for (int i = 0; i < m; i++) b[i] = -1;
	for (int i = 1; i <= n; i++) read(x), ins(x);
	for (int i = 0; i < m; i++) if (b[i] != -1)
		for (int j = i + 1; j < m; j++)
			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
	for (int i = 0; i < m; i++) if (b[i] != -1) a[++cnt1] = b[i];
	if (cnt1 <= 26) dfs(1, cnt1, 0);
	else
	{
		for (int i = 0; i < m; i++) if (b[i] == -1)
		{
			a[++cnt0] = 1ll << i;
			for (int j = i + 1; j < m; j++) if (b[j] != -1 && ((b[j] >> i) & 1))
				a[cnt0] |= 1ll << j;
		}
		dfs(1, cnt0, 0);
		for (int i = 0; i <= m; i++) cnt[i] = ans[i], ans[i] = 0, C[i][0] = 1;
		for (int i = 1; i <= m; i++)
			for (int j = 1; j <= i; j++)
				C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % djq;
		int I = 1;
		for (int i = 1; i <= cnt0; i++) I = 1ll * I * i2 % djq;
		for (int i = 0; i <= m; i++)
			for (int j = 0; j <= m; j++)
			{
				int pl = 0;
				for (int k = 0; k <= j && k <= i; k++)
				{
					int delta = 1ll * C[j][k] * C[m - j][i - k] % djq;
					if (k & 1) pl = (pl - delta + djq) % djq;
					else pl = (pl + delta) % djq;
				}
				ans[i] = (1ll * I * pl % djq * cnt[j] + ans[i]) % djq;
			}
	}
	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
	return puts(""), 0;
}

F

题意

  • 给定 (n) 个节点的树,(m) 条路径和一个 (k)

  • 求有多少对路径的交至少包含 (k) 条边

  • (2le n,mle 1.5 imes10^5)(1le kle n)

做法:分类讨论+倍增+BIT+线段树

  • 任选一个根,先考虑相交的两条路径 LCA 不同的情况

  • 此时可以把一条路径拆成两条((s_i)(lca_i)(t_i)(lca_i))来看待

  • 下面设拆完之后的路径为 ((up_i,down_i))(up_i) 的深度较小

  • 考虑当 (dep_{up_i}<dep_{up_j}) 时,第 (i) 条和第 (j) 条路径交集至少为 (k) 当且仅当 (up_j) 沿着 (down_j) 的方向走 (k) 步之后还在路径 ((down_i,up_i))

  • 用倍增处理出每个 (up_i) 沿着 (down_i) 的方向走 (k) 步之后到达的点,用 DFS序+差分+BIT 进行单点加和路径查询即可

  • 再考虑 LCA 相同的情况,设这个 LCA 为 (u),这时又分两种:

  • (1)设对于所有的 (i) 都有 (s_i) 的 DFS 序小于 (t_i),则 (s_i)(s_j) 都不为 (u) 且在 (u) 的同一棵子树内,(t_i)(t_j) 也一样

  • (2)反之

  • 先考虑(2),设路径 (i)((x_i,u)) 部分和路径 (j)((x_j,u)) 部分有交集((x_i,x_j) 为路径 (i,j) 的端点之一)

  • 同样地,这相当于 (u) 沿着 (x_i) 向下走 (k) 步和沿着 (x_j) 向下走 (k) 步到达的点相同,也可以拆成两条之后用和之前类似的方法处理

  • 而对于(1),考虑 (v=lca(s_i,s_j)),方案合法当且仅当:

  • (1)(u)(v) 的严格祖先

  • (2)(dep_v-dep_uge k)(v) 朝着 (t_i)(dep_v-dep_u+1) 步之后的节点子树内包含 (t_j)

  • (3)(dep_v-dep_u<k)(v) 朝着 (t_i)(k) 步之后的节点子树内包含 (t_j)

  • 这三个条件中(1)满足且(2)(3)满足一者

  • 如果 (i) 的取值集合和 (j) 的取值集合给定(不交),则可以建立 (n) 棵动态开点线段树,维护每个 LCA 的路径的 (t)

  • 把所有 (j) 插入到第 (lca_j) 棵线段树的 (dfn_{t_j}) 位置之后,对于每个 (i) 查询第 (lca_i) 棵线段树上某个节点的子树和即可

  • 回到原问题,可以 dsu-on-tree:对这棵树每个非叶节点找出一个 preferred child(即设 (cnt_u=sum_i[s_i=u]),preferred child 为 (cnt_u) 的和最大的子树),然后 dfs 的过程中,先递归轻儿子并把线段树上的东西清掉,然后递归重儿子,这时不要把线段树上的东西清掉,把重子树以外的所有路径的 (s) 加入并统计答案

  • 期间可用一个 set 维护当前子树内的所有路径

  • (O(mlog^2m+nlog n))

  • 本题的巧妙之处就在于,使用了从交点处移动 (k) 步的方法,来判断两条路径的交长度是否 (ge k)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;
typedef std::set<int>::iterator it;

const int N = 15e4 + 5, M = N << 1, L = 1e7 + 5, E = 20;

int n, m, k, ecnt, nxt[M], adj[N], go[M], times, dfn[N], dep[N], fa[N][E],
s[N], t[N], l[N], p[N], A[N], sze[N], cnt[N], son[N], rt[N], ToT, top, stk[M];
ll ans;
std::set<int> orz[N];
std::vector<int> a[N], b[N];

struct node
{
	int lc, rc, sum;
} T[L];

void change(int l, int r, int pos, int v, int &p)
{
	if (!p) p = ++ToT; T[p].sum += v;
	if (l == r) return;
	int mid = l + r >> 1;
	if (pos <= mid) change(l, mid, pos, v, T[p].lc);
	else change(mid + 1, r, pos, v, T[p].rc);
}

int ask(int l, int r, int s, int e, int p)
{
	if (!p || e < l || s > r) return 0;
	if (s <= l && r <= e) return T[p].sum;
	int mid = l + r >> 1;
	return ask(l, mid, s, e, T[p].lc) + ask(mid + 1, r, s, e, T[p].rc);
}

void change(int x, int v)
{
	for (; x <= n; x += x & -x)
		A[x] += v;
}

void sub(int u) {change(dfn[u], 1); change(dfn[u] + sze[u], -1);}

int ask(int x)
{
	int res = 0;
	for (; x; x -= x & -x) res += A[x];
	return res;
}

inline bool comp(int a, int b)
{
	return dep[l[a]] > dep[l[b]] || (dep[l[a]] == dep[l[b]] && l[a] < l[b]);
}

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}

void dfs(int u, int fu)
{
	dep[u] = dep[fa[u][0] = fu] + (sze[u] = 1);
	for (int i = 0; i < 17; i++) fa[u][i + 1] = fa[fa[u][i]][i];
	dfn[u] = ++times;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu) dfs(v, u), sze[u] += sze[v];
}

int lca(int u, int v)
{
	if (dep[u] < dep[v]) std::swap(u, v);
	for (int i = 17; i >= 0; i--)
		if (dep[fa[u][i]] >= dep[v])
			u = fa[u][i];
	if (u == v) return u;
	for (int i = 17; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}

int J(int u, int k)
{
	for (int i = 17; i >= 0; i--)
		if ((k >> i) & 1) u = fa[u][i];
	return u;
}

void init(int u, int fu)
{
	int mx = -1;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu)
		{
			init(v, u); cnt[u] += cnt[v];
			if (cnt[v] > mx) mx = cnt[v], son[u] = v;
		}
}

void wtf(int u, int i)
{
	if (dfn[l[i]] >= dfn[u] || dfn[u] >= dfn[l[i]] + sze[l[i]]) return;
	int len = dep[u] + dep[t[i]] - dep[l[i]] * 2;
	if (len < k || t[i] == l[i]) return;
	int v = dep[u] - dep[l[i]] >= k ? J(t[i], dep[t[i]] - dep[l[i]] - 1)
		: J(t[i], len - k);
	ans += ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]);
	if (ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]));
}

void DFS(int u, int fu)
{
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu && v != son[u])
		{
			DFS(v, u);
			for (it x = orz[v].begin(); x != orz[v].end(); x++)
				change(1, n, dfn[t[*x]], -1, rt[l[*x]]);
		}
	if (son[u]) DFS(son[u], u);
	for (it x = orz[u].begin(); x != orz[u].end(); x++)
		wtf(u, *x), change(1, n, dfn[t[*x]], 1, rt[l[*x]]);
	if (son[u])
	{
		for (int e = adj[u], v; e; e = nxt[e])
		{
			if ((v = go[e]) == fu || v == son[u]) continue;
			for (it x = orz[v].begin(); x != orz[v].end(); x++) wtf(u, *x);
			for (it x = orz[v].begin(); x != orz[v].end(); x++)
				change(1, n, dfn[t[*x]], 1, rt[l[*x]]), orz[son[u]].insert(*x);
		}
		for (it x = orz[u].begin(); x != orz[u].end(); x++)
			orz[son[u]].insert(*x);
		std::swap(orz[u], orz[son[u]]);
	}
}

int main()
{
	int x, y;
	read(n); read(m); read(k);
	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
	dfs(1, 0);
	for (int i = 1; i <= m; i++)
	{
		read(s[i]); read(t[i]);
		if (dfn[s[i]] > dfn[t[i]]) std::swap(s[i], t[i]);
		l[i] = lca(s[i], t[i]); p[i] = i;
		orz[s[i]].insert(i); cnt[s[i]]++; a[l[i]].push_back(i);
	}
	std::sort(p + 1, p + m + 1, comp);
	for (int i = 1; i <= m;)
	{
		int nxt = i;
		while (nxt <= m && l[p[i]] == l[p[nxt]]) nxt++;
		for (int j = i; j < nxt; j++)
		{
			int x = p[j], u = s[x], v = t[x], w = l[x];
			ans += ask(dfn[u]) + ask(dfn[v]) - ask(dfn[w]) * 2;
		}
		for (int j = i; j < nxt; j++)
		{
			int x = p[j], u = s[x], v = t[x], w = l[x];
			if (dep[u] - dep[w] >= k) sub(J(u, dep[u] - dep[w] - k));
			if (dep[v] - dep[w] >= k) sub(J(v, dep[v] - dep[w] - k));
		}
		i = nxt;
	}
	memset(A, 0, sizeof(A));
	for (int u = 1; u <= n; u++)
	{
		for (int i = 0; i < a[u].size(); i++)
		{
			int x = a[u][i];
			if (dep[s[x]] - dep[u] >= k)
			{
				ans += A[y = J(s[x], dep[s[x]] - dep[u] - k)]++; stk[++top] = y;
				if (t[x] != u) b[J(t[x], dep[t[x]] - dep[u] - 1)].push_back(y);
			}
			if (dep[t[x]] - dep[u] >= k)
			{
				ans += A[y = J(t[x], dep[t[x]] - dep[u] - k)]++; stk[++top] = y;
				if (s[x] != u) b[J(s[x], dep[s[x]] - dep[u] - 1)].push_back(y);
			}
		}
		while (top--) A[stk[top + 1]] = 0; top = 0;
		for (int e = adj[u], v; e; e = nxt[e])
		{
			if ((v = go[e]) == fa[u][0]) continue;
			for (int i = 0; i < b[v].size(); i++)
				ans -= A[y = b[v][i]]++, stk[++top] = y;
			while (top--) A[stk[top + 1]] = 0; top = 0;
		}
	}
	init(1, 0); DFS(1, 0);
	return std::cout << ans << std::endl, 0;
}
原文地址:https://www.cnblogs.com/xyz32768/p/12738207.html