BZOJ3771: Triple

n<=40000个<=40000的不同数字,问选一个或两个或三个,凑成每个值的方案数。

选东西,总数加起来为某值的方案数--生成函数,$f$表示选一个的,$g$表示两个一样的,$h$表示三个一样的(等会去重要用)。

选一个:$f$

选两个:$frac{f^2-g}{2}$

选三个:$frac{f^3-3*(g*f-h)-h}{6}$。

直接用点值算完再换成系数。

 1 //#include<iostream>
 2 #include<cstring>
 3 #include<cstdlib>
 4 #include<cstdio>
 5 //#include<map>
 6 #include<math.h>
 7 //#include<time.h>
 8 //#include<complex>
 9 #include<algorithm>
10 using namespace std;
11  
12 int n,m,wei;
13 #define maxn 300011
14 const int mod=998244353,G=3;
15  
16 int powmod(int a,int b)
17 {
18     int ans=1;
19     while (b)
20     {
21         if (b&1) ans=1ll*ans*a%mod;
22         a=1ll*a*a%mod; b>>=1;
23     }
24     return ans;
25 }
26  
27 int rev[maxn];
28 void dft(int *a,int n,int type)
29 {
30     if (!rev[1]) for (int i=0;i<n;i++)
31     {
32         rev[i]=0;
33         for (int j=0;j<wei;j++) rev[i]|=((i>>j)&1)<<(wei-j-1);
34     }
35     for (int i=0;i<n;i++) if (i<rev[i]) {int t=a[i]; a[i]=a[rev[i]]; a[rev[i]]=t;}
36     for (int i=1;i<n;i<<=1)
37     {
38         int w=powmod(G,(mod-1)/(i<<1));
39         if (type==-1) w=powmod(w,mod-2);
40         for (int j=0,p=i<<1;j<n;j+=p)
41         {
42             int t=1;
43             for (int k=0;k<i;k++,t=1ll*t*w%mod)
44             {
45                 int tmp=1ll*t*a[j+k+i]%mod;
46                 a[j+k+i]=(a[j+k]-tmp+mod)%mod;
47                 a[j+k]=(a[j+k]+tmp)%mod;
48             }
49         }
50     }
51     if (type==-1)
52     {
53         int inv=powmod(n,mod-2);
54         for (int i=0;i<n;i++) a[i]=1ll*a[i]*inv%mod;
55     }
56 }
57  
58 void ntt(int *a,int *b,int *c)
59 {
60     dft(a,n,1); dft(b,n,1);
61     for (int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%mod;
62     dft(c,n,-1); 
63 }
64  
65 int f[maxn],g[maxn],s[maxn],h[maxn],ans[maxn];
66 int main()
67 {
68     scanf("%d",&n); int Max=0;
69     for (int i=1,x;i<=n;i++) scanf("%d",&x),Max=x,f[x]=1,g[x+x]=1,s[x+x+x]=1;
70     m=Max*3; for (n=1,wei=0;n<=m;n<<=1,wei++);
71      
72     dft(f,n,1); dft(g,n,1); dft(s,n,1);
73     for (int i=0,tmp6=((mod+1)/6),tmp2=((mod+1)>>1);i<n;i++)
74     ans[i]=(((1ll*f[i]*f[i]%mod*f[i]%mod-3ll*g[i]*f[i]%mod+2*s[i])%mod*tmp6%mod
75     +(1ll*f[i]*f[i]%mod-g[i])%mod*tmp2%mod+f[i])%mod+mod)%mod;
76     dft(ans,n,-1);
77     for (int i=0;i<=m;i++) if (ans[i]>0) printf("%d %d
",i,ans[i]);
78     return 0;
79 }
View Code
原文地址:https://www.cnblogs.com/Blue233333/p/8476162.html