题解-CF755G PolandBall and Many Other Balls

题面

CF755G PolandBall and Many Other Balls

给定 (n)(m)。有一排 (n) 个球,求对于每个 (1le kle m),选出 (k) 个球或相邻的球不能重复的方案数。

数据范围:(1le nle 10^9)(1le m<2^{15})


路标

这题是老经典题了,前人的描述也足够充分了。

但是蒟蒻尝试了这题的 (3) 种做法并记在笔记中后还是忍不住去挣咕值分享给大家。

这是的三种做法对比图(从下到上依次是倍增((Theta(nlog^2 n)))、组合容斥((Theta(nlog n)))和特征方程((Theta(nlog n)))):

图片可能挂了

代码中的 (n)(m) 可能和题面中不是同个东西 /wul


组合容斥

考虑枚举两个球的组的数量,这应该是最容易想到的式子了:

[egin{aligned} f(k)&=sum_{i=0}^k {n-i choose k}{kchoose i}\ end{aligned} ]

可是这个东西很难推,于是重定义它的组合意义:

先从前 (k) 个选任意个,然后再从剩下的(也可以是前 (k) 个!)当中选 (k) 个的方案数。

抓住容斥切口:前后没有重复。

(p(i)) 为重复 (i) 个剩下不重复的方案数。

(q(i)) 为钦定重复 (i) 个剩下随意的方案数。

[q(i)=sum_{x=i}^k {xchoose i} p(x)Longleftrightarrow p(i)=sum_{x=i}^k (-1)^{x-i} {xchoose i} q(x) ]

还可以用“从前 (k) 个中先把两次重复的 (i) 个选了,然后对于前面剩下的每个随便选不选,后面的选 (k-i) 个”得出 (q(i))

[q(i)={kchoose i}{n-ichoose k-i}2^{k-i} ]

所以上式等于

[egin{aligned} f(k)=p(0)&=sum_{i=0}^k(-1)^i q(i)\ &=sum_{i=0}^k(-1)^i {kchoose i}{n-ichoose k-i}2^{k-i}\ &=sum_{i=0}^k(-1)^i frac{k!}{i!(k-i)!}cdot frac{(n-i)!}{(k-i)!(n-k)!}2^{k-i}\ &=frac{k!}{(n-k)!}sum_{i=0}^k frac{(-1)^i(n-i)!}{i!}cdot frac{2^{k-i}}{((k-i)!^2)}\ end{aligned} ]

(n) 太大了,不能直接卷,得化下降幂:

[egin{aligned} f(k)&=frac{k!}{(n-k)!}sum_{i=0}^k frac{(-1)^i(n-i)!}{i!}cdot frac{2^{k-i}}{((k-i)!^2)}\ &=frac{k!n!}{(n-k)!}sum_{i=0}^k frac{(-1)^i(n-i)!}{i!n!}cdot frac{2^{k-i}}{((k-i)!^2)}\ &=k!n^{underline k}sum_{i=0}^k frac{(-1)^i}{i!n^{underline i}}cdot frac{2^{k-i}}{((k-i)!^2)}\ end{aligned} ]


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1<<15,pN=1<<16,mN=N;
int n,m,a[pN],b[pN];

//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
    for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
    return res;
}
int fac[mN],ifac[mN],dfac[mN],idfac[mN],pw=1;

//Poly
const int G=3,iG=Pow(3,mod-2);
int pn,rev[pN];
#define cle(p) fill((p)+n,(p)+pn,0)
int up2(int n){return 1<<int(ceil(log2(n)));}
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
    R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
    for(int mid=1;mid<pn;mid<<=1)
        for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
            for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
                x=p[j],y=1ll*p[mid|j]*w%mod,fmod(p[j]+=y-mod),fmod(p[mid|j]=x-y);
    if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*in*p[i]%mod;}
}

