「Codeforces 715C」Digit Tree

Description

程序员 ZS 有一棵树,它可以表示为 (n) 个顶点的无向连通图,顶点编号从 (0)(n-1),它们之间有 (n-1) 条边。每条边上都有一个非零的数字。

一天,程序员 ZS 无聊,他决定研究一下这棵树的一些特性。他选择了一个十进制正整数 (M)(gcd(M,10)=1)

对于一对有序的不同的顶点 ((u, v)),他沿着从顶点 (u) 到顶点 (v)的最短路径,按经过顺序写下他在路径上遇到的所有数字(从左往右写),如果得到一个可以被 (M) 整除的十进制整数,那么就认为 ((u,v)) 是有趣的点对。

帮助程序员 ZS 得到有趣的对的数量。

Hint

  • (1le nle 10^5)
  • (1le mle 10^9,gcd(m, 10) = 1)
  • (1le ext{边权} < 10)

Solution

这种树上路径的统计问题基本都是 点分治,而点分治的重点和难点就是如何 统计经过分治中心的满足条件的路径的个数

这里采用 容斥法:即现分治中心为 (s),当前答案等于整个子树 (s) 的答案减去以 (s) 各个子结点为根的子树的答案。

考虑如何统计。

我们设有一条路径是 (x ightarrow y),分治中心为 (s),路径 (x ightarrow s) 对应的数字为 (pd)(s ightarrow y) 对应 (nd)(s)(y) 的距离为 (l)

那么只有 (pd imes 10^l + nd equiv 0 pmod m) 成立时满足要求。

变形一下:(pd equiv -nd imes 10^{-l}pmod m)

于是我们可以这样搞:把所有的 (pd)map 存起来,记录一下个数,用 pair 数组把 ((nd, l)) 记录下来。

导入所有了路径信息后,枚举 pair 数组,查找 map 中的元素配对即可。

预处理一下 (10) 的幂及其逆元的话,时间复杂度 (O(nlog^2 n))如果用 Hash Table 可以优化到理论 (O(nlog n)),但没什么必要。

Code

/*
 * Author : _Wallace_
 * Source : https://www.cnblogs.com/-Wallace-/
 * Problem : Codeforces 715E Digit Tree
 */
#include <cstdio>
#include <map>
#include <utility>
#include <vector>

using namespace std;
const int N = 1e5 + 5;

namespace Inv {
	void extgcd(long long a, long long b, long long& x, long long& y) {
		if (!b) x = 1, y = 0;
		else extgcd(b, a % b, y, x), y -= a / b * x;
	}
	inline long long get(long long b, long long p) {
		long long x, y;
		extgcd(b, p, x, y);
		x = (x % p + p) % p;
		return x;
	}
}

int n, m;
long long p10[N], invp[N];
long long ans;
struct edge { int to, len; };
vector<edge> G[N];

int root;
int maxp[N], size[N];
bool centr[N];

int getSize(int x, int f) {
	size[x] = 1;
	for (auto y : G[x])
		if (!centr[y.to] && y.to != f)
			size[x] += getSize(y.to, x);
	return size[x];
}
void getCentr(int x, int f, int t) {
	maxp[x] = 0;
	for (auto y : G[x])
		if (!centr[y.to] && y.to != f) {
			getCentr(y.to, x, t);
			maxp[x] = max(maxp[x], size[y.to]);
		}
	maxp[x] = max(maxp[x], t - size[x]);
	if (maxp[x] < maxp[root]) root = x;
}

vector<pair<long long, int> > dat;
map<long long, int> cnt;

void getData(int x, int f, long long pd, long long nd, int dep) {
	if (dep >= 0) cnt[pd]++, dat.push_back(make_pair(nd, dep));
	for (auto y : G[x]) {
		if(centr[y.to] || y.to == f) continue;
		long long tpd = (pd + y.len * p10[dep + 1] % m) % m;
		long long tnd = (nd * 10 % m + y.len) % m;
		getData(y.to, x, tpd, tnd, dep + 1);
	}
}

inline long long count(int x, int d) {
	long long ret = 0;
	cnt.clear(), dat.clear();
	if (d == 0) getData(x, 0, 0, 0, -1);
	else getData(x, 0, d % m, d % m, 0);
	
	for (auto p : dat) {
		long long t = ((-p.first * invp[p.second + 1] % m) + m) % m;
		if (cnt.count(t)) ret += cnt[t];
		if (d == 0 && p.first == 0) ++ret;
	}
	return ret + (d == 0 ? cnt[0] : 0);
}

void solve(int x) {
	maxp[root = 0] = N;
	getCentr(x, 0, getSize(x, 0));
	int s = root; centr[s] = true;
	
	for (auto y : G[s])
		if (!centr[y.to])
			solve(y.to);
	
	ans += count(s, 0);
	for (auto y : G[s])
		if (!centr[y.to])
			ans -= count(y.to, y.len);
	centr[s] = false;
}

signed main() {
	scanf("%d%d", &n, &m);
	for (register int i = 1; i < n; i++)  {
		int u, v, l;
		scanf("%d%d%d", &u, &v, &l);
		++u, ++v;
		G[u].push_back(edge{v, l});
		G[v].push_back(edge{u, l});
	}
	
	p10[0] = 1 % m;
	for (register int i = 1; i <= n; i++)
		p10[i] = p10[i - 1] * 10 % m;
	invp[n] = Inv::get(p10[n], m);
	for (register int i = n - 1; i; i--)
		invp[i] = invp[i + 1] * 10 % m;
	
	ans = 0, solve(1);
	printf("%lld
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/-Wallace-/p/12865700.html