Codeforces Round #745 (Div. 1) 1580F. Problems for Codeforces

考虑对于两个相邻的元素,必定只有一个大于等于(lceilfrac m2 ceil),那么对于所有大于等于(lceilfrac m2 ceil)的元素,我们将他减去(lceilfrac m2 ceil),则变成了一个子问题,相邻两个元素的和不能超过(lfloorfrac m2 floor)

考虑如何通过子问题答案推出答案,先不考虑首尾相接的情况,将大于等于(lceilfrac m2 ceil)的设为1,其余设为0,则必定是若干个长度为奇数的0101010...10这样的段落拼接起来,两边可能会有长度为偶数的1010...10和0101...01这样的段拼在左右。

但是很快会发现这样存在问题,那就是当(m)为奇数时,有可能会存在(lfloorfrac m2 floor)这样的单个元素组成一个0跟其他的奇数段拼起来,特殊处理即可。

则我们考虑设长度为奇数的段的生成函数为(A),长度为偶数的段的生成函数为(B),整体的生成函数为(F)则对于(m)为偶数的情况,有:

[F=B^2sum_{i=0}A^i+A=frac {B^2}{1-A}+A ]

注意由于咱们不考虑首尾相接,则可能会有一整段是10101...01的情况,因此要在后面在加上一个(A)

m为奇数的情况,则:

[F=B^2sum_{i=0}(A+x)^i+A=frac {B^2}{1-A-x}+A ]

接下来就是如何算答案。

可以发现对于(n)为奇数的情况,要么就是首尾都是一段奇数段,要么就是首尾拼起来成为了一段奇数段,则本质上相当于一个奇数段拼上了一个偶数段,因为(n)为奇数的情况必定存在两个连续的0。

对于(n)为偶数的情况,则不一定存在两个连续的(0),有可能会有01010101...01这种情况,但倘若我们不断把大于等于(lceilfrac m2 ceil)的元素减去(lceilfrac m2 ceil),则最后要么全部变成了(0),要么必定存在一个情况存在两个连续的(0),则我们对于偶数的情况,在每一层递归都算一遍即可。

复杂度是(O(nlog nlog m)),但好像被锤烂了。

Elegia的高级做法,了解一下:
We can solve F in (O(nlog n)), here is a sketch:

When n is even, a1<m−a2>a3<⋯<m−an>a1. So we need to count the sequences 0≤bi≤m where b1b1. You can do inclusion-exclusion for all "<". It can be solved through computing something like (frac 1{1−P}).

When n is odd, the pattern should be a ring with xy…yx where x≤m2 and y>m2. You can let x←⌈m/2⌉−x and then it becomes a <><>⋯ again.

#include <bits/stdc++.h>

#define I inline
#define fi first
#define se second
#define LL long long
#define mp make_pair
#define reg register int
#define pii pair<int,int>
#define fo(i, a, b) for(int i = a; i <= b; i++)
#define fd(i, a, b) for(reg i = a; i >= b; i--)
#define ULL unsigned long long
#define cr const reg&
using namespace std;
const int inf = 2147483647;
const int mod = 998244353;
const int N = 1e6 + 1;

I int _max(cr x, cr y) {return x > y ? x : y;}
I int _min(cr x, cr y) {return x < y ? x : y;}
I LL read() {
	LL x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
	return x * f;
}
I void ptt(LL x) {if(x >= 10) ptt(x / 10); putchar(x % 10 + '0');}
I void put(LL x) {x < 0 ? putchar('-'), ptt(-x) : ptt(x);}
I void pr1(LL x) {put(x), putchar(' ');}
I void pr2(LL x) {put(x), puts("");}

I int pow_mod(reg a, reg k) {reg ans = 1; for(; k; k >>= 1, a = (LL)a * a % mod) if(k & 1) ans = (LL)ans * a % mod; return ans;}

int n, m, ans;
namespace Poly {
	int w[131073], R[131073]; ULL p[131073];
	int a[131073], b[131073], h[131073], d0[131073], d1[131073];
	
