NTT 练习

一 . Rikka with Subset 

题目: http://acm.hdu.edu.cn/showproblem.php?pid=5829

  

 

参考   https://blog.csdn.net/hdxrie/article/details/80961416?utm_source=blogxgwz3

#include <iostream>
#include <cstdio>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <deque>
#include <vector>
#include <queue>
#include <string>
#include <cstring>
#include <map>
#include <stack>
#include <set>
#define LL long long
#define ULL unsigned long long
#define rep(i,j,k) for(int i=j;i<=k;i++)
#define dep(i,j,k) for(int i=k;i>=j;i--)
#define INF 0x3f3f3f3f
#define mem(i,j) memset(i,j,sizeof(i))
#define make(i,j) make_pair(i,j)
#define pb push_back
using namespace std;
const LL p = 998244353, N = (1 << 18) + 5, G = 3, Gi = 332748118;
LL ksm(LL a,LL b) {
    LL ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return ans;
}
int n, m;
LL limit, a[N], b[N], r[N], l;
void NTT(LL *A, int type) {
    rep(i, 0, limit - 1) if(i < r[i]) swap(A[i], A[r[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) {
        LL Wn = ksm(type == 1 ? G : Gi, (p - 1) / (mid << 1) );
        for(int j = 0;j < limit; j += (mid << 1)) {
            LL w = 1;
            for(int k = 0; k < mid; k++, w = (w * Wn) % p) {
                int x = A[j + k], y = w * A[j + k + mid] % p;
                A[j + k] = ( x + y ) % p;
                A[j + k + mid] = (x - y + p) % p;
            }
        }
    }
    if (type == -1) {
        LL inv = ksm(limit, p - 2);
        for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * inv % p;
    }
}
LL inv[N], fac[N], invfac[N];
void init() {
    invfac[0] = fac[0] = inv[1] = fac[1] = invfac[1] = 1LL;
    rep(i, 2, N - 1) {
        fac[i]=(fac[i-1]*i)%p;
        inv[i] = (p - p / i) * inv[p % i] % p;
        invfac[i] = (invfac[i - 1] * inv[i]) % p;
    }
}
LL A[N];
int main() {
    init();
    int t;
    scanf("%d", &t);
    while( t-- ) {
        scanf("%d", &n);
        mem(a, 0); mem(b, 0);
        for (limit = 1, l = 0; limit <= (n << 1); l++, limit <<= 1);
        rep(i, 0, limit - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        rep(i, 0, n - 1) scanf("%lld", &A[i]);
        sort(A, A + n,greater<LL>());
        rep(i, 0, n - 1) a[i] = ksm(2, n - i) * invfac[i] % p;
        rep(i, 0, n - 1) b[i] = fac[i] * A[i] % p;
        reverse(b, b + n);
        NTT(a, 1); NTT(b, 1);
        rep(i, 0, limit - 1) a[i] = a[i] * b[i] % p;
        NTT(a, -1);
        LL ans = 0LL, preans = 0LL; LL coe = inv[2];
        rep(i, 1, n) {
            ans = coe * invfac[i-1] % p * a[n-i] % p;
            ans = (ans + preans) % p;
            printf("%lld ",ans);
            coe = coe * inv[2] % p; swap(ans, preans);
        }
        puts("");
    }
    return 0;
}
/*#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN (1<<18)+5
#define MOD 998244353LL
#define g 3LL
using namespace std;
int n,m,L,T,A[MAXN],rev[MAXN];
long long inv[MAXN],fac[MAXN],invfac[MAXN];
long long a[MAXN],b[MAXN];
long long ans_i,pre_ans_i,coe;

inline bool cmp(long long a,long long b){return a>b;}

inline long long Quick_MOD(long long a,long long b)
{
    long long res=1,base=a;
    while (b)
    {
        if (b&1) res=(res*base)%MOD;
        base=(base*base)%MOD;
        b>>=1;
    }
    return res;
}

inline void NTT(long long c[],int n,int f)
{
    long long w,wn,x,y;
    for (int i=0;i<n;i++)
        if (i<rev[i]) swap(c[i],c[rev[i]]);
    for (int i=1;i<n;i<<=1)
    {
        wn=Quick_MOD(g,(MOD-1)/(i<<1));
        if (!~f) wn=Quick_MOD(wn,MOD-2);
        for (int p=i<<1,j=0;j<n;j+=p)
        {
            w=1LL;
            for (int k=0;k<i;k++,w=w*wn%MOD)
            {
                x=c[j+k];y=c[j+k+i]*w%MOD;
                c[j+k]=(x+y)%MOD;c[j+k+i]=(x-y+MOD)%MOD;
            }
        }
    }
    if (!~f)
        for (int i=0;i<n;i++) c[i]=c[i]*inv[n]%MOD;
    return ;
}

inline void PreWork()
{
    invfac[0]=fac[0]=inv[1]=fac[1]=invfac[1]=1LL;
    for (int i=2;i<MAXN;i++)
    {
        fac[i]=(fac[i-1]*i)%MOD;
        inv[i]=(MOD-MOD/i)*inv[MOD%i]%MOD;
        invfac[i]=(invfac[i-1]*inv[i])%MOD;
    }
    return ;
}

inline void read(int &x)
{
    x=0;char ch=getchar();
    while (ch<'0'||ch>'9') ch=getchar();
    while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return ;
}

int main()
{
    PreWork();
    read(T);
    while (T--)
    {
        memset(a,0,sizeof a);memset(b,0,sizeof b);
        read(n);
        for (m=1,L=0;m<=(n<<1);L++,m<<=1);
        for (int i=0;i<m;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
        for (int i=0;i<n;i++) read(A[i]);
        sort(A,A+n,cmp);
        for (int i=0;i<n;i++) a[i]=Quick_MOD(2,n-i)*invfac[i]%MOD;
        for (int i=0;i<n;i++) b[i]=fac[i]*(long long)A[i]%MOD;
        reverse(b,b+n);
        NTT(a,m,1);NTT(b,m,1);
        for (int i=0;i<m;i++) a[i]=a[i]*b[i]%MOD;
        NTT(a,m,-1);
        ans_i=pre_ans_i=0LL;coe=inv[2];
        for (int i=1;i<=n;i++)
        {
            ans_i=coe*invfac[i-1]%MOD*a[n-i]%MOD;
            ans_i=(ans_i+pre_ans_i)%MOD;
            printf("%lld ",ans_i);
            coe=coe*inv[2]%MOD;swap(ans_i,pre_ans_i);
        }
        putchar('
');
    }
    return 0;
}*/
View Code

二 . 序列统计

题目描述

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。

输入输出格式

输入格式:

一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。

输出格式:

一行,一个整数,表示你求出的种类数mod 1004535809的值。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<cctype>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<stack>
#include<queue>
#include<map>
#define size 3000010
#define ll long long
#define db double
#define il inline
#define rint register int
#define gc getchar()
#define rep(i,s,n) for (register int i=s;i<=n;i++)
#define drep(i,n,s) for (register int i=n;i>=s;--i)
#ifdef WIN32
#else
#define ld "%lld"
#endif
#define Mod 1004535809

using namespace std;

il ll r()
{
    char c; ll x,f=1;
    for (c=gc;!isdigit(c);c=gc) if (c=='-') f=-1; x=c-'0';
    for (c=gc;isdigit(c);c=gc) x=x*10+c-'0'; return x*f;
}

ll F1[size],F0[size],c[size],A[size],po[size],pv[size];
int Lg,L,l1,l2,rev[size],n,m,S,X,g,pg[size];

il ll ksm(ll x,ll y,ll mod)
{
    ll res=1;
    while (y)
    {
        if (y&1) res=(res*x)%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}

int G(int s)
{
    int q[1010]={0};
    rep(i,2,s-2) if ((s-1)%i==0) q[++q[0]]=i;
    for (int i=2;;i++)
    {
        bool B=1;
        for (int j=1;j<=q[0]&&B;j++) if (ksm(i,q[j],s)==1) B=0;
        if (B) return i;
    }
    return -1;
}

void Rader(int tmp)
{
    Lg=0,L=1; while (L<tmp) L<<=1,Lg++; L<<=1,Lg++;
    rep(i,0,L-1)
        for (int t=i,j=1;j<=Lg;j++)
            rev[i]<<=1,rev[i]|=t&1,t>>=1;
    ll I=ksm(3,Mod-2,Mod);
    for (int i=1;i<=L;i<<=1) po[i]=ksm(3,(Mod-1)/i,Mod),pv[i]=ksm(I,(Mod-1)/i,Mod);
}

void dft(ll F[],int sgn)
{
    rep(i,0,L-1) A[i]=F[rev[i]];
    rep(i,0,L-1) F[i]=A[i];
    for (int i=2;i<=L;i<<=1)
    {
        ll wi=po[i]; if (sgn==-1) wi=pv[i];
        for (int k=0;k<L;k+=i)
        {
            ll ww=1,x=0,y=0;
            rep(j,0,i/2-1)
            {
                x=F[k+j]; y=ww*F[k+j+i/2]%Mod;
                F[k+j]=(x+y)%Mod; F[i/2+j+k]=(x-y+Mod)%Mod;
                ww=(ww*wi)%Mod;
            }
        }
    }
    if (sgn==-1)
    for (ll I=ksm(L,Mod-2,Mod),i=0;i<L;i++) F[i]=(F[i]*I)%Mod;
}

void Ksm(int y)
{
    F1[0]=1;
    while (y)
    {
        dft(F0,1);
        if (y&1)
        {
            dft(F1,1); rep(i,0,L-1) F1[i]=(F1[i]*F0[i])%Mod;
            dft(F1,-1);
            drep(i,L-1,m-1)  F1[i-m+1]=(F1[i-m+1]+F1[i])%Mod,F1[i]=0;
        }
        rep(i,0,L-1) F0[i]=(F0[i]*F0[i])%Mod;
        dft(F0,-1);
        drep(i,L-1,m-1)  F0[i-m+1]=(F0[i-m+1]+F0[i])%Mod,F0[i]=0;
        y>>=1;
    }
}

int main()
{
    n=r(); m=r(); X=r(); S=r(); 
    g=G(m);
    ll Q=1,qx;
    rep(i,1,m-2) Q=Q*g%m,pg[Q]=i;
    rep(i,1,S)
    {
        qx=r(); if (qx) F0[pg[qx]]=1;
    }
    Rader(m);
    Ksm(n);printf("%lld
",F1[pg[X]]);
    return 0;
}
View Code
一步一步,永不停息
原文地址:https://www.cnblogs.com/Willems/p/10984929.html