CF623E Transforming Sequence 题解

传送门


【分析】

不难想到,令 \(f_{n, k}\) 表示前 \(n\) 个数,使得 \(b_n\)\(k\)\(1\) 的方案数

则很容易列出转移方程,由于 \(a_i>0\) ,故 \(\displaystyle f_{n, k}=\sum_{i=0}^{k-1} \binom k if_{n-1, i}\cdot 2^i\)

变形得到 \(\displaystyle {f_{n, k}\over k!}x^k=\sum_{i+j=k} {f_{n-1,i}\cdot 2^i\over i!}x^i\cdot {[j>0]\over j!}x^j\)

因此得到生成函数的转移关系 \(\displaystyle \hat F_{n}(x)=\hat F_{n-1}(2x)\cdot (e^x-1)\)

而由 \(f_{0, k}=[k=0]\) 得到 \(\hat F_0(x)=1\)

列举几项得到 \(\hat F_1(x)=e^x-1, \hat F_2(x)=(e^{2x}-1)(e^x-1), \hat F_3(x)=(e^{4x}-1)(e^{2x}-1)(e^x-1)\)

猜想 \(\displaystyle \hat F_n(x)=\prod_{i=0}^{n-1} (e^{2^ix}-1)\) ,若对 \(n\leq m\) 时成立

\(\displaystyle \hat F_{m+1}(x)=\hat F_m(x)(2x)\cdot (e^x-1)=\prod_{i=0}^{m-1}(e^{2^i\cdot 2x}-1)\cdot (e^x-1)=\prod_{i=0}^m (e^{2^ix}-1)\)

由数学归纳法证明猜想正确


设答案为 \(A_{n, k}\) ,则 \(\displaystyle A_{n, k}=\sum_{i=0}^k \binom k i f_{n, i}\Rightarrow \hat A_n(x)=\hat F_n(x)\cdot e^x\)

问题化为如何求解 \(\hat F_n(x)\) 不超时

考虑到 \(\displaystyle \hat F_{2n}(x)=\prod_{i=0}^{2n-1}(e^{2^ix}-1)=\prod_{i=0}^{n-1}(e^{2^ix}-1)\cdot \prod_{i=n}^{2n-1}(e^{2^ix}-1)=\prod_{i=0}^{n-1}(e^{2^ix}-1)\cdot \prod_{i=0}^{n-1}(e^{2^{i+n}x}-1)=\hat F_n(x)\cdot \hat F_n(2^n x)\)

\(\displaystyle F(x)=\sum_{i=0}^\infty f_ix^i\Rightarrow F(Ax)=\sum_{i=0}^\infty f_iA^i\cdot x^i\)

故已知 \(\hat F_n(x)\) ,可以先 \(O(n)\) 求解 \(\hat F(2^n x)\) ,再多项式卷积 \(O(n\log n)\) 求出 \(\hat F_{2n}(x)\)

因而求解 \(\hat F_n(x)\) 时,先递归求解 \(\hat F_{n/2}(x)\) ,若 \(n\) 为奇数,再利用递推式 \(O(n\log n)\) 算出结果

最后再卷上 \(e^x\) 即可得到答案的生成函数

总复杂度 \(\displaystyle T(n)=T({n\over 2})+O(n\log n)=O(n\log n)\) ,常数约为 \(20\) 次 FFT


【代码】

由于需要任意模数,考虑使用 MTT 。但这题貌似精度需求非常高,用 double 会算出问题,得 long double

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef long double db;
#define fi first
#define se second

struct vir{
    db r, i;
    vir(db r_=0, db i_=0):r(r_), i(i_) {}
};
inline vir operator + (const vir &a, const vir &b) { return vir(a.r+b.r, a.i+b.i); }
inline vir operator - (const vir &a, const vir &b) { return vir(a.r-b.r, a.i-b.i); }
inline vir operator * (const vir &a, const vir &b) { return vir(a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r); }
inline vir operator / (const vir &a, db b) { return vir(a.r/b, a.i/b); }
inline vir operator ! (const vir &z) { return vir(z.r, -z.i); }
inline vir operator / (const vir &a, const vir &b) { return a*!b/(b*!b); }
const db pi=acos(-1), eps=1e-6;