//Main
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>m>>n,++n,pn=up2(n<<1),revn();
    fac[0]=dfac[0]=1;
    R(i,n-1) fac[i+1]=(1ll+i)*fac[i]%mod;
    R(i,min(n-1,m))dfac[i+1]=1ll*(m-i)*dfac[i]%mod;
    ifac[n-1]=Pow(fac[n-1],mod-2);
    idfac[min(n-1,m)]=Pow(dfac[min(n-1,m)],mod-2);
    L(i,n-1) ifac[i]=(1ll+i)*ifac[i+1]%mod;
    L(i,min(n-1,m)) idfac[i]=1ll*(m-i)*idfac[i+1]%mod;
    R(i,n) b[i]=1ll*ifac[i]*ifac[i]%mod*pw%mod,fmod(pw=(pw<<1)-mod),
        a[i]=1ll*ifac[i]*idfac[i]%mod,(i&1)&&(fmod(a[i]=-a[i]),true);
    ntt(a,false),ntt(b,false);
    R(i,pn) a[i]=1ll*a[i]*b[i]%mod; ntt(a,true),cle(a);
    R(i,n) a[i]=1ll*a[i]*fac[i]%mod*dfac[i]%mod;
    R(i,n)if(i) cout<<a[i]<<' '; cout<<'
'; 
    return 0;
}

倍增

(f_{n,k}) 表示前 (n) 个球分成 (k) 组的方案数,简单的 dp 式:

[f_{n,k}=f_{n-1,k-1}+f_{n-2,k-1}+f_{n-1,k} ]

(F_n(x)=sum_{k=0}^n f_{n,k}x^k),可通过上式推出:

[F_n(x)=(1+x)F_{n-1}(x)+xF_{n-2}(x) ]

这个东西在这里还没用,但是后面特征方程做法的时候会用到。

考虑将两排球合并,分类讨论中间的断口是否切开了一组相邻的球:

[F_{n+m}(x)=F_n(x)F_m(x)+xF_{n-1}(x)F_{m-1}(x) ]

然后就可以倍增卷积了,要维护 (F_{2^k},F_{2^k-1},F_{2^k-2})

[egin{cases}F_{2n}=F_{n}^2+xF_{n-1}^2\F_{2n-1}=F_{n}F_{n-1}+xF_{n-1}F_{n-2}\F_{2n-2}=F_{n-1}^2+xF_{n-2}^2\end{cases} ]

然后设当前的 (F_{m}) 要与倍增数组合并,其实只需要维护 (F_{m})(F_{m-1}) 即可。

[egin{cases}F_{2^k+m}=F_{2^k}F_{m}+xF_{2^k-1}F_{m-1}\F_{2^k+m-1}=F_{2^k-1}F_{m}+xF_{2^k-2}F_{m-1}\end{cases} ]


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1<<15,pN=1<<16,D=31;
int m,n,f[3][pN],ns[2][pN],tmp[3][pN];

//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
    for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
    return res;
}

//Poly
const int G=3,iG=Pow(G,mod-2);
int pn,rev[pN];
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
    R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
    for(int mid=1;mid<pn;mid<<=1)
        for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
            for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
                x=p[j],y=1ll*w*p[mid|j]%mod,fmod(p[j]=x+y-mod),fmod(p[mid|j]=x-y);
    if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*p[i]*in%mod;}
}

//Main
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>m>>n,++n,pn=1;
    while(pn<(n<<1)) pn<<=1; revn();
    R(d,D){
        if(d==0) f[0][0]=f[0][1]=f[1][0]=1,ns[0][0]=1;
        else {
            R(t,3)R(i,pn) tmp[t][i]=(i<n?f[t][i]:0);
            R(t,3) ntt(tmp[t],false);
            R(i,pn){
                f[0][i]=1ll*tmp[0][i]*tmp[0][i]%mod;
                f[1][i]=1ll*tmp[1][i]*tmp[0][i]%mod;
                f[2][i]=1ll*tmp[1][i]*tmp[1][i]%mod;
                tmp[0][i]=1ll*tmp[1][i]*tmp[1][i]%mod;
                tmp[1][i]=1ll*tmp[1][i]*tmp[2][i]%mod;
                tmp[2][i]=1ll*tmp[2][i]*tmp[2][i]%mod;
            }
            R(t,3) ntt(tmp[t],true),ntt(f[t],true);
            R(t,3)R(i,n-1) fmod(f[t][i+1]+=tmp[t][i]-mod);
            R(t,3) fill(f[t]+n,f[t]+pn,0);
        }
        if(m>>d&1^1) continue;
        // cout<<"@DEBUG d="<<d<<'
';
        // R(i,pn) cout<<f[0][i]<<' ';cout<<'
';
        R(t,3)R(i,pn) tmp[t][i]=(i<n?f[t][i]:0);
        R(t,3) ntt(tmp[t],false);
        R(t,2) ntt(ns[t],false);
        R(i,pn){
            int t[2]={ns[0][i],ns[1][i]};
            ns[0][i]=1ll*t[0]*tmp[0][i]%mod;
            ns[1][i]=1ll*t[0]*tmp[1][i]%mod;
            tmp[0][i]=1ll*tmp[1][i]*t[1]%mod;
            tmp[1][i]=1ll*tmp[2][i]*t[1]%mod;
        }
        R(t,2) ntt(ns[t],true),ntt(tmp[t],true);
        R(t,2)R(i,n-1) fmod(ns[t][i+1]+=tmp[t][i]-mod);
        R(t,2) fill(ns[t]+n,ns[t]+pn,0);
    }
    R(i,n)if(i) cout<<ns[0][i]<<' '; cout<<'
