多项式求逆
定义
对于一个多项式\(A(x)\),如果存在多项式\(B(x)\)满足其次数小于等于\(A(x)\)的次数且
则称\(B(x)\)为\(A(x)\)在模\(x^n\)意义下的的逆元,记作\(A^{-1}(x)\)。
具体来说:
假设\(A(x)=a_0+a_1x+a_2x^2+\dots+a_nx^n+a_{n+1}a^{n+1}+\dots\)
那么\(A(x)\ mod \ x^n\)后,\(A(x)_{now}=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}\)
其在\(x^n\)次数以及后的项都会被模掉,而前面的不变。
\(A(x)B(x)\equiv 1 \pmod {x^n}\),\(A(x)\)的次数是\(n-1\)
求\(B(x)\)就是要求构造一个次数小于等于\(n-1\)的多项式,使与\(A(x)\)相乘之后
\(a_0=1,a_1\dots a_{n-1} =0\),然后\(n\)次项及以后都会被模掉。
所以\(A(x)B(x)\equiv 1 \pmod {x^n}\)。
求解
我们考虑当\(n=1\)时,\(A(x)=a_0,a_0B(x)\equiv 1\pmod{x}\)。
那么,\(B(x)=a_0^{-1}\)。
当\(n>1\)时,
假设我们已经求出\(A(x)\ mod\ x^{\lceil \frac{n}{2} \rceil}\)意义下的逆元\(G(x)\)。
因为\(A(x)B(x)\equiv1\pmod{x^n}\),那么\(A(x)B(x)\equiv 1\pmod{x^{\lceil \frac{n}{2} \rceil}}\)
原因可参考上方取模定义的解释,模掉更多项了肯定也成立。
所以:
然后开始推一波式子:
对两边平方:
并且可以得到:
注意式子后面的模数可以变回\(x^n\),因为\(B(x)-G(x)\)在\(\ mod\ x^{\lceil \frac{n}{2} \rceil}\)意义下的\(0\)到\({\lceil \frac{n}{2} \rceil}-1\)次项的系数都为\(0\),所以对于第\(0\leq i \leq 2n-1\)项系数\(a_i\),\(a_i=\sum_{j=0}^i{b_j\times b_{i-j}}\),对于每一对\(b_j\)和\(b_{i-j}\),\(i\)和\(i-j\)中有一个小于等于\(x^{\lceil \frac{n}{2} \rceil}\),所以\(b_j\)和\(b_{i-j}\)中有一个一定为\(0\)。所以\(0\leq i \leq 2n-1\)项系数\(a_i\)都为\(0\),故上式成立。
在式子两边同乘\(A(x)\):
因为:
所以:
移项:
因为已知\(G(x)\),所以可以用NTT求出\(B(x)\),所以就可以从下向上递推了,
由主定理可以得到其复杂度:\(O(nlog_2n)\)。(虽然我也不会证)
下面贴上【模板】多项式求逆 的代码:
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#define N 400010
using namespace std;
int read()
{
int f,x,c;
while(!isdigit(c=getchar())&&c!='-'); c=='-'?(f=1,x=0):(f=0,x=c-'0');
while(isdigit(c=getchar())) x=(x<<3)+(x<<1)+c-'0'; return f?-x:x;
}
const int P=998244353,G=3,Gi=332748118;
int n,a[N],b[N],rev[N],lim=1,num,c[N];
int ksm(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=1ll*ans*a%P;
a=1ll*a*a%P; b>>=1;
}
return ans;
}
void NTT(int *A,int type)
{
for(int i=0;i<lim;i++) if(i<rev[i]) swap(A[i],A[rev[i]]);
for(int mid=1;mid<lim;mid<<=1)
{
int Wn=ksm(type==1?G:Gi,(P-1)/(mid<<1));
for(int r=mid<<1,l=0;l<lim;l+=r)
{
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*Wn%P)
{
int x=A[l+k],y=1ll*w*A[l+mid+k]%P;
A[l+k]=(x+y)%P; A[l+mid+k]=(x-y+P)%P;
}
}
}
}
void work(int deg,int *a,int *b)
{
if(deg==1) {b[0]=ksm(a[0],P-2);return;}
work((deg+1)>>1,a,b);
lim=1,num=0;
while(lim<=deg+deg) lim<<=1,num++;
for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(num-1));
for(int i=0;i<deg;i++) c[i]=a[i];
for(int i=deg;i<lim;i++) c[i]=0;
NTT(c,1); NTT(b,1);
for(int i=0;i<lim;i++)
b[i]=1ll*(2ll-1ll*c[i]*b[i]%P+P)%P*b[i]%P;
NTT(b,-1); int inv=ksm(lim,P-2);
for(int i=0;i<lim;i++) b[i]=1ll*b[i]*inv%P;
for(int i=deg;i<lim;i++) b[i]=0;
}
int main()
{
n=read();
for(int i=0;i<n;i++) a[i]=read();
work(n,a,b);
for(int i=0;i<n;i++) printf("%d ",b[i]);
return 0;
}