const int LimBit=18, M=1<<LimBit;
int a[M], b[M], c[M];
struct NTT{
    int Stk[M], curStk;
    int N, na, nb, P, m, N_;
    inline ll kpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%P) if(x&1) ans=ans*a%P; return ans; }
    NTT(int P_=1e9+7):P(P_){
        m=1<<15;
        curStk=0;
		N_=-1;
    }

    vir w[2][M], invN;
    int rev[M];
    void work(){
		if(N_==N) return ; N_=N;
        int d = __builtin_ctz(N);
        w[0][0] = w[1][0] = vir(1, 0);
        vir x = vir(cos(2*pi/N), sin(2*pi/N)), y = !x;
        for (int i = 1; i < N; i++) {
            rev[i] = (rev[i>>1] >> 1) | ((i&1) << (d-1));
            w[0][i] = x * w[0][i-1] , w[1][i] = y * w[1][i-1];
        }
        invN=vir(1.0/N, 0);
    }

    inline void fft(vir *a,int f){
        static vir x;
        for(int i=0;i<N;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
        for(int i=1;i<N;i<<=1)
            for(int j=0,t=N/(i<<1);j<N;j+=i<<1)
                for(int k=0,l=0;k<i;++k,l+=t)
                    x=w[f][l]*a[j+k+i], a[j+k+i]=a[j+k]-x, a[j+k]=a[j+k]+x;
        if(f) for(int i=0;i<N;++i) a[i]=a[i]*invN;
    }
    inline int mergeNum(ll a1,ll a0,ll b1,ll b0){
        // return ((a1*m+a0+b1)%p*m+b0+p)%p;
        return (( ((a1<<15)+a0+b1)%P<<15 )+b0)%P;
    }
    inline void doit(int *a,int *b,int na,int nb){
        static vir x, y, p0[M], p1[M];
        for(N=1;N<na+nb-1;N<<=1);
        for(int i=na;i<N;++i) a[i]=0;
        for(int i=nb;i<N;++i) b[i]=0;
        for(int i=0;i<N;++i)
            // p1[i]=pdd(a[i]/m,a[i]%m),p0[i]=pdd(b[i]/m,b[i]%m);
            p1[i]=vir(a[i]>>15, a[i]&(m-1)), p0[i]=vir(b[i]>>15, b[i]&(m-1));
        work(); fft(p1,0); fft(p0,0);
        for(int i=0,j;i+i<=N;++i){
            j=(N-1)&(N-i);
            x=p0[i], y=p0[j];
            p0[i]=(x-!y)*p1[i]*vir(0, -0.5);
            p1[i]=(x+!y)*p1[i]*vir(0.5, 0);
            if(i==j) continue;
            p0[j]=(y-!x)*p1[j]*vir(0, -0.5);
            p1[j]=(y+!x)*p1[j]*vir(0.5, 0);
        }
        fft(p1,1); fft(p0,1);
        for(int i=0;i<N;++i)
            a[i]=mergeNum(p1[i].r+0.5, p1[i].i+0.5, p0[i].r+0.5, p0[i].i+0.5);
    }
}ntt;

ll n;
int k, f[M];
inline void init() {
	b[1]=1;
	for(int i=2, P=ntt.P; i<=3e4; ++i)
		b[i]=P-(ll)b[P%i]*(P/i)%P;
	for(int i=2, P=ntt.P; i<=3e4; ++i)
		b[i]=(ll)b[i-1]*b[i]%P;
}
inline void work(ll n) {
	if(n==0) {
		for(int i=0; i<=k; ++i) f[i]=0;
		f[0]=1;
		return ;
	}
	work(n>>1);
	for(int i=0, x=ntt.kpow(2, n>>1), y=1, P=ntt.P; i<=k; ++i)
		a[i]=(ll)f[i]*y%P, y=(ll)y*x%P;
	ntt.doit(f, a, k+1, k+1);

	if(n+1&1) return ;
	for(int i=0, x=2, y=1, P=ntt.P; i<=k; ++i)
		f[i]=(ll)f[i]*y%P, y=(ll)y*x%P;
	ntt.doit(f, b, k+1, k+1);
}
int main() {
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	init();
	cin>>n>>k;
	work(n);
	++b[0];
	ntt.doit(f, b, k+1, k+1);
	for(int i=0, x=1, P=ntt.P; i<=k; ++i)
		f[i]=(ll)f[i]*x%P, x=(ll)x*(i+1)%P;
	cout<<f[k];
	cout.flush();
	return 0;
}
原文地址:https://www.cnblogs.com/JustinRochester/p/15546108.html