';
    return 0;
}

特征方程

参考倍增卷积做法的前半部分:

[F_n(x)-(1+x)F_{n-1}(x)-xF_{n-2}(x)=0 ]

这东西可以多项式特征方程:设有多项式 (a,b) 满足:

[F_n(x)-bF_{n-1}(x)=a(F_{n-1}(x)-bF_{n-2}(x))\ herefore F_n(x)-(a+b)F_{n-1}(x)+abF_{n-2}(x)=0\ ]

已知 ((a+b)=(x+1),ab=-x),所以 (a,b) 是特征方程的两根:

[A^2-(x+1)A-x=0\ A=frac{x+1pm sqrt{(x+1)^2+4x}}{2}\ A=frac{x+1pm sqrt{x^2+6x+1}}{2}\ ]

(a=frac{x+1+sqrt{x^2+6x+1}}{2})(b=frac{x+1-sqrt{x^2+6x+1}}{2})

[F_n(x)-bF_{n-1}(x)=a(F_{n-1}(x)-bF_{n-2}(x))\ F_n(x)-bF_{n-1}(x)=a^{n-1}(F_1(x)-bF_0(x))\ ]

由于 (Delta>0),那个根号肯定 (>0)(a eq b)

[egin{cases} F_n(x)-bF_{n-1}(x)=a^{n-1}(F_1(x)-bF_0(x))\ F_n(x)-aF_{n-1}(x)=b^{n-1}(F_1(x)-aF_0(x))\ end{cases}\ a(F_n(x)-bF_{n-1}(x))-b(F_n(x)-aF_{n-1}(x))\ =a^{n}(F_1(x)-bF_0(x))-b^{n}(F_1(x)-aF_0(x))\ (a-b)F_n(x)=a^{n}(F_1(x)-bF_0(x))-b^{n}(F_1(x)-aF_0(x))\ ]

加点文字防止看得眼花 /qq

[egin{aligned} F_n&=frac{1}{a-b}(a^{n}(F_1(x)-bF_0(x))-b^{n}(F_1(x)-aF_0(x)))\ &=frac{1}{a-b}(a^n(x+1-b)-b^n(x+1-a))\ &=frac{a^{n+1}-b^{n+1}}{a-b}\ &=frac{left(frac{x+1+sqrt{x^2+6x+1}}{2} ight)^{n+1}-left(frac{x+1-sqrt{x^2+6x+1}}{2} ight)^{n+1}}{sqrt{x^2+6x+1}}^{color{orange}{(1)}}\ &equiv frac{left(frac{x+1+sqrt{x^2+6x+1}}{2} ight)^{n+1}}{sqrt{x^2+6x+1}}pmod{x^{n+1}}\ end{aligned} ]

然后带上多项式全家桶就可以做了。


当时推到 (color{orange}{(1)}) 那里的时候其实已经可以做了,然后蒟蒻点开题解发现和鰰的不一样,蒟蒻不是很懂为什么

[left(frac{x+1-sqrt{x^2+6x+1}}{2} ight)^{n+1}equiv 0pmod {x^{n+1}} ]

然后 qq 群中的 Elegia 队长还有鰰现场回复:因为这个式子常数项为 (0)

学数学学傻了,那个多项式 (x^2+6x+1) 开根后常数项肯定是 (1),其实是可以和前面的 (1) 抵消的。


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1<<15,pN=1<<16,mN=pN;
int n,m,pm,a[pN],b[pN],c[pN];

