FFT&&NTT&&相关

FFT 快速计算多项式乘法 

bzoj3527 力

题目大意:给定qi,求ei=sigma(j<i)qj/(i-j)^2-sigma(j>i)qj/(i-j)^2。

思路:画个表格能发现两个三角都是可以卷积的,要求qj*1/(i-j)^2累加到ei上,但是右上角的部分要倒两次,然后就是fft了。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define LD double
#define N 1000005
using namespace std;
struct use{
    LD r,i;
    void init(LD rr,LD ii){r=rr;i=ii;};
    use operator+(const use&x){return (use){r+x.r,i+x.i};}
    use operator-(const use&x){return (use){r-x.r,i-x.i};}
    use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+x.r*i};}
}a[N],b[N],ai[N],c[N];
LD qi[N],ans[N];int up,l,rev[N]={0},ci[N]={0};
LD sqr(LD x){return x*x;}
void fft(use *a,int f){
    int i,j,k;use w,wn,x,y;
    for (i=0;i<up;++i) ai[i]=a[rev[i]];
    for (i=0;i<up;++i) a[i]=ai[i];
    for (i=2;i<=up;i<<=1){
        wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
        for (j=0;j<up;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=a[k];y=a[k+i/2]*w;
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f==-1)
        for (i=0;i<up;++i) a[i].r/=up*1.;
}
int main(){
    int i,j,n;scanf("%d",&n);
    for (i=0;i<n;++i) scanf("%lf",&qi[i]);
    for (l=0,up=1;up<n;up<<=1,++l);up<<=1;++l;
    for (i=0;i<up;++i){
        int ll=0;
        for (j=i;j;j>>=1) ci[++ll]=j&1;
        for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|ci[j];
    }for (i=0;i<n;++i) a[i].init(qi[i],0.);
    for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.);
    fft(a,1);fft(b,1);
    for (i=0;i<up;++i) c[i]=a[i]*b[i];
    fft(c,-1);for (i=0;i<n;++i) ans[i]=c[i].r;
    memset(a,0,sizeof(a));memset(b,0,sizeof(b));
    for (i=0;i<n;++i) a[i].init(qi[n-1-i],0.);
    for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.);
    fft(a,1);fft(b,1);
    for (i=0;i<up;++i) c[i]=a[i]*b[i];
    fft(c,-1);for (i=0;i<n;++i) ans[i]-=c[n-1-i].r;
    for (i=0;i<n;++i) printf("%.9f
",ans[i]);
}
View Code

codechef COUNTARI

题目大意:给定n个数,求数列中i<j<k且ai、aj、ak呈等差数列的个数。

思路:分块+fft。三个在一个块内的可以len^2,两个在块内一个在外面的也可以len^2,中间点在块内其他在两边的可以fft。