	I int Pre(cr n) {
		reg len = 1; for(; len <= n; len <<= 1);
		for(reg i = 1; i < (len << 1); i <<= 1) {
			reg s = 1, wn = pow_mod(3, (mod - 1) / (i << 1));
			fo(j, 0, i - 1) w[i + j] = s, s = (LL)s * wn % mod;
		} return len;
	}
	
	I int gao(int x) {return x < 0 ? x + mod : x;}
	I void DFT(int y[], cr len) {
		fo(i, 0, len - 1) p[R[i]] = gao(y[i]);
		int b;
		for(reg i = 1; i < len; i <<= 1) for(reg j = 0; j < len; j += i << 1)
			fo(k, 0, i - 1) b = p[i + j + k] * w[i + k] % mod, p[i + j + k] = p[j + k] + mod - b, p[j + k] += b;
		fo(i, 0, len - 1) y[i] = p[i] % mod;
	}
	I void IDFT(int y[], cr len) {
		reverse(y + 1, y + len); DFT(y, len); reg hh = pow_mod(len, mod - 2);
		fo(i, 0, len - 1) y[i] = (LL)y[i] * hh % mod;
	}
	I void clear(int a[], cr len) {memset(a + len, 0, sizeof(a[0]) * len);}
	I void clear(int a[], cr s, cr t) {if(s >= t) return ; memset(a + s, 0, sizeof(a[0]) * (t - s));}
	I void cpy(int a[], int b[], cr len) {memcpy(a, b, sizeof(a[0]) * len), memset(a + len, 0, sizeof(a[0]) * len);}
	I void getinv(int a[], int b[], cr len) {
		if(len == 1) {b[0] = pow_mod(a[0], mod - 2), b[1] = 0; return ;}
		getinv(a, b, len >> 1); cpy(h, a, len), clear(b, len);
		fo(i, 0, (len << 1) - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) ? len : 0);
		DFT(h, len << 1), DFT(b, len << 1);
		fo(i, 0, (len << 1) - 1) b[i] = (2 - (LL)b[i] * h[i]) % mod * b[i] % mod;
		IDFT(b, len << 1); clear(b, len);
	}
	
	I void solve(cr v, cr len) {
		if(v == 1) {
			fo(i, 0, len - 1) a[i] = 1;
			ans = 1;
			return ;
		} solve(v >> 1, len);
		memset(d0, 0, sizeof(int) * (len << 1));
		memset(d1, 0, sizeof(int) * (len << 1));
		fo(i, 0, len - 1) (i & 1 ? d1 : d0)[i] = a[i];
		if(v & 1) d1[1]++, d1[1] = d1[1] >= mod ? d1[1] - mod : d1[1];
		memset(a, 0, sizeof(int) * (len << 1));
		fo(i, 0, len - 1) a[i] = d1[i] ? mod - d1[i] : 0;
		a[0]++, a[0] = a[0] >= mod ? a[0] - mod : a[0];
		getinv(a, b, len);
		if(v == m || !(n & 1)) {
			reg sum = 0;
			for(int i = 1; i <= n; i += 2) sum = (sum + (LL)b[n - i] * d1[i] % mod * i) % mod;
			if(!(n & 1)) ans = (2LL * ans + sum) % mod;
			else ans = sum;
			if(v == m) return ;
		} DFT(d0, len << 1);
		fo(i, 0, (len << 1) - 1) d0[i] = (LL)d0[i] * d0[i] % mod;
		IDFT(d0, len << 1); clear(d0, len);
		DFT(d0, len << 1), DFT(b, len << 1);
		fo(i, 0, (len << 1) - 1) a[i] = (LL)b[i] * d0[i] % mod;
		IDFT(a, len << 1);
		fo(i, 0, len - 1) a[i] = a[i] + d1[i], a[i] = a[i] >= mod ? a[i] - mod : a[i];
		if(v & 1) a[1] ? a[1]-- : a[1] = mod - 1;
	}
}

int main() {
	n = read(), m = read();
	reg len = Poly::Pre(n);
	Poly:: solve(m, len);
	pr2(ans < 0 ? ans + mod : ans);
	return 0;
}
原文地址:https://www.cnblogs.com/xgcxgc/p/15366031.html