洛谷P4512 【模板】多项式除法

传送门

先膜拜一下两位大佬->这里这里

问题是这样的:给定一个$n$次多项式$A(x)$和一个$m(m≤n)$次多项式$B(x)$,要求求出两个多项式$D(x),R(x)$,满足$$A(x)=D(x)B(x)+R(x)$$

这里$A(x)$为$n$次多项式,$B(x)$为$m$次多项式,那么$D(x)$为$n-m$次多项式,$R(x)$为$m-1$次多项式(如果高次项不存在的话用$0$补齐)

发现这里$R(x)$很麻烦,考虑如何消去

对于$n$次多项式$A(x)$,我们定义一个运算$A^R(x)$,,满足$$A^R(x)=x^nA(frac{1}{x})$$

这个运算的作用是将$A(x)$的系数进行翻转,随便拿一个多项式带进去运算就能发现

然后开始推推推$$A(x)=D(x)B(x)+R(x)$$

我们将$x$用$frac{1}{x}$代入,并在左右同乘上$x^n$,得$$x^nA(frac{1}{x})=x^{n-m}D(frac{1}{x})x^mB(frac{1}{x})+x^{n-m+1}x^{m-1}R(frac{1}{x})$$
$$A^R(x)=D^R(x)B^R(x) + x^{n - m + 1}R^R(x)$$

实际上,在多项式求逆的时候我们就知道,在没有取模的情况下,多项式除法是可以有无数项的。那么我们要在这里进行取模,顺便消去$R(x)$

因为$D(x)$即使在反转之后次数仍然不会高于$n-m$,而$x^{n - m + 1}R^R(x)$的最低次项次数高于$n-m$,所以我们可以把上式放到$pmod{x^{n-m+1}}$意义下,就能把$R(x)$的影响消除掉,且不会影响$D(x)$,而$A(x)$和$B(x)$已知不会有问题,那么原式就变成了$$A^R(x)=D^R(x)B^R(x)pmod{x^{n-m+1}}$$
$$A^R(x)B^{-R}(x)=D^R(x)pmod{x^{n-m+1}}$$
那么只要求出$B^{R}$的逆元,就能求出$D(x)$了

然后$R(x)$的话,只要把$D(x)$带进式子里就可以求得了

 1 //minamoto
 2 #include<iostream>
 3 #include<cstdio>
 4 #include<cstring>
 5 #include<algorithm>
 6 #define swap(x,y) (x^=y,y^=x,x^=y)
 7 #define mul(x,y) (1ll*x*y%P)
 8 #define add(x,y) (x+y>=P?x+y-P:x+y)
 9 #define dec(x,y) (x-y<0?x-y+P:x-y)
10 using namespace std;
11 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
12 char buf[1<<21],*p1=buf,*p2=buf;
13 inline int read(){
14     #define num ch-'0'
15     char ch;bool flag=0;int res;
16     while(!isdigit(ch=getc()))
17     (ch=='-')&&(flag=true);
18     for(res=num;isdigit(ch=getc());res=res*10+num);
19     (flag)&&(res=-res);
20     #undef num
21     return res;
22 }
23 char sr[1<<21],z[20];int C=-1,Z;
24 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
25 inline void print(int x){
26     if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
27     while(z[++Z]=x%10+48,x/=10);
28     while(sr[++C]=z[Z],--Z);sr[++C]=' ';
29 }
30 const int N=1e6+5,P=998244353,Gi=3;
31 inline int ksm(int a,int b){
32     int res=1;
33     while(b){
34         if(b&1) res=mul(res,a);
35         a=mul(a,a),b>>=1;
36     }
37     return res;
38 }
39 int n,r[N],A[N],B[N],O[N],F[N],G[N],Q[N],R[N],Ginv[N],Atmp[N],Btmp[N];
40 inline int getlen(int x){
41     int len=1;
42     while(len<=x) len<<=1;
43     return len;
44 }
45 void NTT(int *A,int type,int len){
46     int limit=1,l=0;
47     while(limit<len) limit<<=1,++l;
48     for(int i=0;i<limit;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
49     for(int i=0;i<limit;++i) if(i<r[i]) swap(A[i],A[r[i]]);
50     for(int mid=1;mid<limit;mid<<=1){
51         int R=mid<<1,Wn=ksm(Gi,(P-1)/R);O[0]=1;
52         for(int j=1;j<mid;++j) O[j]=mul(O[j-1],Wn);
53         for(int j=0;j<limit;j+=R){
54             for(int k=0;k<mid;++k){
55                 int x=A[j+k],y=mul(O[k],A[j+k+mid]);
56                 A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
57             }
58         }
59     }
60     if(type==-1){
61         reverse(A+1,A+limit);
62         for(int i=0,inv=ksm(len,P-2);i<limit;++i)
63         A[i]=mul(A[i],inv);
64     }
65 }
66 void Inv(int *a,int *b,int len){
67     if(len==1) return (void)(b[0]=ksm(a[0],P-2));
68     Inv(a,b,len>>1);
69     for(int i=0;i<len;++i) A[i]=a[i],B[i]=b[i];
70     NTT(A,1,len<<1),NTT(B,1,len<<1);
71     for(int i=0,l=(len<<1);i<l;++i) A[i]=mul(mul(A[i],B[i]),B[i]);
72     NTT(A,-1,len<<1);
73     for(int i=0;i<len;++i) b[i]=dec(1ll*(b[i]<<1)%P,A[i]);
74     for(int i=0,l=(len<<1);i<l;++i) A[i]=B[i]=0;
75 }
76 void Mul(int *a,int *b,int *c,int n,int m){
77     int len=getlen(max(n,m))<<1;
78     for(int i=0;i<=n;++i) A[i]=a[i];
79     for(int i=0;i<=m;++i) B[i]=b[i];
80     NTT(A,1,len),NTT(B,1,len);
81     for(int i=0;i<=len;++i) c[i]=mul(A[i],B[i]),A[i]=B[i]=0;
82     NTT(c,-1,len);
83 }
84 int main(){
85 //    freopen("testdata.in","r",stdin);
86     int n=read(),m=read();
87     for(int i=0;i<=n;++i) F[i]=read();
88     for(int i=0;i<=m;++i) G[i]=read();
89     reverse(F,F+1+n),reverse(G,G+1+m);
90     Inv(G,Ginv,getlen(n-m));
91     Mul(F,Ginv,Q,n-m,n-m);
92     reverse(Q,Q+n-m+1);
93     for(int i=0;i<=n-m;++i) print(Q[i]);sr[++C]='
';
94     reverse(F,F+n+1),reverse(G,G+m+1);
95     Mul(Q,G,R,n-m,m);
96     for(int i=0;i<m;++i) print(dec(F[i],R[i]));
97     Ot();
98     return 0;
99 }
原文地址:https://www.cnblogs.com/bztMinamoto/p/9744682.html