[CF286E] Ladies' shop

Description

给出 (n)(leq m) 且不同的数 (a_1,dots,a_n),现在要求从这 (n) 个数中选出最少的数字,满足这 (n) 个数字都可以由选出的数字组合成(就是做一个完全背包能做出来),并且任意组合出来的数字,只要不超过 (m),就必须让这个数字在给出的 (n) 个数中。问是否可行,如果可行,请求出最少选多少数字。 (n,mleq 10^6)

Sol

先判断是否可行,再看哪些数可以省略。

求出 (a) 数组的生成函数,即构造多项式 (F(x)=sum f_icdot x^i)(f_i)(1) 当且仅当 (a) 数组中出现 (a_*=i)

然后求出 (G(x)=F^2(i)=sum g_icdot x^i)。如果 (g_i>0) 那就说明给出的这 (n) 个数可以合成 (i)

于是就得到了从原来的 (n) 个数中拿出 (0sim 2) 个的结果。

然而最多拿出 (m) 个。

所以还要继续,用快速幂求得 (f^m)。如果多项式快速幂的话,复杂度 (O(nlog^2n)),用多项式ln+多项式exp求的话,复杂度 (O(nlog n))。但是多项式exp常数太大了!

事实上是有只做 (1) 次FFT的方法的。

显然如果 (f_i>0) 的话,(g_i>0)

那我们只要保证满足 (f_i=0,g_i>0,ileq m)(i) 不存在就好了。

如果第一轮不存在这些不合法的,那接下来肯定也不存在。感性理解一下这就相当于构成了一个封闭的集合。

所以只做 (1) 次FFT就行了。

然后考虑一下哪些数可以省略

如果一个数 (i) 可以被其他数表示出来,那 (g_i) 一定 (>2)。所以 (g_i=2)(i) 就是必选的。

时间复杂度 (O(nlog n))

Sol

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e6+5;
const int mod=998244353;

int lim,rev[N];
int n,m,a[N],b[N];

int ksm(int a,int b=mod-2,int ans=1){
    while(b){
        if(b&1) ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;b>>=1;
    } return ans;
}

int getint(){
    int X=0,w=0;char ch=getchar();
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

void ntt(int *f,int g){
    for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int tmp=ksm(g,(mod-1)/(mid<<1));
        for(int R=mid<<1,j=0;j<lim;j+=R){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
                int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
            }
        }
    } if(g>3) 
        for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}

signed main(){
    n=getint(),m=getint();
    for(int i=1;i<=n;i++){
        int x=getint();
        a[x]=b[x]=1;
    } 
    lim=1;while(lim<=m+m) lim<<=1;
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
    a[0]=1; ntt(a,3);
    for(int i=0;i<lim;i++) a[i]=1ll*a[i]*a[i]%mod;
    ntt(a,(mod+1)/3);
    for(int i=1;i<=m;i++)
        if(a[i] and !b[i]) return printf("NO"),0;
    puts("YES"); int tot=0;
    for(int i=1;i<=m;i++)
        if(a[i]==2) tot++;
    printf("%d
",tot);
    for(int i=1;i<=m;i++)
        if(a[i]==2) printf("%d ",i);
    return 0;
}

原文地址:https://www.cnblogs.com/YoungNeal/p/10360660.html