FFT&多项式学习记录

最近补了一发多项式相关的知识点,记录一下放个板子

  1 #include<bits/stdc++.h>
  2 #define ll long long
  3 #define maxn 100005
  4 using namespace std;
  5 const int mod = 998244353;
  6 const int g = 3;
  7 namespace Poly
  8 {
  9     int fastpow(int a,int p)
 10     {
 11         int ans=1;
 12         while(p)
 13         {
 14             if(p&1)ans=1ll*ans*a%mod;
 15             a=1ll*a*a%mod;p>>=1; 
 16         }
 17         return ans;
 18     }
 19     int inv(int a)
 20     {
 21         return fastpow(a,mod-2);
 22     }
 23     const int inv2=inv(2);
 24     int rev[maxn*4],tmp[35];
 25     int getrev(int x,int L)
 26     {
 27         int res=0;
 28         for(int i=0;i<L;++i)tmp[i]=(x>>i)&1;
 29         for(int i=L-1;i>=0;--i)if(tmp[L-1-i])res|=(1<<i);
 30         return res;
 31     }
 32     void ntt(int n,int *a,int tp)//n-1次多项式,其中n为2的幂次,a为补齐之后的多项式系数,tp==1为DFT,tp==-1为IDFT 
 33     {
 34         for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
 35         for(int h=2;h<=n;h<<=1)
 36         {
 37             ll wn=fastpow(g,(mod-1)/h);
 38             if(tp==-1)wn=inv(wn);
 39             int mid=(h>>1);
 40             for(int i=0;i<n;i+=h)
 41             {
 42                 ll w=1;
 43                 for(int j=i;j<i+mid;++j,w=w*wn%mod)
 44                 {
 45                     ll l=a[j],r=1ll*w*a[j+mid]%mod;
 46                     a[j]=(l+r)%mod;
 47                     a[j+mid]=(l-r+mod)%mod;
 48                 }
 49             }
 50         }
 51         if(tp==-1)
 52         {
 53             int k=inv(n);
 54             for(int i=0;i<n;++i)a[i]=1ll*a[i]*k%mod;
 55         }
 56     }
 57     int X[maxn*4],Y[maxn*4];
 58     void poly_mul(int n,int m,int *a,int *b,int *c)//n-1次多项式*m-1次多项式,result存在c中 
 59     {
 60         int N=1,L=0;
 61         for(;N<=n+m-2;N<<=1)L++;
 62         for(int i=0;i<N;++i)rev[i]=getrev(i,L);
 63         for(int i=0;i<N;++i)X[i]=a[i],Y[i]=b[i];
 64         ntt(N,X,1);ntt(N,Y,1);
 65         for(int i=0;i<N;++i)c[i]=1ll*X[i]*Y[i]%mod;
 66         ntt(N,c,-1);
 67         for(int i=0;i<N;++i)X[i]=0,Y[i]=0;
 68     }
 69     int T[maxn*4];
 70     void poly_inv(int n,int *a,int *b)//n-1次多项式求逆 
 71     {
 72         if(n==1)
 73         {
 74             b[0]=inv(a[0]);
 75             return;
 76         }
 77         else
 78         {
 79             poly_inv((n+1)>>1,a,b);
 80             int N=1,L=0;
 81             for(;N<(n<<1);N<<=1)L++;
 82             for(int i=0;i<N;++i)rev[i]=getrev(i,L);
 83             for(int i=0;i<n;++i)T[i]=a[i];
 84             for(int i=n;i<N;++i)T[i]=0;
 85             ntt(N,T,1);ntt(N,b,1);
 86             for(int i=0;i<N;++i)b[i]=1ll*b[i]*(2ll+mod-1ll*b[i]*T[i]%mod)%mod;
 87             ntt(N,b,-1);
 88             for(int i=n;i<N;++i)b[i]=0;
 89             for(int i=0;i<N;++i)T[i]=0; 
 90         }
 91     }
 92     int A[maxn*4],B[maxn*4],C[maxn*4],D[maxn*4];
 93     void poly_div(int n,int m,int *a,int *b,int *c)//n-1次多项式 / m-1次多项式,c中存结果
 94     {
 95         for(int i=0;i<n;++i)A[i]=a[n-i-1];
 96         for(int i=0;i<m;++i)B[i]=b[m-i-1];
 97         for(int i=n-m+1;i<n;++i)A[i]=0;
 98         for(int i=n-m+1;i<m;++i)B[i]=0;
 99         poly_inv(n-m+1,B,D);
100         poly_mul(n,n-m+1,A,D,c);
101         reverse(c,c+n-m+1);
102         for(int i=n-m+1;i<n*2;++i)c[i]=0;
103         for(int i=0;i<n*2;++i)A[i]=0,B[i]=0,D[i]=0;
104     }
105     void poly_mod(int n,int m,int *a,int *b,int *r)//n-1次多项式 % m-1次多项式,r中存结果
106     {
107         poly_div(n,m,a,b,C);
108         for(int i=0;i<n*2;++i)A[i]=0,B[i]=0,D[i]=0;
109         for(int i=0;i<n;++i)A[i]=a[i];
110         for(int i=0;i<m;++i)B[i]=b[i];
111         poly_mul(m,n-m+1,B,C,D);
112         for(int i=0;i<m-1;++i)r[i]=(A[i]-D[i]+mod)%mod;
113         for(int i=0;i<n*2;++i)A[i]=0,B[i]=0,D[i]=0;
114     }
115     int M[maxn*4],Q[maxn*4],K[maxn*4];
116     void poly_sqrt(int n,int *a,int *b)//n-1次多项式开根号,a[0]=1,result存在 b 中 
117     {
118         if(n==1)
119         {
120             b[0]=1;
121             return;
122         }
123         else
124         {
125             poly_sqrt((n+1)>>1,a,b);
126             int N=1,L=0;
127             for(;N<(n<<1);N<<=1)L++;
128             for(int i=0;i<N;++i)rev[i]=getrev(i,L);
129             for(int i=0;i<n;++i)M[i]=a[i];
130             for(int i=n;i<N;++i)M[i]=0;
131             poly_inv(n,b,Q);
132             ntt(N,Q,1);ntt(N,M,1);
133             for(int i=0;i<N;++i)Q[i]=1ll*Q[i]*M[i]%mod;
134             ntt(N,Q,-1);
135             for(int i=0;i<n;++i)b[i]=1ll*inv2*(1ll*b[i]+1ll*Q[i])%mod;
136             for(int i=0;i<N;++i)Q[i]=0;
137         }
138     }
139     void poly_derivation(int n,int *a,int *b)//n-1次多项式求导 
140     {
141         for(int i=0;i<n-1;++i)b[i]=1ll*a[i+1]*(i+1)%mod;
142     }
143     void poly_integration(int n,int *a,int *b)//n-1次多项式积分
144     {
145         b[0]=0;
146         for(int i=0;i<n;++i)b[i+1]=1ll*a[i]*inv(i+1)%mod;
147     }
148     void poly_ln(int n,int *a,int *b)//n-1次多项式 ln 
149     {
150         for(int i=0;i<=n*2;++i)K[i]=0,Q[i]=0,M[i]=0;
151         poly_derivation(n,a,M);
152         poly_inv(n,a,Q);
153         poly_mul(n-1,n,M,Q,K);
154         poly_integration(n,K,b);
155         for(int i=n;i<=n*2;++i)b[i]=0;
156         for(int i=0;i<=n*2;++i)K[i]=0,Q[i]=0,M[i]=0;
157     }
158     int U[maxn*4],V[maxn*4],W[maxn*4];
159     void poly_exp(int n,int *a,int *b)//n-1次多项式exp,a[0]=0 
160     {
161         if(n==1)
162         {
163             b[0]=1;
164             return;
165         }
166         else
167         {
168             poly_exp((n+1)>>1,a,b);
169             poly_ln(n,b,V);
170             int N=1,L=0;
171             for(;N<(n<<1);N<<=1)L++;
172             for(int i=0;i<N;++i)rev[i]=getrev(i,L);
173             for(int i=0;i<n;++i)U[i]=a[i];
174             for(int i=n;i<N;++i)U[i]=0;
175             U[0]=(U[0]+1)%mod;
176             for(int i=0;i<N;++i)U[i]=(1ll*U[i]+mod-V[i])%mod;
177             ntt(N,b,1);ntt(N,U,1);
178             for(int i=0;i<N;++i)b[i]=1ll*b[i]*U[i]%mod;
179             ntt(N,b,-1);
180             for(int i=n;i<N;++i)b[i]=0;
181         }
182     }
183 }
View Code
原文地址:https://www.cnblogs.com/uuzlove/p/10434373.html