洛谷 5205 【模板】多项式开根

题目:https://www.luogu.org/problemnew/show/P5205

不会二次剩余。

牛顿迭代推开根式子:

( f^2(x)-g(x)=0 )

( f(x)=f_0(x)-frac{ f_0^2(x)-g(x) }{ ( f_0^2(x)-g(x) )' } = frac{ f_0^2(x)-g(x) }{ 2f_0(x) } )

实现的时候形如 ( f(x)=frac{ f_0(x)+frac{ g(x) }{ f_0(x) } }{2} )

用的 vector 。慢了很多。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
#define vi vector<int>
#define pb push_back
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=(1<<18)+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n; vi f;
namespace Pl{
  int len,r[N];
  vi ntt(vi a,bool fx)
  {
    for(int i=0;i<len;i++)
      if(i<r[i])swap(a[i],a[r[i]]);
    for(int R=2;R<=len;R<<=1)
      {
    int wn=pw(3,fx?(mod-1)-(mod-1)/R:(mod-1)/R);
    for(int i=0,m=R>>1;i<len;i+=R)
      for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod)
        {
          int x=a[i+j], y=(ll)w*a[i+m+j]%mod;
          a[i+j]=upt(x+y); a[i+m+j]=upt(x-y);
        }
      }
    if(!fx)return a; int inv=pw(len,mod-2);
    for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod;
    return a;
  }
  vi inv(vi f,int n)
  {
    int tp; for(tp=1;tp<n;tp<<=1); tp<<=1;
    vi a,b; a.resize(tp); b.resize(tp);
    b[0]=pw(f[0],mod-2);
    for(int t=2,yt=1;yt<n;yt=t,t=len)
      {
    len=t<<1;
    for(int i=0,j=len>>1;i<len;i++)
      r[i]=(r[i>>1]>>1)+((i&1)?j:0);
    for(int i=0;i<t;i++)a[i]=f[i];
    a=ntt(a,0); b=ntt(b,0);
    for(int i=0;i<len;i++)
      b[i]=upt((ll)b[i]*(2-(ll)a[i]*b[i]%mod)%mod);
    b=ntt(b,1);
    for(int i=t;i<len;i++)a[i]=0;
    for(int i=t;i<len;i++)b[i]=0;
      }
    for(int i=n;i<len;i++)b[i]=0;//
    return b;
  }
  vi sqr(vi f,int n)
  {
    int iv2=pw(2,mod-2);
    int tp; for(tp=1;tp<n;tp<<=1); tp<<=1;
    vi a,b,c; a.resize(tp); b.resize(tp);
    b[0]=1;
    for(int t=2,yt=1;yt<n;yt=t,t=len)
      {
    for(int i=0;i<t;i++)a[i]=f[i];
    c=inv(b,t);
    len=t<<1;  c.resize(len);//resize
    for(int i=0,j=len>>1;i<len;i++)
      r[i]=(r[i>>1]>>1)+((i&1)?j:0);
    a=ntt(a,0); c=ntt(c,0);
    for(int i=0;i<len;i++)
      a[i]=(ll)a[i]*c[i]%mod;
    a=ntt(a,1);
    for(int i=0;i<t;i++)b[i]=(ll)(b[i]+a[i])*iv2%mod;
    for(int i=t;i<len;i++)a[i]=0;
      }
    return b;
  }
}
int main()
{
  n=rdn()-1; int tp;
  for(tp=1;tp<=n;tp<<=1); tp<<=1;
  f.resize(tp);
  for(int i=0;i<=n;i++)f[i]=rdn();
  f=Pl::sqr(f,n+1);
  for(int i=0;i<=n;i++)printf("%d ",f[i]);puts("");
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10658359.html