[CF794G] Replace All【组合数学】【数论】

哎,我是不是有场 CF 没写题解,咕了咕了(

先考虑对于没有 ? 的情况,即已知了一种将 (c, d) 中的 ? 替换为字母的方案后,如何求合法的 (01) 串二元组 ((s, t)) 数量。

(a_x, b_x) 分别表示字符串 (x)AB 的数量,考虑以下两种情况:

  1. (a_c< a_d)(b_c> b_d)
  2. (a_c>a_d)(b_c<b_d)
  3. (a_c=a_d)(b_c=b_d)

注意到 (|s|, |t|geq 1),因此对于这三种以外的情况,无论 (|s|, |t|) 如何安排,都不可能让替换完的 (01) 序列长度相等。而前两种情况实际上是对称的,因此这里只考虑第一种。

为了让替换完的 (01) 序列长度相等,可以得到 ((a_d-a_c)|s|=(b_c-b_d)|t|)。设 (g=gcd(a_d-a_c, b_c-b_d)),不难发现,(|s|) 的最小值 (m_s=frac{b_c-b_d}{g}),对应的 (|t|) 的最小值 (m_t=frac{a_d-a_c}{g}),不妨设 (m_tleq m_s)

由于 (gcd(m_s, m_t)=1),不难发现,任意的 (xin [0, m_s)),都有唯一的 (yin [0, m_s)),使得 (ycdot m_tequiv xpmod {m_s})。换句话说就是,对于一个连续的 (m_t)(s) 组成的序列,我们不断地从最前面截取长度为 (m_t) 的一段,截取下来的恰好是 (s) 的每一个循环(即,将 (s) 的某个前缀取下来,原封不动地接到 (s) 的后面)。不难发现,在这种情况下,(s, t) 只有两种方案:全由 (0) 组成,或全由 (1) 组成。不难发现,如果将这个序列复制正整数遍,再在某些位置插入一些 (s)(t),只要在截取的时候将这些串单独截取掉,上面的性质仍然成立。

上面一段的最后一句话,实际上表示了,对于任意一种满足情况 (1) 的(情况 (2) 类似),确定了 (a_c, a_d, b_c, b_d)(c, d) 串的安排方式,我们只关心 (a_d-a_c)(b_c-b_d) 的值,而并不关心这些 AB 的具体位置。并且,我们能以此算出对应的 (s, t) 的方案数。具体来说,首先 (|s|, |t|) 分别是 (m_s, m_t) 的倍数,且 (|s|=kcdot m_s, |t|=kcdot m_t) 时,我们有 (2^k) 种方案(因为不互质的部分无法用上面的方式取到,所以这些位置恰好分成了 (k) 个独立取值的连通块)。这是一个简单的等比数列求和。

因此,我们可以暴力枚举 (a_d-a_c) 的取值。不难发现,由于 ?(c, d) 中的数量是确定的,我们可以直接算出 (b_c-b_d) 的值。还有一个问题是,可能会有多组 ((a_c, a_d, b_c, b_d))。但是感性理解一下可以发现,如果要保持 (a_d-a_c)(b_c-b_d) 不变,那么每在 (c) 中多将一个 ? 替换成 A,就必须在 (d) 中少将一个 ? 替换成 B。换句话说,(a_c+b_d) 的值是确定的,并且这种情况下,将 ? 替换为字母的方案数是一个只与 (a_c+b_d) 相关的组合数(实际上它就是范德蒙德恒等式)。

现在还剩情况 (3)。对于情况 (3),发现 (|s|, |t|) 已经没有了限制,因此我们要对每一个 ((|s|, |t|)) 求出合法的 (s, t) 数量。实际上,与上面的讨论差不多,如果 (gcd(|s|, |t|)=1),可以证明 (s, t) 也是要么全 (0),要么全 (1) 的。拓展一下,可以发现我们要求的就是 (sum_{|s|=1}^{n}sum_{|t|=1}^n 2^{gcd(|s|, |t|)})。这是非常简单的莫反,此处不再赘述。而将 ? 替换为字母的方案数,实际上与上面是相同的。

但是,这里还有最后一个坑点。如果 (c=d),也就是说每个位置的 A/B 都相等,那么任意一组长度不超过 (n)((s, t)) 显然都是满足条件的。因此这一部分要从上面扣除,单独计算。

Code:

#include <bits/stdc++.h>
#define R register
#define mp make_pair
#define ll long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7, N = 310000, M = N << 1;

int n, m, k, sa, sb, ta, tb, sc, tc, ispr[N], mu[N];
ll fac[M], inv[M];
char s[N], t[N];
vector<int> prime;

inline int addMod(int a, int b) {
	return (a += b) >= mod ? a - mod : a;
}

inline ll quickpow(ll base, ll pw) {
	ll ret = 1;
	while (pw) {
		if (pw & 1) ret = ret * base % mod;
		base = base * base % mod, pw >>= 1;
	}
	return ret;
}

template <class T>
inline void read(T &x) {
	x = 0;
	char ch = getchar(), w = 0;
	while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
	while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
	x = w ? -x : x;
	return;
}

inline void initComb(int n) {
	fac[0] = 1;
	for (R int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % mod;
	inv[n] = quickpow(fac[n], mod - 2);
	for (R int i = n; i; --i) inv[i - 1] = inv[i] * i % mod;
}

inline ll comb(int n, int m) {
	if (m < 0 || n < m) return 0;
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

int getGcd(int a, int b) {
	return b ? getGcd(b, a % b) : a;
}

inline ll calc(int x) {
	return addMod(quickpow(2, x + 1), mod - 2);
}

inline int sign(int x) {
	return x < 0 ? -1 : x > 0;
}

void initPrime(int n) {
	mu[1] = 1;
	for (R int i = 2, k; i <= n; ++i) {
		if (!ispr[i])
			mu[i] = -1, prime.push_back(i);
		for (auto &j : prime) {
			if ((k = i * j) > n) break;
			ispr[k] = 1;
			if (i % j == 0) break;
			mu[k] = addMod(mod, -mu[i]);
		}
	}
	return;
}

inline ll sq(ll x) {
	return x * x % mod;
}

int main() {
	scanf("%s%s", s + 1, t + 1), read(k);
	n = strlen(s + 1), m = strlen(t + 1);
	if (n < m) swap(s, t), swap(n, m);
	initComb(n + m), initPrime(k);
	for (R int i = 1; i <= n; ++i)
		sa += s[i] == 'A', sb += s[i] == 'B', sc += s[i] == '?';
	for (R int i = 1; i <= m; ++i)
		ta += t[i] == 'A', tb += t[i] == 'B', tc += t[i] == '?';
	ll ans = 0;
	for (R int i = sa - ta - tc; i <= sa - ta + sc; ++i) {
		int j = m - n + i;
		if (i == 0 && j == 0) {
			ll w = 0, pw = 1;
			for (R int d = 1; d <= k; ++d) {
				pw = addMod(pw, pw);
				for (R int u = 1, v = d; v <= k; ++u, v += d)
					w = (w + pw * mu[u] % mod * sq(k / v)) % mod;
			}
			pw = 1;
			for (R int d = 1; pw && d <= n; ++d) {
				if (s[d] == '?' && t[d] == '?')
					pw = addMod(pw, pw);
				else if (s[d] != '?' && t[d] != '?' && s[d] != t[d])
					pw = 0;
			}
			ans = (ans + w * (comb(sc + tc, ta + tc - sa + i) + mod - pw)) % mod;
			ans = (ans + pw * sq(calc(k))) % mod;
		}
		if (sign(i) * sign(j) != 1) continue;
		ans = (ans + calc(k / ((i < 0 ? j : i) / getGcd(i, j))) * comb(sc + tc, ta + tc - sa + i)) % mod;
	}
	cout << ans << endl;
	return 0;
}
原文地址:https://www.cnblogs.com/suwakow/p/12748947.html