[PKUSC2018]最大前缀和

题目

只会(O(2^nn^2))的暴力子集卷积啊

首先第一反应是算贡献,我们先求出每一个子集的子集和(sum_i),之后考虑(i)这个子集在多少种排列中成为了最大前缀和

由于一个排列的最大前缀和可能有好几个,于是我们强行规定最大前缀和为最大且出现位置最靠前的前缀和

如果我们能求出一个(dp_i)表示(i)这个集合有多少种排列使得(i)就是最大前缀和,我们只需要让剩下的数组成的排列在任何时候前缀和都不大于0就好了

(f_i)表示(i)这个集合有多少种排列在任何时刻前缀和都不大于(0),这样的话答案就是(sum_{isubset S }sum_i imes dp_i imes f_{Sigoplus i})

这个(f)(O(n2^n)) 的时间内就很容易求出来,现在的问题就是求出(dp)

随便胡了一个子集卷积的做法发现会算重,于是我们考虑正难则反,我们算一下(i)有多少个排列使得最大前缀和不是(i),之后那(|i|!)一减就好了

显然我们可以枚举一个子集(t)成为最大前缀和,之后让后面不能有更大的前缀和就好了,于是让剩下的数排列在任何时候前缀和不大于(0)就好了

于是(dp_i=|i|!-sum_{tsubset i}dp_t imes f_{tigoplus i})

显然这个是一个子集卷积的形式我们可以强行上(fwt)优化成(O(2^nn^2)),由于我们中间过程还要卷回来,所以常数很大,卡卡常就好了

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
const int mod=998244353;
const int maxn=(1<<20)+5;
int sum[maxn];int a[maxn];
int g[21][maxn],dp[21][maxn],cnt[maxn];
int n,len,fac[21],st[21][maxn>>1],tp[21];
inline int qm(int a) {return a>=mod?a-mod:a;}
inline int sqm(int a) {return a<0?a+mod:a;}
inline void Fwt(int *f) {
    for(re int ln=1,i=2;i<=len;i<<=1,ln=i>>1)
        for(re int l=0;l<len;l+=i)
            for(re int x=l;x<l+ln;++x)
                f[x+ln]=qm(f[x+ln]+f[x]); 
}
inline void Ifwt(int *f) {
	for(re int ln=1,i=2;i<=len;i<<=1,ln=i>>1)
		for(re int l=0;l<len;l+=i)
			for(re int x=l;x<l+ln;++x)
				f[x+ln]=sqm(f[x+ln]-f[x]);
}
int main() {
    scanf("%d",&n);len=(1<<n);fac[0]=1;
    for(re int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
    for(re int i=0;i<n;i++) scanf("%d",&a[i]);
    for(re int i=1;i<len;i++) 
        for(re int j=0;j<n;j++)
            if((1<<j)&i) sum[i]=(a[j]+sum[i])%mod;
    for(re int i=1;i<len;i++) cnt[i]=cnt[i>>1]+(i&1);
    for(re int i=1;i<len;i++) st[cnt[i]][++tp[cnt[i]]]=i;
    g[0][0]=1;
    for(re int i=0;i<len;i++) {
        if(sum[i]>0) continue;
        for(re int j=0;j<n;j++) {
            if(i&(1<<j)) continue;
            if(sum[i|(1<<j)]<=0) 
                g[cnt[i]+1][i|(1<<j)]=qm(g[cnt[i]+1][i|(1<<j)]+g[cnt[i]][i]);
        }
    }
    for(re int i=0;i<=n;i++) Fwt(g[i]);
    for(re int i=0;i<n;i++) dp[1][1<<i]=1;
    Fwt(dp[1]);
    for(re int i=2;i<=n;i++) {
        for(re int j=1;j<i;j++)
            for(re int k=0;k<len;k++)
                dp[i][k]=qm(dp[i][k]+1ll*g[j][k]*dp[i-j][k]%mod);
        Ifwt(dp[i]);
        for(re int j=1;j<=tp[i];j++) dp[i][st[i][j]]=sqm(fac[i]-dp[i][st[i][j]]);
        Fwt(dp[i]);
    }
    for(re int i=1;i<=n;i++) Ifwt(dp[i]);
    for(re int i=1;i<len;i++) if(sum[i]<0) sum[i]=(sum[i]+mod)%mod;
    int ans=0;len--;
    for(re int i=1;i<=len;i++)
        ans=qm(ans+1ll*sum[i]*dp[cnt[i]][i]%mod*g[n-cnt[i]][len^i]%mod);
    printf("%d
",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/asuldb/p/10970038.html