//Math
const int mod=998244353;
void fmod(int &x){x+=x>>31&mod;}
int Pow(int a,int x,int res=1){
    for(;x;x>>=1,a=1ll*a*a%mod) (x&1)&&(res=1ll*res*a%mod);
    return res;
}
int miv[mN];
void math_init(){
    miv[1]=1;
    R(i,mN)if(i>=2) fmod(miv[i]=-1ll*mod/i*miv[mod%i]%mod);
}

//Poly
const int G=3,iG=Pow(G,mod-2);
int pn,rev[pN],red[pN],blue[pN],yel[pN];
#define cle(p) fill((p)+n,(p)+pn,0)
int up2(int n){return 1<<int(ceil(log2(n)));}
void revn(){R(i,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));}
void ntt(int* p,bool t){
    R(i,pn)if(i<rev[i]) swap(p[i],p[rev[i]]);
    for(int mid=1;mid<pn;mid<<=1)
        for(int wn=Pow(t?iG:G,(mod-1)/(mid<<1)),i=0;i<pn;i+=mid<<1)
            for(int j=i,w=1,x,y;j<(mid|i);j++,w=1ll*w*wn%mod)
                x=p[j],y=1ll*w*p[mid|j]%mod,fmod(p[j]=x+y-mod),fmod(p[mid|j]=x-y);
    if(t){int in=Pow(pn,mod-2); R(i,pn) p[i]=1ll*p[i]*in%mod;}
}
void deint(int* p,bool t){
    if(t){L(i,pn)if(i) p[i]=1ll*p[i-1]*miv[i]%mod; p[0]=0;}
    else {R(i,pn-1) p[i]=1ll*p[i+1]*(i+1)%mod; p[pn-1]=0;}
}
void mul(int* p,int* q,int n){
    pn=n<<1,revn(),cle(p),cle(q),ntt(p,false),ntt(q,false);
    R(i,pn) p[i]=1ll*p[i]*q[i]%mod; ntt(p,true),cle(q);
}
void inv(int* p,int* q,int n){
    if(n==1) return void(q[0]=Pow(p[0],mod-2));
    inv(p,q,n>>1),pn=n<<1,revn(),copy(p,p+n,red);cle(red),cle(q),ntt(q,false),ntt(red,false);
    R(i,pn) fmod(q[i]=(-1ll*red[i]*q[i]%mod+2)*q[i]%mod); ntt(q,true),cle(q);
}
void ln(int* p,int* q,int n){
    inv(p,q,n),pn=n<<1,revn(),copy(p,p+n,red),cle(red),cle(q);
    deint(red,false),ntt(q,false),ntt(red,false);
    R(i,pn) q[i]=1ll*q[i]*red[i]%mod; ntt(q,true),deint(q,true),cle(q);
}
void exp(int* p,int* q,int n){
    if(n==1) return void(q[0]=1);
    exp(p,q,n>>1),ln(q,blue,n),pn=n<<1,revn(),fmod(blue[0]-=1);
    R(i,n) fmod(blue[i]=p[i]-blue[i]); ntt(q,false),ntt(blue,false);
    R(i,pn) q[i]=1ll*q[i]*blue[i]%mod; ntt(q,true),cle(q),cle(blue);
}
void sqrt(int* p,int* q,int n){
    if(n==1) return void(q[0]=1);
    sqrt(p,q,n>>1),inv(q,blue,n),pn=n<<1,revn(),copy(p,p+n,yel);
    cle(yel),cle(blue),ntt(blue,false),ntt(yel,false);
    R(i,pn) blue[i]=1ll*blue[i]*yel[i]%mod; ntt(blue,true),cle(blue);
    R(i,n) fmod(q[i]+=blue[i]-mod),q[i]=1ll*(mod+1)/2*q[i]%mod; cle(q);
}
void pow(int* p,int* q,int x,int n){
    ln(p,yel,n); R(i,n) yel[i]=1ll*yel[i]*x%mod; exp(yel,q,n);
}

//Main
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>m>>n,++n,math_init(),pm=up2(n);
    a[0]=1,a[1]=6,a[2]=1,sqrt(a,b,pm),inv(b,a,pm);
    fmod(b[0]+=1-mod),fmod(b[1]+=1-mod);
    R(i,pm) b[i]=499122177ll*b[i]%mod;
    pow(b,c,(m+1)%mod,pm),mul(a,c,pm);
    R(i,n)if(i) cout<<(i<=m?a[i]:0)<<' '; cout<<'
';
    return 0;
}

祝大家学习愉快!

原文地址:https://www.cnblogs.com/George1123/p/14073540.html