[清华集训2017]生成树计数

代码:

#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=998244353;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=6e5;
const int G=3;
int f[N],g[N],n;
struct fft{
  int l,n,m;
  int r[N],a[N],b[N],w[N],inv[N];
  int C[N],D[N];
  fft()
  {
    inv[0]=inv[1]=1;
    rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
  }
  IL void ntt_init()
  {
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); 
  }
  IL void clear()
  {
      rep(i,0,n) a[i]=b[i]=0;
  }
  void ntt(int *a,int o)
  {
      for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
      for (int i=1;i<n;i<<=1)
      {
          int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
          rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
          for (int j=0;j<n;j+=(i*2))
            for (int k=0;k<i;k++)
            {
                int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
            //    if (x<0||y<0) cerr<<x<<" "<<y<<endl; 
                a[j+k]=x+y>mo?x+y-mo:x+y; 
            a[i+j+k]=x-y>=0?x-y:x-y+mo;
        //     a[j+k]=(x+y)%mo;
        //     a[i+j+k]=(x-y)%mo;
            }
    }
    if (o==-1)
    {
        reverse(&a[1],&a[n]);
        for (int i=0,inv=fsp(n,mo-2);i<n;i++)
           a[i]=1ll*a[i]*inv%mo;
    }
  }
  IL void getcj(int *C,int len)
  {
  //    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;      
      m=len*2; ntt_init();
      rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo;
      ntt(a,1); ntt(b,1);
      for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
      ntt(a,-1);
      for (int i=0;i<n;i++) C[i]=a[i];
      clear();
  }
  IL void getcj(int *A,int *B,int len)
  {
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
    ntt(a,1); ntt(b,1);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=a[i];
    clear();
  }
  IL void getinv(int *A,int *B,int len)
  {
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)>>1);
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; 
    clear();
  }
  IL void getsqrt(int *A,int *B,int len)
  {
    int inv2=fsp(2,mo-2);
    if (len==1) {B[0]=sqrt(A[0]); return;}
    getsqrt(A,B,(len+1)>>1);
    int C[N]={};
    getinv(B,C,len);
    getcj(A,C,len);
    for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
  }
  IL void getDao(int *a,int *b,int len)
  {
    for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo;
    b[len-1]=0;
  }
  IL void getjf(int *a,int *b,int len)
  {
    for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
    b[0]=0;
  }
  IL void getln(int *A,int *B,int len)
  {
  //  me(C); me(D);
    getDao(A,C,len);
    getinv(A,D,len);
    getcj(C,D,len);
    getjf(D,B,len);
    rep(i,0,len) C[i]=0,D[i]=0;
  }
  IL void getexp(int *A,int *B,int len)
  {
    if (len==1) {B[0]=1; return;}
    getexp(A,B,(len+1)>>1);
    int C[N];
    getln(B,C,len);
    for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo;
    C[0]=(C[0]+1)%mo;
    getcj(C,B,len);
  }
}F;
/*

f[i]=sum f[j]*g[i-j]; 

*/
/*
int now[N];
void solve(int h,int t)
{
  if (h>=t) return; 
  if (t-h<=32)
  {
      rep(i,h,t)
        rep(j,h,i)
          f[i]=(f[i]+1ll*f[j]*g[i-j])%mo;
    return;
  }
  int mid=(h+t)/2;
  solve(h,mid);
  rep(i,h,mid) F.a[i-h]=f[i];
  rep(i,1,t-h) F.b[i]=g[i];
  F.getcj(now,(t-h+1)+(mid-h+1));
  rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo;
  solve(mid+1,t);
}
*/
int sum[N],now[N],a[N],b[N],c[N],d[N],e[N];
ll jc[N],jc2[N];
/*
prod (1+a[i]x) 
*/ 
void solve(int h,int t,int *a)
{
    if (h==t) return;
    int mid=(h+t)/2;
    solve(h,mid,a); solve(mid+1,t,a);
    rep(i,h,mid) F.a[i-h+1]=a[i];
    rep(i,mid+1,t) F.b[i-mid]=a[i];
    F.a[0]=F.b[0]=1;
    F.getcj(now,(mid-h+2));
    rep(i,h,t) a[i]=now[i-h+1];
}
int sum3[N],sum4[N];
int main()
{
   ios::sync_with_stdio(false);
   int n,m;
   cin>>n>>m;
   ll ans=1;
   rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo;
   rep(i,1,n) sum[i]=(-a[i]+mo)%mo;
   solve(1,n,sum);
   sum[0]=1;
   F.getln(sum,sum,n+2);
   F.getDao(sum,sum,n+2);
   dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo;
   sum[0]=n;
   jc[0]=jc2[0]=1;
   rep(i,1,n) jc[i]=jc[i-1]*i%mo;
   jc2[n]=fsp(jc[n],mo-2);
   dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo;
   rep(i,0,n) a[i]=b[i]=c[i]=0;
   rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo;
   rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo;
   F.getln(c,e,n+1);
   rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo;
   F.getexp(e,c,n+1);
   F.getinv(b,d,n+1);
   F.getcj(a,d,n+1);
   rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo;
   F.getcj(c,d,n+1);
   ans=ans*d[n-2]%mo*jc[n-2]%mo;
   cout<<ans<<endl; 
   return 0;
}
View Code
原文地址:https://www.cnblogs.com/yinwuxiao/p/15143666.html