Re.多项式求逆

前言

emmm暂无


多项式求逆目的

顾名思义 就是求出一个多项式的摸xn时的逆

给定一个多项式F(x),请求出一个多项式G(x),满足F(x)G(x)1(modxn),系数对998244353取模。


多项式求逆主要思路

我们考虑用递推的做法

假设我们当前已知F(x)H(x)=1(mod xi/2)

要求的是F(x)Q(x)=1(mod xi)

因为F(x)Q(x)=1(mod xi)

所以F(x)Q(x)=1(mod xi/2)

可得F(x)(Q(x)-H(x))=0(mod xi/2)

显然可得Q(x)-H(x)=0(mod xi/2)

将上述式子两边平方得H(x)2-2Q(x)H(x)+Q(x)2=0(mod xi)

再将两边同时乘上F(x)

因为F(x)Q(x)=1(mod xi)

所以得F(x)H(x)2-2H(x)+Q(x)=0(mod xi)

移项最后得求的G(x)=2H(x)-F(x)H(x)2(mod xi)

那么就可以递推了

初始状态显然为i等于1的情况G(0)为F(0)的逆元

最后的递推式就为

1.if(x=0)-----G(0)=F(0)p-2(p为模数)

2.if(x>0)-----G(x)=2H(x)-F(x)H(x)2(mod xi)


代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define C getchar()-48
inline ll read()
{
    ll s=0,r=1;
    char c=C;
    for(;c<0||c>9;c=C) if(c==-3) r=-1;
    for(;c>=0&&c<=9;c=C) s=(s<<1)+(s<<3)+c;
    return s*r;
} 
const ll p=998244353,G=3,N=2100000;
ll n; 
ll rev[N];
ll a[N],b[N],c[N];
inline ll ksm(ll a,ll b)
{
    ll ans=1;
    while(b)
    {
        if(b&1) ans=(ans*a)%p;
        a=(a*a)%p;
        b>>=1;
    }
    return ans;
}
inline void ntt(ll *a,ll n,ll kd)
{
    for(ll i=0;i<n;i++)
    if(i<rev[i])
      swap(a[i],a[rev[i]]);
    for(ll i=1;i<n;i<<=1)
    {
        ll gn=ksm(G,(p-1)/(i<<1));
        for(ll j=0;j<n;j+=(i<<1))
        {
            ll t1,t2,g=1;
            for(ll k=0;k<i;k++,g=1ll*g*gn%p)
            {
                t1=a[j+k],t2=1ll*g*a[j+k+i]%p;
                a[j+k]=(t1+t2)%p,a[j+k+i]=(t1-t2+p)%p;
            }
        }
    }
    if(kd==1) return;
    ll ny=ksm(n,p-2); 
    reverse(a+1,a+n);
    for(ll i=0;i<n;i++) a[i]=1ll*a[i]*ny%p;
}
inline void work(ll deg,ll *a,ll *b)
{
    if(deg==1){b[0]=ksm(a[0],p-2);return;}
    work((deg+1)>>1,a,b);
    ll len=0,sum=1;
    for(;sum<(deg<<1);sum<<=1,len++);
    for(ll i=1;i<sum;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    for(ll i=0;i<deg;i++) c[i]=a[i];
    for(ll i=deg;i<sum;i++) c[i]=0;
    ntt(c,sum,1);ntt(b,sum,1);
    for(ll i=0;i<sum;i++) b[i]=1ll*(2-1ll*c[i]*b[i]%p+p)%p*b[i]%p;
    ntt(b,sum,-1);
    for(ll i=deg;i<sum;i++) b[i]=0;
}
int main()
{
    n=read();
    for(ll i=0;i<n;i++) a[i]=read();
    work(n,a,b);
    for(ll i=0;i<n;i++) printf("%lld ",b[i]); 
    return 0;
}
原文地址:https://www.cnblogs.com/1436177712qqcom/p/10473455.html