P4512 【模板】多项式除法

P4512 【模板】多项式除法

链接

分析  

  多项式除法

注意的地方:

75,76行开始时是这样写的:

memcpy(TA,a,sizeof(int)*(n+1));memset(TA+n+1,0,sizeof(TA));
memcpy(TB,b,sizeof(int)*(m+1));memset(TB+m+1,0,sizeof(TB));

然后开O2的情况不过。最后发现时后面的memset不能这样写。然后在本地开O2测试,可以过样例。。。 ~ 惊!~ 吓!

代码

  1 #include<cstdio>
  2 #include<algorithm>
  3 #include<cstring>
  4 #include<cmath>
  5 #include<iostream>
  6 #include<cctype>
  7 
  8 #define P 998244353
  9 #define G 3
 10 #define Gi 332748118 
 11 #define N 270000
 12 
 13 using namespace std;
 14 
 15 int A[N],B[N],D[N],TA[N],TB[N],DR[N],Binv[N],R[N];
 16 
 17 inline int read() {
 18     int x = 0,f = 1;char ch=getchar();
 19     for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1;
 20     for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0';
 21     return x*f;
 22 }
 23 
 24 int ksm(int a,int b) {
 25     int ans = 1;
 26     while (b) {
 27         if (b & 1) ans = (1ll * ans * a) % P;
 28         a = (1ll * a * a) % P;
 29         b >>= 1;
 30     }
 31     return ans;
 32 }
 33 void NTT(int *a,int n,int ty) {
 34     for (int i=0,j=0; i<n; ++i) {
 35         if (i < j) swap(a[i],a[j]);
 36         for (int k=(n>>1); (j^=k)<k; k>>=1);
 37     }
 38     for (int w1,m=2; m<=n; m<<=1) {
 39         if (ty == 1) w1 = ksm(G,(P-1)/m);
 40         else w1 = ksm(Gi,(P-1)/m);
 41         for (int i=0; i<n; i+=m) {
 42             int w = 1;
 43             for (int k=0; k<(m>>1); ++k) {
 44                 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P;
 45                 a[i+k] = (u + t) % P;
 46                 a[i+k+(m>>1)] = (u - t + P) % P;
 47                 w = 1ll * w * w1 % P;
 48             }
 49         }
 50     }
 51     if (ty==-1) {
 52         int inv = ksm(n,P-2);
 53         for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P;
 54     }
 55     
 56 }
 57 void Inv(int *A,int *B,int n) { 
 58     int len = 1;
 59     while (len <= n) len <<= 1; 
 60     B[0] = ksm(A[0],P-2);
 61     for (int m=2; m<=len; m<<=1) {
 62         for (int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i];
 63         NTT(TA,m<<1,1);
 64         NTT(TB,m<<1,1);
 65         for (int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P;
 66         NTT(TA,m<<1,-1);
 67         for (int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P;
 68     }
 69     memset(TA,0,sizeof(TA));
 70     memset(TB,0,sizeof(TB));
 71 }
 72 void Mul(int *a,int *b,int *C,int n,int m) {
 73     int len = 1;
 74     while (len <= n+m) len <<= 1;
 75     for (int i=0; i<=n; ++i) TA[i] = a[i];
 76     for (int i=0; i<=m; ++i) TB[i] = b[i];
 77     NTT(TA,len,1);
 78     NTT(TB,len,1);
 79     for (int i=0; i<len; ++i) C[i] = (1ll * TA[i] * TB[i]) % P,TA[i] = TB[i] = 0;
 80     NTT(C,len,-1);
 81 }
 82 int main() {
 83     int n = read() ,m = read() ;
 84     for (int i=0; i<=n; ++i) A[i] = read();
 85     for (int i=0; i<=m; ++i) B[i] = read();
 86     
 87     reverse(A,A+n+1);reverse(B,B+m+1); 
 88     
 89     Inv(B,Binv,n-m); // 求B转换后的逆元 
 90     Mul(A,Binv,D,n-m,n-m); // 求转换后的D 
 91     reverse(D,D+n-m+1);  
 92     for (int i=0; i<=n-m; ++i) printf("%d ",D[i]);puts("");
 93     
 94     reverse(A,A+n+1);reverse(B,B+m+1);
 95     Mul(D,B,R,n-m,m); // 求D*B 
 96     for (int i=0; i<m; ++i) R[i] = (A[i] - R[i] + P) % P; // 求R 
 97     for (int i=0; i<m; ++i) printf("%d ",R[i]);
 98     
 99     return 0;
100 }
原文地址:https://www.cnblogs.com/mjtcn/p/9157378.html