注意:double强转longlong的时候是下取整,所以应该+0.5。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 100005
#define up 30005
#define LL long long
#define LD double
using namespace std;
struct use{
    LD r,i;
    void init(LD rr,LD ii){r=rr;i=ii;}
    use operator+(const use&x){return(use){r+x.r,i+x.i};}
    use operator-(const use&x){return(use){r-x.r,i-x.i};}
    use operator*(const use&x){return(use){r*x.r-i*x.i,r*x.i+x.r*i};}
}a[N],b[N],c[N],A[N];
int ai[N],rev[N],en[N]={0},uu,l;
LL c1[N]={0LL},c2[N]={0LL},cnt[up]={0LL};
void fft(use *a,int f){
    int i,j,k;use w,wn,x,y;
    for (i=0;i<uu;++i) A[i]=a[rev[i]];
    for (i=0;i<uu;++i) a[i]=A[i];
    for (i=2;i<=uu;i<<=1){
        wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
        for (j=0;j<uu;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=a[k];y=w*a[k+i/2];
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f==-1) for (i=0;i<uu;++i) a[i].r/=1.*uu;
}
LL calc(int x){
    int i,j;LL ans=0LL;
    for (i=0;i<uu;++i) a[i].init(c1[i],0.);
    for (i=0;i<uu;++i) b[i].init(c2[i],0.);
    fft(a,1);fft(b,1);
    for (i=0;i<uu;++i) c[i]=a[i]*b[i];
    fft(c,-1);
    for (i=en[x-1]+1;i<=en[x];++i) ans+=(LL)(c[2*ai[i]].r+0.5);
    return ans;}
int main(){
    int n,i,j,k,ci,len,bl;LL ans=0LL;
    scanf("%d",&n);len=2000;bl=(n-1)/len+1;
    for (uu=1,l=0;uu<up;uu<<=1,++l);uu<<=1;++l;
    for (i=0;i<uu;++i){
        for(ci=0,j=i;j;j>>=1) en[++ci]=j&1;
        for(j=1;j<=l;++j) rev[i]=(rev[i]<<1)|en[j];
    }for (i=1;i<=n;++i){
        en[(i-1)/len+1]=i;
        scanf("%d",&ai[i]);
        ++c2[ai[i]];
    }for (i=1;i<=bl;++i){
        for (j=en[i-1]+1;j<=en[i];++j) --c2[ai[j]];
        for (j=en[i-1]+1;j<=en[i];++j){
            for (k=en[i];k>j;--k){
                ci=ai[k]*2-ai[j];
                if (ci>0&&ci<up) ans+=cnt[ci]+c2[ci];
                ++cnt[ai[k]];
                ci=ai[j]*2-ai[k];
                if (ci>0&&ci<up) ans+=c1[ci];
            }for (k=j+1;k<=en[i];++k) --cnt[ai[k]];
        }ans+=calc(i);
        for (j=en[i-1]+1;j<=en[i];++j) ++c1[ai[j]];
    }printf("%I64d
",ans);
}
View Code

codechef PRIMEDST

题目大意:求树上距离为质数的点对的概率。

思路:点分+fft。求距离为k的点对的时候用点分,现在这个k是所有质数,所以可以fft一下。注意有些数组不能清零防止tle;fft的上界可以根据每次的大小进行更改。(太久没写点分结果点分都写残了)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 150000
#define M 50005
#define LL long long
#define LD double
using namespace std;
struct use{
    LD r,i;
    void init(LD rr,LD ii){r=rr;i=ii;}
    use operator+(const use&x){return (use){r+x.r,i+x.i};}
    use operator-(const use&x){return (use){r-x.r,i-x.i};}
    use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};}
}a[N],b[N],c[N],ai[N];
int point[N]={0},next[N]={0},en[N]={0},mn,mx,rt,tot=0,siz[N],rev[N],di[100],
    prime[N]={0},ci[M]={0},up,l,ccc=0;
bool vi[N]={false},flag[N]={false}; LL ans=0LL;
void add(int u,int v){
    next[++tot]=point[u];point[u]=tot;en[tot]=v;
    next[++tot]=point[v];point[v]=tot;en[tot]=u;}
