CodeForces 794 G.Replace All

CodeForces 794 G.Replace All

解题思路

首先如果字符串 (A, B) 没有匹配,那么二元组 ((S, T)) 合法的一个必要条件是存在正整数对 ((x,y)),使得 (xS=yT),其中 (xS) 是将字符串 (S) 复制 (x) 遍后得到的字符串,(yT) 是将字符串 (T) 复制 (T) 遍后得到的字符串。由于 (A,B) 直接匹配的情况比较容易讨论,下面没有特殊说明,都是 (A,B) 没有直接匹配的情况。

这个条件的实际意义是通过这个二元组 ((S,T)) 转化后,能将 (x)('a') 组成的子串与 (y)('b') 组成的子串通配,必要性可以根据这个感性理解,下面来证明满足这个条件的二元组的一些性质。

对于字符串 (C=xS=yT) ,显然存在周期 (x) 和周期 (y) ,根据周期定理 (gcd(x,y)) 也是 (C) 的一个周期,其也是 (S,T) 的一个周期,我们令 (C[1:gcd(x,y)]=D) ,那么 (S=dfrac{y}{gcd(x,y)}D,T=dfrac{x}{gcd(x,y)}D) 。用 (+) 表示字符串的拼接,可以得到 (S+T=T+S=dfrac{xy}{gcd(x,y)}D) 。也就是用这个二元组转化后,任意两个相邻的字符 ('a','b') 交换后得到的字符串不变,最终的字符串只与字符 ('a','b') 的数量用关。

假设将 ('?') 填好之后,令 (Delta a) 表示 (A)('a') 的数量与 (B)('a') 的数量之差,(Delta b) 表示 (A)('b') 的数量与 (B)('b') 的数量之差,此时如果 (Delta a Delta bgeq0)(Delta a,Delta b) 不同时等于 (0) ,那么不存在满足条件的合法二元组。

如果 (Delta a=0,Delta b=0) ,那么任意一个满足条件的合法二元组都可以,其中一个 (gcd(x,y)=g) 的合法二元组的方案数就是 (2^g) (考虑字符串 (D) 的每一位是怎么填的即可),那么只需要容斥出所有 (gcd(x,y)=i) 的对数即可。

否则,满足条件的 (x, y) 的比值是 (dfrac{|Delta a|}{|Delta b|}) ,令 (g=gcd(Delta a,Delta b)) ,枚举 (D) 的长度,方案数就是

[2^{dfrac{n}{dfrac{max(|Delta a|,|Delta b|)}{g}}+1}-2 ]

现在考虑 ('?') 的影响,令 (cntA)(A)('?') 个数,(cntB)(B)('?') 个数,那么填 ('?') 的影响就是让 (Delta a) 加上一个整数 (d)(Delta b) 加上 (cntA-cntB-d),这样选的方案数是 ({cntA+cntB}choose{cntB+d}) ,推导可以把方案数的和式列出来然后展开,当然你要做卷积也是可以的。(模数 (10^9+7)) ,然后只需要枚举一下 ('?') 贡献的 (d) 的值这一部分就算出来了。

注意前面提到的都是 (A,B) 没有直接匹配的情况,对于 (A,B) 在填完 ('?') 之后能直接匹配的情况,所有的二元组都是合法的,只需要把之前没算的部分在这里算上即可,总的复杂度是 (mathcal O(nlog n))

code

/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
	int ch = 0, f = 0; x = 0;
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
	for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
	if(f) x = -x;
}
const int N = 1000005, mod = 1e9 + 7;
char s[N], t[N];
int inv[N], js[N], pw[N], f[N], cnts, cntt, da, db, lens, lent, n, ans, total;
inline int Pow(int a, int b){
	int ans = 1;
	for(; b; b >>= 1, a = 1ll * a * a % mod)
		if(b & 1) ans = 1ll * ans * a % mod;
	return ans;
}
inline int C(int x, int y){ 
	return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod; 
}
int main(){
	scanf("%s", s + 1), scanf("%s", t + 1), read(n);
	lens = strlen(s + 1), lent = strlen(t + 1);
	for(int i = 1; i <= lens; i++){
		if(s[i] == 'A') da++; if(s[i] == 'B') db++; if(s[i] == '?') cnts++;
	}
	for(int i = 1; i <= lent; i++){
		if(t[i] == 'A') da--; if(t[i] == 'B') db--; if(t[i] == '?') cntt++;
	}
	inv[0] = js[0] = pw[0] = 1;
	for(int i = 1; i <= n + 1; i++) pw[i] = 2ll * pw[i-1] % mod;
	for(int i = 1; i <= lens + lent; i++) 
		js[i] = 1ll * js[i-1] * i % mod, inv[i] = Pow(js[i], mod - 2);
	for(int i = n; i; i--){
		f[i] = 1ll * (n / i) * (n / i) % mod;
		for(int j = i + i; j <= n; j += i) (f[i] += mod - f[j]) %= mod;
		total = (total + 1ll * f[i] * pw[i] % mod) % mod;
	}
	for(int d = -cntt; d <= cnts; d++){
		int A = da + d, B = db + cnts - cntt - d, x = C(cnts + cntt, cntt + d);
		if(!A && !B) (ans += 1ll * x * total % mod) %= mod;
		if(1ll * A * B >= 0) continue;
		int g = __gcd(abs(A), abs(B)); 
		A = abs(A) / g, B = abs(B) / g;
		(ans += (1ll * x * (pw[n/max(A,B)+1] - 2) + mod) % mod) %= mod;
	}
	if(lens == lent){
		int flag = 1, res = 1;
		for(int i = 1; i <= lens; i++){
			if(s[i] != '?' && t[i] != '?' && s[i] != t[i]) flag = 0;
			if(s[i] == '?' && t[i] == '?') res = 2ll * res % mod;
		}
		if(!flag) return cout << ans << endl, 0;
		(ans += 1ll * res * (1ll * (pw[n+1] - 2) * (pw[n+1] - 2) % mod - total) % mod) %= mod;
		ans = (ans % mod + mod) % mod;
	}
	cout << ans << endl;
	return 0;
}
原文地址:https://www.cnblogs.com/mangoyang/p/10712594.html