codeforces1096G Lucky Tickets

题目链接:https://codeforces.com/problemset/problem/1096/G

大意:给出(k)个数码(d_1,d_2,cdots,d_k),构造一个由这(k)个数码组成的(n)位数(可重复使用数码),使得该数的前(frac{n}{2})位数码之和等于后(frac{n}{2})位数码之和,求方案数

分析:转化一下题意就是说构造(frac{n}{2})位数,求构成数的各位数字之和的方案数,最后乘法原理乘一下即可

如果满足(nleq1000)的话跑完全背包即可,然而数据放到了(2·10^5)

我们考虑一下如下的生成函数(其实不能称作标准的生成函数):((x^{d_1}+x^{d_2}+cdots+x^{d_k})^{frac{n}{2}})

我们将它展开,每一项的次数就是表示出来的数的各位上的数码之和,系数就表示方案数(不就是完全背包么

由于项数最大也就是(2·10^6),直接NTT+快速幂即可

#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
using namespace std;
#define rep(i,a,b) for (i=a;i<=b;i++)
typedef long long ll;
#define maxd 998244353
#define pi acos(-1.0)
#define N 2000000
#define int long long 
ll a[5004000],b[5000400];
int n,r[5004000],k,lim=1,cnt=0;

int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

int qpow(int x,int y)
{
    int ans=1,sum=x;
    while (y)
    {
        int tmp=y%2;y/=2;
        if (tmp) ans=(1ll*ans*sum)%maxd;
        sum=(1ll*sum*sum)%maxd;
    }
    return ans;
}

void ntt(int lim,ll *a,int typ)
{
    int i;
    for (i=0;i<lim;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    int mid;
    for (mid=1;mid<lim;mid<<=1)
    {
        int gn=qpow(3,(maxd-1)/(mid<<1));
        int sta,len=mid<<1,j;
        for (sta=0;sta<lim;sta+=len)
        {
            int g=1;
            for (j=0;j<mid;j++,g=(g*gn)%maxd)
            {
                int x1=a[j+sta],y1=(g*a[j+sta+mid])%maxd;
                a[j+sta]=(x1+y1)%maxd;
                a[j+sta+mid]=(x1-y1+maxd)%maxd;
            }
        }
    }
    if (typ==-1) reverse(&a[1],&a[lim]);
}

void init()
{
    n=read();k=read();
    //memset(a,0,sizeof(a));
    int i;n/=2;
    for (i=1;i<=k;i++)
        {int x=read();a[x]=1;}
    while (lim<=N) {lim<<=1;cnt++;}
    for (i=0;i<=lim;i++)
        r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
}

void work()
{
    ntt(lim,a,1);int i;
    for (i=0;i<lim;i++) a[i]=qpow(a[i],n);
    ntt(lim,a,-1);
    ll ans=0,tmp=qpow(lim,maxd-2);
    for (i=0;i<=N;i++) 
    {
        a[i]=(a[i]*tmp)%maxd;
        ans=(ans+1ll*a[i]*a[i])%maxd;
    }
    printf("%lld",ans);
}

signed main()
{
    init();
    work();
    return 0;
}
原文地址:https://www.cnblogs.com/encodetalker/p/10299366.html