哎,我是不是有场 CF 没写题解,咕了咕了(
先考虑对于没有 ?
的情况,即已知了一种将 (c, d) 中的 ?
替换为字母的方案后,如何求合法的 (01) 串二元组 ((s, t)) 数量。
设 (a_x, b_x) 分别表示字符串 (x) 中 A
,B
的数量,考虑以下两种情况:
- (a_c< a_d) 且 (b_c> b_d)。
- (a_c>a_d) 且 (b_c<b_d)。
- (a_c=a_d) 且 (b_c=b_d)。
注意到 (|s|, |t|geq 1),因此对于这三种以外的情况,无论 (|s|, |t|) 如何安排,都不可能让替换完的 (01) 序列长度相等。而前两种情况实际上是对称的,因此这里只考虑第一种。
为了让替换完的 (01) 序列长度相等,可以得到 ((a_d-a_c)|s|=(b_c-b_d)|t|)。设 (g=gcd(a_d-a_c, b_c-b_d)),不难发现,(|s|) 的最小值 (m_s=frac{b_c-b_d}{g}),对应的 (|t|) 的最小值 (m_t=frac{a_d-a_c}{g}),不妨设 (m_tleq m_s)。
由于 (gcd(m_s, m_t)=1),不难发现,任意的 (xin [0, m_s)),都有唯一的 (yin [0, m_s)),使得 (ycdot m_tequiv xpmod {m_s})。换句话说就是,对于一个连续的 (m_t) 个 (s) 组成的序列,我们不断地从最前面截取长度为 (m_t) 的一段,截取下来的恰好是 (s) 的每一个循环(即,将 (s) 的某个前缀取下来,原封不动地接到 (s) 的后面)。不难发现,在这种情况下,(s, t) 只有两种方案:全由 (0) 组成,或全由 (1) 组成。不难发现,如果将这个序列复制正整数遍,再在某些位置插入一些 (s) 或 (t),只要在截取的时候将这些串单独截取掉,上面的性质仍然成立。
上面一段的最后一句话,实际上表示了,对于任意一种满足情况 (1) 的(情况 (2) 类似),确定了 (a_c, a_d, b_c, b_d) 的 (c, d) 串的安排方式,我们只关心 (a_d-a_c) 和 (b_c-b_d) 的值,而并不关心这些 A
,B
的具体位置。并且,我们能以此算出对应的 (s, t) 的方案数。具体来说,首先 (|s|, |t|) 分别是 (m_s, m_t) 的倍数,且 (|s|=kcdot m_s, |t|=kcdot m_t) 时,我们有 (2^k) 种方案(因为不互质的部分无法用上面的方式取到,所以这些位置恰好分成了 (k) 个独立取值的连通块)。这是一个简单的等比数列求和。
因此,我们可以暴力枚举 (a_d-a_c) 的取值。不难发现,由于 ?
在 (c, d) 中的数量是确定的,我们可以直接算出 (b_c-b_d) 的值。还有一个问题是,可能会有多组 ((a_c, a_d, b_c, b_d))。但是感性理解一下可以发现,如果要保持 (a_d-a_c) 和 (b_c-b_d) 不变,那么每在 (c) 中多将一个 ?
替换成 A
,就必须在 (d) 中少将一个 ?
替换成 B
。换句话说,(a_c+b_d) 的值是确定的,并且这种情况下,将 ?
替换为字母的方案数是一个只与 (a_c+b_d) 相关的组合数(实际上它就是范德蒙德恒等式)。
现在还剩情况 (3)。对于情况 (3),发现 (|s|, |t|) 已经没有了限制,因此我们要对每一个 ((|s|, |t|)) 求出合法的 (s, t) 数量。实际上,与上面的讨论差不多,如果 (gcd(|s|, |t|)=1),可以证明 (s, t) 也是要么全 (0),要么全 (1) 的。拓展一下,可以发现我们要求的就是 (sum_{|s|=1}^{n}sum_{|t|=1}^n 2^{gcd(|s|, |t|)})。这是非常简单的莫反,此处不再赘述。而将 ?
替换为字母的方案数,实际上与上面是相同的。
但是,这里还有最后一个坑点。如果 (c=d),也就是说每个位置的 A/B
都相等,那么任意一组长度不超过 (n) 的 ((s, t)) 显然都是满足条件的。因此这一部分要从上面扣除,单独计算。
Code:
#include <bits/stdc++.h>
#define R register
#define mp make_pair
#define ll long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7, N = 310000, M = N << 1;
int n, m, k, sa, sb, ta, tb, sc, tc, ispr[N], mu[N];
ll fac[M], inv[M];
char s[N], t[N];
vector<int> prime;
inline int addMod(int a, int b) {
return (a += b) >= mod ? a - mod : a;
}
inline ll quickpow(ll base, ll pw) {
ll ret = 1;
while (pw) {
if (pw & 1) ret = ret * base % mod;
base = base * base % mod, pw >>= 1;
}
return ret;
}
template <class T>
inline void read(T &x) {
x = 0;
char ch = getchar(), w = 0;
while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
x = w ? -x : x;
return;
}
inline void initComb(int n) {
fac[0] = 1;
for (R int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % mod;
inv[n] = quickpow(fac[n], mod - 2);
for (R int i = n; i; --i) inv[i - 1] = inv[i] * i % mod;
}
inline ll comb(int n, int m) {
if (m < 0 || n < m) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int getGcd(int a, int b) {
return b ? getGcd(b, a % b) : a;
}
inline ll calc(int x) {
return addMod(quickpow(2, x + 1), mod - 2);
}
inline int sign(int x) {
return x < 0 ? -1 : x > 0;
}
void initPrime(int n) {
mu[1] = 1;
for (R int i = 2, k; i <= n; ++i) {
if (!ispr[i])
mu[i] = -1, prime.push_back(i);
for (auto &j : prime) {
if ((k = i * j) > n) break;
ispr[k] = 1;
if (i % j == 0) break;
mu[k] = addMod(mod, -mu[i]);
}
}
return;
}
inline ll sq(ll x) {
return x * x % mod;
}
int main() {
scanf("%s%s", s + 1, t + 1), read(k);
n = strlen(s + 1), m = strlen(t + 1);
if (n < m) swap(s, t), swap(n, m);
initComb(n + m), initPrime(k);
for (R int i = 1; i <= n; ++i)
sa += s[i] == 'A', sb += s[i] == 'B', sc += s[i] == '?';
for (R int i = 1; i <= m; ++i)
ta += t[i] == 'A', tb += t[i] == 'B', tc += t[i] == '?';
ll ans = 0;
for (R int i = sa - ta - tc; i <= sa - ta + sc; ++i) {
int j = m - n + i;
if (i == 0 && j == 0) {
ll w = 0, pw = 1;
for (R int d = 1; d <= k; ++d) {
pw = addMod(pw, pw);
for (R int u = 1, v = d; v <= k; ++u, v += d)
w = (w + pw * mu[u] % mod * sq(k / v)) % mod;
}
pw = 1;
for (R int d = 1; pw && d <= n; ++d) {
if (s[d] == '?' && t[d] == '?')
pw = addMod(pw, pw);
else if (s[d] != '?' && t[d] != '?' && s[d] != t[d])
pw = 0;
}
ans = (ans + w * (comb(sc + tc, ta + tc - sa + i) + mod - pw)) % mod;
ans = (ans + pw * sq(calc(k))) % mod;
}
if (sign(i) * sign(j) != 1) continue;
ans = (ans + calc(k / ((i < 0 ? j : i) / getGcd(i, j))) * comb(sc + tc, ta + tc - sa + i)) % mod;
}
cout << ans << endl;
return 0;
}