题目大意#
给 (n) 块白木板, 高度为 (a_i), (k) 块红木板, 高度为 (b_i), 问在周长为 (q_i) 的情况下, 有多少种取法使得任取任意数量的白木板和一块红木板组成两个严格递增数列, 要求红木板的高度为最大.
思路#
其实从组合的方向上去看还算简单
先统计每个不同高度的白木板的个数, 记为 (cnt[a_i]).
然后从 (k) 个红木板中选定一块红木板, 高度为 (L).
统计高度比 (L) 小的白木板的个数情况:
如果 (cnt[i]==1), 计入出现一次的个数 (na), 当然 (na++)
如果 (cnt[i]>=2), 计入出现两次的个数 (nb), 当然 (nb+=2), 比 (2) 多直接当 (2) 来看, 根据题意不可能有 (3) 个同样高度的木板出现
如果要从 (na) 中取 (i) 块, 那么形成的方案数是 (a_i=C^{i}_{na}cdot 2^{i}) ((na) 块中取 (i) 块, 每块都可以放在 (2) 个数列中一个)
如果要从 (nb) 中取 (i) 块, 那么形成的方案数是 (b_i=C^{i}_{nb})
那么如果总共要取 (m) 块, 总共就会形成 (c_{m}=sum ^{1n}_{i=0}a_{i}b_{m-i}) 种方案.
(m) 块白板 + (1) 块高 (L) 的红板, 形成的多边形周长就是 (2cdot(L+1+m)) 啦,
跑完每一个红木板, 将周长对应的答案预处理一下, 直接回答每一个问题就行了.
这样就可以写出一个 (O(n^2)) 的算法:肯定会超时的啦
LL calc(int m,int na,int nb){
LL res=0;
for(int i=0;i<=m;i++) (res+=C(na,i)*p2[i]%mod*C(nb,m-i))%=mod;
return res;
}
int main(){
n=read();k=read();
fac[0]=invi[0]=1;invi[1]=1;p2[0]=1;
for(int i=2;i<=n;i++) invi[i]=invi[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=n;i++)
fac[i]=fac[i-1]*i%mod,
p2[i]=p2[i-1]*2%mod,
invi[i]=invi[i]*invi[i-1]%mod;
for(int i=1;i<=n;i++) cnt[read()]++;
for(int i=1;i<=k;i++){
int l=read();
int na=0,nb=0;
for(int i=l-1;i>0;i--)
if(cnt[i]>=2) nb+=2;
else if(cnt[i]==1) na++;
for(int i=0;i<=na+nb;i++)
(ans[(l+1+i)*2]+=calc(i,na,nb))%=mod;
}
q=read();
while(q--) printf("%lld
",ans[read()]);
return 0;
}
(NTT)#
其实, (c_{m}=sum ^{1n}_{i=0}a_{i}b_{m-i}) 这个式子是一个多项式乘法的一部分,
也就是说这其实是一个卷积式.
我们可以通过计算两个多项式的乘积求出所有 (c_i), 令:
那么 (C(x)) 的各项系数就是要求的 (c_m) 了
卷积的快速算法就是 (FFT) 和 (NTT)
我也是为了过这道题大晚上的肝完了这两个算法
这个题要求取模并且这个模数就是 (NTT) 支持的模数
上 (NTT) 就可以过了
完整的 (AC) 代码
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAXN 1200010
using namespace std;
typedef long long LL;
const int mod=998244353;
int n,k,q,len,bit,rev[MAXN];
LL A[MAXN],B[MAXN],cnt[MAXN];
LL fac[MAXN],invi[MAXN],p2[MAXN];
LL ans[MAXN];
int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
return x*f;
}
LL C(int n,int m) {
if(m>n) return 0;
return fac[n]*invi[m]%mod*invi[n-m]%mod;
}
LL qpow(LL x,LL k){
LL res=1;
while(k){
if(k&1) res=res*x%mod;
x=x*x%mod;
k>>=1;
}
return res;
}
void NTT(LL *a,int opt){
for(int i=0;i<len;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<len;mid<<=1){
LL wn=qpow(3,(mod-1)/(mid*2));
if(opt==-1) wn=qpow(wn,mod-2);
for(int i=0;i<len;i+=mid*2){
LL w=1;
for(int j=0;j<mid;j++,w=w*wn%mod){
LL x=a[i+j],y=w*a[i+j+mid]%mod;
a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
}
}
}
}
int main(){
n=read();k=read();
fac[0]=invi[0]=1;invi[1]=1;p2[0]=1;
for(int i=2;i<=n;i++) invi[i]=invi[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=n;i++)
fac[i]=fac[i-1]*i%mod,
p2[i]=p2[i-1]*2%mod,
invi[i]=invi[i]*invi[i-1]%mod;
for(int i=1;i<=n;i++) cnt[read()]++;
for(int i=1;i<=k;i++){
int l=read();
int na=0,nb=0;
memset(A,0,sizeof(A));
memset(B,0,sizeof(B));
for(int i=l-1;i>0;i--)
if(cnt[i]>=2) nb+=2;
else if(cnt[i]==1) na++;
for(int i=0;i<=na;i++) A[i]=C(na,i)*p2[i];
for(int i=0;i<=nb;i++) B[i]=C(nb,i);
len=1,bit=0;
while(len<=na+nb) len<<=1,bit++;
for(int i=0;i<len;i++) rev[i]=((rev[i>>1]>>1) | ((i&1)<<(bit-1)));
NTT(A,1);NTT(B,1);
for(int i=0;i<len;i++) A[i]=A[i]*B[i]%mod;
NTT(A,-1);
LL inv=qpow(len,mod-2);
for(int i=0;i<=na+nb;i++) A[i]=A[i]*inv%mod;
for(int i=0;i<=na+nb;i++)
(ans[(l+i+1)*2]+=A[i])%=mod;
}
q=read();
while(q--) printf("%lld
",ans[read()]);
return 0;
}
最后感谢大佬的题解 http://www.cnblogs.com/NaVi-Awson/