void shai(int n){
    int i,j;
    for (i=2;i<=n;++i){
        if (!flag[i]) prime[++prime[0]]=i;
        for (j=1;j<=prime[0]&&i*prime[j]<n;++j){
            flag[i*prime[j]]=true;
            if (i%prime[j]==0) break;
        }
    }
}
void fft(use *a,int f){
    int i,j,k;use w,wn,x,y;
    for (i=0;i<up;++i) ai[i]=a[rev[i]];
    for (i=0;i<up;++i) a[i]=ai[i];
    for (i=2;i<=up;i<<=1){
        wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
        for (j=0;j<up;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=a[k];y=w*a[k+i/2];
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f==-1) for (i=0;i<up;++i) a[i].r/=up*1.;
}
void grt(int u,int f,int nn){
    int i,v,ms=0;siz[u]=1;
    for (i=point[u];i;i=next[i]){
        if (vi[v=en[i]]||v==f) continue;
        grt(v,u,nn);ms=max(ms,siz[v]);
        siz[u]+=siz[v];
    }ms=max(ms,nn-siz[u]);
    if (ms<=mn){mn=ms;rt=u;}
}
void dfs(int u,int f,int de){
    int i,v;siz[u]=1;
    ++ci[de];mx=max(mx,de);
    for (i=point[u];i;i=next[i]){
        if (vi[v=en[i]]||v==f) continue;
        dfs(v,u,de+1);siz[u]+=siz[v];
    }
}
LL calc(int u,int de){
    int i,j,v;LL cnt=0LL;
    for (i=0;i<=mx;++i) ci[i]=0;
    memset(di,0,sizeof(di));
    mx=0;dfs(u,0,de);mx+=1;
    for (up=1,l=0;up<mx;up<<=1,++l);up<<=1;++l;
    for (i=0;i<up;++i){
        rev[i]=0;
        for (v=0,j=i;j;j>>=1) di[++v]=j&1;
        for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|di[j];
    }for (i=0;i<up;++i){
        v=(i>=M ? 0 : ci[i]);
        a[i].init(v*1.,0.);b[i].init(v*1.,0.);
    }fft(a,1);fft(b,1);
    for (i=0;i<up;++i) c[i]=a[i]*b[i];
    fft(c,-1);
    for (i=1;i<=prime[0]&&prime[i]<up;++i) cnt+=(LL)(c[prime[i]].r+0.5);
    return cnt;}
void work(int u){
    int i,v;vi[u]=true;ans+=calc(u,0);
    for (i=point[u];i;i=next[i]){
      if (vi[v=en[i]]) continue;
      ans-=calc(v,1);
      grt(v,u,mn=siz[v]);work(rt);
    }
}
int main(){
    int n,i,u,v;LL cc;scanf("%d",&n);
    for (i=1;i<n;++i){scanf("%d%d",&u,&v);add(u,v);}
    grt(1,0,mn=n);shai(N);cc=(LL)n*((LL)n-1LL);
    work(rt);printf("%.9f
",(LD)ans*1./(LD)cc);
}
View Code

bzoj3513 idiots

题目大意:给定n个木棍,问能构成三角形的概率。(木棍长度<=2*10^5)

思路:较短的两根的和<=第三根就是不符合的,木棍长度比较小,可以用fft,计算两个的和为x的木棍对数,对于长度为y的,x<=y的对数都是不满足的,但长度为x的对数中除了同一木棍选两次的统计了一次,其他的都统计了两次,所以要相应的减去。最后用(总的-不合法的)/总的就是答案了。

注意:(1)fft清数组的时候,求rev的时候利用的保存二进制的数组也要清零;

   (2)统计答案的时候要注意减掉那些不合法的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 600005
#define LD double
#define LL long long
using namespace std;
struct use{
    LD r,i;
    void init(LD rr,LD ii){r=rr;i=ii;}
    use operator +(const use&x){return (use){r+x.r,i+x.i};}
    use operator -(const use&x){return (use){r-x.r,i-x.i};}
    use operator *(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};}
}a[N],c[N],ai[N];
int rev[N],up,sm[N],cc[N];
LL getc(LL n){return n*(n-1LL)*(n-2LL)/6LL;}
int in(){
    char ch=getchar();int x=0;
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9'){
        x=x*10+ch-'0';ch=getchar();
    }return x;}
