HDU 5829 Rikka with Subset

快速数论变换ntt。

早上才刚刚接触了一下FFT,然后就开始撸这题了,所以要详细地记录一下。

看了这篇巨巨的博客才慢慢领会的:http://blog.csdn.net/cqu_hyx/article/details/52194696

FFT的作用是计算卷积。可以简单的理解为计算多项式*多项式最后得到的多项式,暴力计算是O(n*n)的,FFT可以做到O(nlogn)。

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<iostream>
using namespace std;
typedef long long LL;
const double pi=acos(-1.0),eps=1e-8;
void File()
{
    freopen("D:\in.txt","r",stdin);
    freopen("D:\out.txt","w",stdout);
}
template <class T>
inline void read(T &x)
{
    char c = getchar(); x = 0;while(!isdigit(c)) c = getchar();
    while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar();  }
}

const int maxn=300005;
const LL mod=998244353;
const LL G=3;

LL t[maxn],a[maxn],b[maxn],c[maxn],f[maxn],fac[maxn],NI[maxn];
int T,n,m;
LL rev[maxn],N,len,inv;

LL POW[maxn],NiPOW[maxn];

LL power(LL x,LL y)
{
    LL res=1;
    for(;y;y>>=1,x=(x*x)%mod)
    {
        if(y&1)res=(res*x)%mod;
    }
    return res;
}

void init()
{
    while((n+m)>=(1<<len))len++;
    N=(1<<len);
    inv=power(N,mod-2);
    for(int i=0;i<N;i++)
    {
        LL pos=0;
        LL temp=i;
        for(int j=1;j<=len;j++)
        {
            pos<<=1;pos |= temp&1;temp>>=1;
        }
        rev[i]=pos;
    }
}

void ntt(LL *a,LL n,LL re)
{
    for(int i=0;i<n;i++)
    {
        if(rev[i]>i)
        {
            swap(a[i],a[rev[i]]);
        }
    }
    for(int i=2;i<=n;i<<=1)
    {
        int mid=i>>1;

        LL wn=power(G,(mod-1)/i);
        if(re) wn=power(wn,(mod-2));
        for(int j=0;j<n;j+=i)
        {
            LL w=1;
            for(int k=0;k<mid;k++)
            {
                int temp1=a[j+k];
                int temp2=(LL)a[j+k+mid]*w%mod;
                a[j+k]=(temp1+temp2);if(a[j+k]>=mod)a[j+k]-=mod;
                a[j+k+mid]=(temp1-temp2);if(a[j+k+mid]<0)a[j+k+mid]+=mod;
                w=(LL)w*wn%mod;
            }
        }
    }
    if(re)
    {
        for(int i=0;i<n;i++)
        {
            a[i]=(LL)a[i]*inv%mod;
        }
    }
}

bool cmp(LL a,LL b) {return a>b;}

LL extend_gcd(LL a,LL b,LL &x,LL &y)
{
    if(a==0&&b==0) return -1;
    if(b==0){x=1;y=0;return a;}
    LL d=extend_gcd(b,a%b,y,x);
    y-=a/b*x;
    return d;
}

LL mod_reverse(LL a,LL n)
{
    LL x,y;
    LL d=extend_gcd(a,n,x,y);
    if(d==1) return (x%n+n)%n;
    else return -1;
}

int main()
{
    fac[0]=1; for(int i=1;i<=100000;i++) fac[i]=(LL)i*fac[i-1]%mod;
    for(int i=0;i<=100000;i++) NI[i]=mod_reverse(fac[i],mod);
    POW[0]=1; for(int i=1;i<=100000;i++) POW[i]=(LL)2*POW[i-1]%mod;
    for(int i=0;i<=100000;i++) NiPOW[i]=mod_reverse(POW[i],mod);

    scanf("%d",&T); while(T--)
    {
        len=0;  memset(c,0,sizeof c); memset(a,0,sizeof a); memset(b,0,sizeof b);

        scanf("%d",&n); m=n;
        for(int i=1;i<=n;i++) { int x; scanf("%d",&x); t[i]=(LL)x; }

        sort(t+1,t+1+n,cmp);
        for(int i=0;i<n;i++)
        {
            LL x=fac[n]*NI[i]%mod;
            a[i]=x*POW[n-i]%mod;
        }
        for(int i=1;i<=n;i++) b[n-i]=t[i]*fac[i-1]%mod;

        init(); ntt(a,N,0); ntt(b,N,0);
        for(int i=0;i<=N;i++) c[i]=a[i]*b[i]%mod;
        ntt(c,N,1);

        for(int i=0;i<n;i++) f[n-i]=c[i]*NI[n]%mod;
        for(int i=1;i<=n;i++) f[i]=f[i]*NI[i-1]%mod;
        for(int i=1;i<=n;i++) f[i]=f[i]*NiPOW[i]%mod;
        LL ans=0; for(int i=1;i<=n;i++) { ans=(ans+f[i])%mod; printf("%lld ",ans); }
        printf("
");
    }
    return 0;
}
原文地址:https://www.cnblogs.com/zufezzt/p/5768565.html