void fft(use *aa,int f){
    int i,j,k;use x,y,wn,w;
    for (i=0;i<up;++i) ai[i]=aa[rev[i]];
    for (i=0;i<up;++i) aa[i]=ai[i];
    for (i=2;i<=up;i<<=1){
        wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
        for (j=0;j<up;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=aa[k];y=aa[k+i/2]*w;
                aa[k]=x+y;aa[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f<0) for (i=0;i<up;++i) aa[i].r/=1.*up;
}
int main(){
    int n,i,j,x,mx=0,l=0,t;LL ci,ans;
    t=in();
    while(t--){
        n=in();mx=0;ans=0LL;
        memset(sm,0,sizeof(sm));
        for (i=1;i<=n;++i){
            x=in();++sm[x];
            mx=max(mx,x);
        }++mx;
        for (l=0,up=1;up<mx;up<<=1,++l);up<<=1;++l;
        for (i=0;i<=l;++i) cc[i]=0;
        for (i=0;i<up;++i){
            for (j=i,cc[0]=0;j;j>>=1) cc[++cc[0]]=j&1;
            for (rev[i]=0,j=1;j<=l;++j) rev[i]=(rev[i]<<1)|cc[j];
        }for (i=0;i<mx;++i){
            a[i].init((LD)sm[i],0.);
            if (i) sm[i]+=sm[i-1];
        }for (;i<up;++i){
            sm[i]+=sm[i-1];
            a[i].init(0.,0.);
        }fft(a,1);
        for (i=0;i<up;++i) c[i]=a[i]*a[i];
        fft(c,-1);
        for (ci=0LL,i=0;i<up;++i){
            ci+=(LL)(c[i].r+0.5);
            ans+=(ci-(LL)sm[i/2])/2LL*(LL)(sm[i]-sm[i-1]);
        }printf("%.7f
",1.-(LD)ans*1./(LD)getc((LL)n));
    }
}
View Code

bzoj4503 两个串(!!!)

题目大意:给定s1、s2,s2中有?可以匹配任何小写字母,问s2在s1中出现几次、出现的位置。

思路:考虑一种hash方法:如果没有?,(s2-s1)^2=0的段是s2=s1的段,有了?,可以把?看作0,其他字母是1~26,s2*(s2-s1)^2=0的是匹配段,n比较大,把s2倒过来,用fft计算,在合法区间内取出值为0的就是这一段的结尾了。

注意:点值表达式是可以乘和加的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define LD double
#define N 2000005
using namespace std;
struct use{
    LD u,i;
    void init(LD x,LD y){u=x;i=y;}
    use operator+(const use&x)const{return (use){u+x.u,i+x.i};}
    use operator-(const use&x)const{return (use){u-x.u,i-x.i};}
    use operator*(const use&x)const{return (use){u*x.u-i*x.i,u*x.i+i*x.u};}
}a[N],b[N],c[N],aa[N];
char s1[N],s2[N];
int l1,l2,up,l,rev[N]={0},ai[N]={0};
int idx(char c){return (c=='?' ? 0 : c-'a'+1);}
int sqr(int x){return x*x;}
void fft(use *a,int f){
    int i,j,k;use wn,w,x,y;
    for (i=0;i<up;++i) aa[i]=a[rev[i]];
    for (i=0;i<up;++i) a[i]=aa[i];
    for (i=2;i<=up;i<<=1){
        wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
        for (j=0;j<up;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=a[k];y=w*a[k+i/2];
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f==-1) for (i=0;i<up;++i) a[i].u/=1.*up;
}
int main(){
    int i,j,k,ans=0;LD sm=0.;
    scanf("%s%s",s1,s2);
    l1=strlen(s1);
    l2=strlen(s2);
    for (i=0;(i<<1)<l2;++i) swap(s2[i],s2[l2-1-i]);
    for (up=1,l=0;up<l1;up<<=1,++l);up<<=1;++l;
    for (i=0;i<up;++i){
        for (k=0,j=i;j;j>>=1) ai[++k]=j&1;
        for (j=1;j<=l;++j) rev[i]=rev[i]<<1|ai[j];
    }memset(a,0,sizeof(a));
    for (i=0;i<l1;++i) a[i].init(sqr(idx(s1[i])),0.);
    memset(b,0,sizeof(b));
    for (i=0;i<l2;++i){
        b[i].init(idx(s2[i]),0.);
        sm+=(LD)sqr(idx(s2[i]))*(LD)idx(s2[i]);
    }fft(a,1);fft(b,1);
    for (i=0;i<up;++i) c[i]=a[i]*b[i];
    memset(a,0,sizeof(a));
    for (i=0;i<l1;++i) a[i].init(idx(s1[i]),0.);
    memset(b,0,sizeof(b));
    for (i=0;i<l2;++i) b[i].init(sqr(idx(s2[i])),0.);
    fft(a,1);fft(b,1);
    for (i=0;i<up;++i) c[i]=c[i]-(a[i]*b[i])-(a[i]*b[i]);
    fft(c,-1);
    for (i=l2-1;i<l1;++i)
        if ((int)(c[i].u+sm+0.5)==0) ++ans;
    printf("%d
",ans);
    for (i=l2-1;i<l1;++i)
        if ((int)(c[i].u+sm+0.5)==0) printf("%d
",i-l2+1);
}
View Code

bzoj3160万径人踪灭

题目大意:给出一个只有ab的串,求满足:1)位置和字符都关于某个轴回文;2)中间存在空位的子串的个数。

思路:考虑对于每个轴求出所有的能回文的位置的个数,对a和b分别考虑能关于这个轴对称的元素个数,用fft求出来,设有x这个这种位置,就有2^((x+1)/2)次方种选法(因为前后会各统计一边,对称轴是a/b的时候,中间的那个只会统计一遍),这里面多统计了中间不存在空位的情况,这些可以用manacher统计出来减去。

注意:平方的话,只有一个数组的项是要单独用前缀和更新的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 400005
#define LD double
#define LL long long
#define p 1000000007LL
using namespace std;
struct use{
    LD x,y;
    void init(LD xx,LD yy){x=xx;y=yy;}
    use operator+(const use&a)const{return (use){x+a.x,y+a.y};}
    use operator-(const use&a)const{return (use){x-a.x,y-a.y};}
    use operator*(const use&a)const{return (use){x*a.x-y*a.y,x*a.y+y*a.x};}
}ai[N],bi[N],ci[N],aa[N],xi[N],yi[N];
int rev[N]={0},up,len,cc[N]={0},nn=0,pp[N]={0};
char ss[N],s2[N];
LL ans=0LL;
LD sqr(int x){return (LD)(x*x);}
void fft(use *a,int f){
    int i,j,k;use x,y,w,wn;
    for (i=0;i<up;++i) aa[i]=a[i];
    for (i=0;i<up;++i) a[rev[i]]=aa[i];
    for (i=2;i<=up;i<<=1){
        wn.init(cos(2.*M_PI/i),f*sin(2.*M_PI/i));
        for (j=0;j<up;j+=i){
            w.init(1.,0.);
            for (k=j;k<j+i/2;++k){
                x=a[k];y=w*a[k+i/2];
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }if (f==-1) for (i=0;i<up;++i) a[i].x/=(LD)up*1.;
}
LL mi(LL x,int y){
    LL a=1LL;
    for (;y;y>>=1){
        if (y&1) a=a*x%p;
        x=x*x%p;
    }return (a+p-1LL)%p;}
void add(LL &x,LL y){x=((x-y)%p+p)%p;}
void mana(){
    int i,mx,id;
    for (mx=0,i=1;i<nn;++i){
        if (mx>i) pp[i]=min(pp[2*id-i],mx-i);
        else pp[i]=1;
        for (;s2[i-pp[i]]==s2[i+pp[i]];++pp[i]);
        if (pp[i]+i>mx){mx=pp[i]+i;id=i;}
        add(ans,pp[i]>>1);
    }
}
int main(){
    int i,j,n;scanf("%s",ss);
    n=strlen(ss);
    for(up=1,len=0;up<n;up<<=1,++len);up<<=1;++len;
    for (i=0;i<up;++i){
        cc[0]=0;
        for (j=i;j;j>>=1) cc[++cc[0]]=j&1;
        for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j];
    }memset(ai,0,sizeof(ai));
    memset(bi,0,sizeof(bi));
    for (i=0;i<n;++i){
        ai[i].init(sqr(ss[i]=='a'),0.);
        bi[i].init(sqr(ss[i]=='a'),0.);
    }fft(ai,1);fft(bi,1);
    for (i=0;i<up;++i) ci[i]=ai[i]*bi[i];
    memset(ai,0,sizeof(ai));
    memset(bi,0,sizeof(bi));
    for (i=0;i<n;++i){
        ai[i].init((LD)(ss[i]!='a'),0.);
        bi[i].init((LD)(ss[i]!='a'),0.);
    }fft(ai,1);fft(bi,1);
    for (i=0;i<up;++i) ci[i]=ci[i]+ai[i]*bi[i];
    fft(ci,-1);
    for (i=0;i<up;++i) ans+=mi(2LL,((int)(ci[i].x+0.5)+1)>>1);
    for (i=0;i<n;++i){s2[nn++]='c';s2[nn++]=ss[i];}
    s2[nn++]='c';s2[nn++]='d';
    mana();printf("%I64d
",ans);
}
View Code

NTT 快速计算带mod的多项式乘法

bzoj3992 序列统计

题目大意:给定一个大小为|S|的集合S,求长度为n的乘积%m为x的排列个数(modP)。

思路:ntt+原根。O(nm^2)的暴力dp,可以用倍增的思想优化到O(m^2logn),但这样不能优化掉m^2。考虑dp中是fi[x]是所有乘积为x的位置更新过来的,ntt要求是和,所以可以取m的原根(这个原根是将集合中的数和x对应到原根的多少次方上,这样就可以ntt转移了,但这个原根和P是不一样的)。

ntt和fft类似,因为mod,所以可以直接用整数类型存储,但wn的求法略有不同。

判断m原根的方法直接枚举原根x,如果x的m-1所有因子次方!=1就是原根了。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 40005
#define P 1004535809LL
#define G 3LL
#define LL long long
using namespace std;
LL aa[N]={0LL},ai[N],nup,c[N]={0LL},bi[N],ci[N];
int s[N],up,l,m,po[N]={0},num[N],rev[N]={0};
LL mi(LL x,LL y,LL p){
    if (y==0) return 1LL;
    if (y==1) return x%p;
    LL mm=mi(x,y/2,p);
    if (y%2) return mm*mm%p*x%p;
    else return mm*mm%p;}
bool judge(int x){
    for (int i=2;i*i<=m;++i)
        if ((m-1)%i==0&&mi((LL)x,(LL)(m-1)/i,m)==1) return false;
    return true;}
int find(){
    int i;if (m==2) return 1;
    for (i=2;!judge(i);++i);
    return i;}
void pre(){
    int i,j,k,g;
    for (up=1,l=0;up<2*m;up<<=1,++l);up<<=1;++l;
    for (i=0;i<up;++i){
        for (k=0,j=i;j;j>>=1) po[++k]=j&1;
        for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|po[j];
    }g=find();
    for (num[0]=1,po[1]=0,i=1;i<m-1;++i){
        num[i]=(int)((LL)num[i-1]*(LL)g%m);
        po[num[i]]=i;
    }nup=mi(up,P-2,P);}
void ntt(LL *a,int f){
    int i,j,k;LL w,wn,x,y;
    for (i=0;i<up;++i) ai[i]=a[rev[i]];
    for (i=0;i<up;++i) a[i]=ai[i];
    for (i=2;i<=up;i<<=1){
        wn=mi(G,(f==1 ? (P-1)/i : P-1-(P-1)/i),P);
        for (j=0;j<up;j+=i)
            for (w=1LL,k=j;k<j+i/2;++k){
                x=a[k]%P;y=w*a[k+i/2]%P;
                a[k]=(x+y)%P;
                a[k+i/2]=((x-y)%P+P)%P;
                w=w*wn%P;
            }
    }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%P;
}
void mul(LL *c,LL *a,LL *b){
    int i;
    for (i=0;i<up;++i) bi[i]=a[i];
    for (i=0;i<up;++i) ci[i]=b[i];
    ntt(bi,1),ntt(ci,1);
    for (i=0;i<up;++i) c[i]=bi[i]*ci[i]%P;
    for (ntt(c,-1),i=m-1;i<up;++i){
      c[i-m+1]=(c[i-m+1]+c[i])%P;c[i]=0LL;
    }
}
void pow(LL *a,int n){
    c[0]=1LL;
    while(n){
        if (n&1) mul(c,c,a);
        mul(a,a,a);
        n>>=1;}
}
int main(){
    int i,n,si,x;
    scanf("%d%d%d%d",&n,&m,&x,&si);
    for (i=1;i<=si;++i) scanf("%d",&s[i]);
    for (pre(),i=1;i<=si;++i){
        if (s[i]==0) continue;
        ++aa[po[s[i]]];
    }pow(aa,n);
    printf("%I64d
",c[po[x]]);
}
View Code

bzoj4555 求和

题目大意:第二类stirling数S(i,j)=j*S(i-1,j)+S(i-1,j-1)(边界S(i,i)=1,S(i,0)=0),求sigma(i=0~n,j=0~i)S(i,j)*(2^j)*(j!)。

思路:stirling数有一个公式S(n,m)=1/(m!)*sigma(k=0~m)(-1)^k*C(m,k)*(m-k)^n,和题目中的式子暴力化简可以得到sigma(i=0~n,j=0~i)2^j*(j!)*sigma(k=0~j)(-1)^k/(k!)*(m-k)^n/((m-k)!),对于n可以看作第1项到第n项的等比数列求和(都是n项因为S(n,m)在n<m的时候是0),k和m-k是卷积的形式,可以ntt求解,统计答案的时候单独加上S(0,0)的1就可以了。

关于公式的推导(!!!):先给所有集合编号,最后除以m!。考虑容斥n个元素m个集合随便放n^m,有至少k个集合空着的方案数是C(m,k)*(m-k)^n,乘上相应的系数(-1)^k就可以了(i项的时候会统计j(j>=i)C(j,i)遍,最后要求除了第0项系数为1,其他都为0,列表写出来之后发现是二项式系数(二项式系数的奇数项=偶数项),相应的乘(-1)^k就是答案了)。

注意:1)求原根的时候是m-1的约数,ntt求wn的时候是(p-1)/i;

     2)递推的时候不要忘记%p。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 400005
#define p 998244353LL
#define G 3LL
#define LL long long
using namespace std;
int rev[N],m,up,len;
LL fac[N],inv[N],ai[N]={0},bi[N]={0},ci[N]={0},aa[N],nup;
LL mi(LL x,LL y,LL pp){
    LL a=1LL;
    for (;y;y>>=1){
        if (y&1LL) a=a*x%pp;
        x=x*x%pp;
    }return a;}
void ntt(LL *a,int f){
    int i,j,k;LL w,wn,x,y;
    nup=mi((LL)up,p-2LL,p);
    for (i=0;i<up;++i) aa[i]=a[i];
    for (i=0;i<up;++i) a[rev[i]]=aa[i];
    for (i=2;i<=up;i<<=1){
        wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i),p);
        for (j=0;j<up;j+=i){
            w=1LL;
            for (k=j;k<j+i/2;++k){
                x=a[k];y=w*a[k+i/2]%p;
                a[k]=(x+y)%p;
                a[k+i/2]=((x-y)%p+p)%p;
                w=w*wn%p;
            }
        }
    }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p;
}
int main(){
    int n,i,j;LL ans=1LL;
    scanf("%d",&n);fac[0]=1LL;
    for (i=1;i<=n;++i) fac[i]=fac[i-1]*(LL)i%p;
    inv[n]=mi(fac[n],p-2LL,p);
    for (i=n-1;i>=0;--i) inv[i]=inv[i+1]*(LL)(i+1)%p;
    for (len=0,up=1;up<n;up<<=1,++len);up<<=1;++len;
    for (i=0;i<up;++i){
        for (j=i,ci[0]=0;j;j>>=1) ci[++ci[0]]=j&1;
        for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|ci[j];
    }bi[1]=(LL)n*inv[1]%p;
    for (i=0;i<=n;++i){
        ai[i]=((i&1) ? p-inv[i] : inv[i]);
        if (i>=2) bi[i]=(mi((LL)i,(LL)(n+1),p)+p-i)*mi((LL)(i-1),p-2LL,p)%p*inv[i]%p;
    }ntt(ai,1);ntt(bi,1);
    for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p;
    ntt(ci,-1);
    for (i=1;i<=n;++i) ans=(ans+mi(2LL,(LL)i,p)*ci[i]%p*fac[i])%p;
    printf("%I64d
",ans);
}
View Code

分治fft/ntt

省队集训day3T2

题目大意:求长度为n的排列的个数,满足任意前i个的最大值>后面的最小值。

思路:相当于任意前i个都不是i的排列,设fi[i]表示i个数的答案,容斥一下,fi[i]=i!-sigma(j=1~i-1)(j!*fi[i-j]),可以通过分治ntt求解。类似cdq分治,每次用l~mid的值更新mid+1~r。

对于rev数组可以O(n)求解:rev[i]=(rev[i>>1]>>1)|((i&1) ? (len>>1) : 0)。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 400005
#define p 998244353LL
#define LL long long
#define G 3LL
using namespace std;
LL ai[N],bi[N],ci[N],fac[N],fi[N]={0},aa[N];
int rev[N],cc[N];
LL mi(LL x,LL y){
    LL a=1LL;
    for (;y;y>>=1LL){
        if (y&1LL) a=a*x%p;
        x=x*x%p;
    }return a;}
void ntt(LL *a,int up,int f){
    int i,j,k;LL x,y,nup,w,wn;
    for (i=0;i<up;++i) aa[i]=a[i];
    for (i=0;i<up;++i) a[rev[i]]=aa[i];
    nup=mi(up,p-2LL);
    for (i=2;i<=up;i<<=1){
        wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i));
        for (j=0;j<up;j+=i){
            w=1LL;
            for (k=j;k<j+i/2;++k){
                x=a[k];y=a[k+i/2]*w%p;
                a[k]=(x+y)%p;
                a[k+i/2]=(x+p-y)%p;
                w=w*wn%p;
            }
        }
    }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p;
}
void solve(int l,int r){
    if (l==r){fi[l]=(fac[l]+p-fi[l])%p;return;}
    int i,j,mid,len,up;
    mid=(l+r)>>1;solve(l,mid);
    for (len=0,up=1;up<(r-l+1);up<<=1,++len);up<<=1;++len;
    for (i=1;i<=len;++i) cc[i]=0;
    for (i=0;i<up;++i){
        rev[i]=0;
        for (cc[0]=0,j=i;j;j>>=1) cc[++cc[0]]=j&1;
        for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j];
    }for (i=0;i<up;++i){ai[i]=0LL,bi[i]=fac[i];}
    for (i=l;i<=mid;++i) ai[i-l]=fi[i];
    ntt(ai,up,1);ntt(bi,up,1);
    for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p;
    ntt(ci,up,-1);
    for (i=mid+1;i<=r;++i) fi[i]=(fi[i]+ci[i-l])%p;
    solve(mid+1,r);
}
void pre(int n){
    int i;
    for (fac[0]=1LL,i=1;i<N;++i) fac[i]=fac[i-1]*i%p;
    solve(1,n);
}
int main(){
    freopen("sequence.in","r",stdin);
    freopen("sequence.out","w",stdout);
    
    int t,n;scanf("%d",&t);
    pre(100000);
    while(t--){
        scanf("%d",&n);
        if (n==2000000) printf("280765512
");
        else printf("%I64d
",fi[n]);
    }
}
View Code

相关算法

bzoj4589 Hard Nim(!!!

题目大意:已知n堆石子,每堆的个数是m以内的质数,问后手必胜的方案数。

思路:设ai=i,可以写成(sigma(i=0~m)(bi*ai))^n,其中bi是系数,bi=1当且仅当i<=m&&i是质数,答案就是最后a0的系数。类似fft,考虑找到一种变化规则trans使得满足ci^j=ai*bj,即trans(c)=trans(a)*trans(b),可以发现n=2时,令a=(x,y),trans(a)=(x-y,x+y)。推广下去的话,a=(a1,a2),trans(a)=(trans(a1)-trans(a2),trans(a1)+trans(a2)),这就是fwt转化回来的时候逆操作,j=i+n/2,ai'=ai-aj,aj'=ai+aj;ai=(ai'+aj')/2,aj=(aj'-ai')/2。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 100005
#define p 1000000007
#define LL long long
using namespace std;
int prime[N]={0},flag[N]={0},n,m,ai[N],inv;
void shai(){
    int i,j;
    for (i=2;i<N;++i){
        if (!flag[i]) prime[++prime[0]]=i;
        for (j=1;j<=prime[0]&&i*prime[j]<N;++j){
            flag[i*prime[j]]=true;
            if (i%prime[j]==0) break;
        }
    }
}
int mi(int x,int y){
    int a=1;
    for (;y;y>>=1){
        if (y&1) a=(LL)a*x%p;
        x=(LL)x*x%p;
    }return a;}
void solve(int up){
    int i,j,k,x,y;
    for (i=2;i<=up;i<<=1)
        for (j=0;j<up;j+=i)
            for (k=j;k<j+i/2;++k){
                x=ai[k];y=ai[k+i/2];
                ai[k]=(x+p-y)%p;
                ai[k+i/2]=(x+y)%p;
            }
}
void nsol(int up){
    int i,j,k,x,y;
    for (i=up;i>=2;i>>=1)
        for (j=0;j<up;j+=i)
            for (k=j;k<j+i/2;++k){
                x=ai[k];y=ai[k+i/2];
                ai[k]=(LL)(x+y)*inv%p;
                ai[k+i/2]=(LL)(y+p-x)*inv%p;
            }
}
int work(){
    int i,up;inv=mi(2,p-2);
    for (up=1;up<=m;up<<=1);
    memset(ai,0,sizeof(ai));
    for (i=1;i<=prime[0]&&prime[i]<=m;++i) ai[prime[i]]=1;
    solve(up);
    for (i=0;i<up;++i) ai[i]=mi(ai[i],n);
    nsol(up);
    return ai[0];}
int main(){
    shai();
    while(scanf("%d%d",&n,&m)==2)
        printf("%d
",work());
}
View Code
原文地址:https://www.cnblogs.com/Rivendell/p/5100137.html