[题解] LuoguP3784 [SDOI2017]遗忘的集合

要mtt的题都是......

多补了几项就被卡了一整页......果然还是太菜了......


不说了......来看100分的做法吧......

如果做过付公主的背包,前面几步应该不难想,所以我们再来写一遍柿子。

首先令(c_i = [0,1])表示数(i)是否在集合中,那么(f)的生成函数就是

[F(x) = prodlimits_{i=1}^n (frac{1}{1-x^i})^{c_i} ]

乘法不太好做,我们两边(ln)一下,转化成加法

[ln F(x) = sumlimits_{i=1}^n c_i ln(frac{1}{1-x^i}) ]

我们想要右边的(ln)变得好看一点,这个柿子在付公主的背包里好像推过了......这里就不写了。

柿子是

[ln(frac{1}{1-x^i}) = sumlimits_{j=1}^{infty}frac{1}{j}x^{ij} ]

带上去再做一些变化

[ln F(x) = sumlimits_{i=1}^n c_i sumlimits_{j ge 1} frac{1}{j} x^{ij} ]

我们枚举(k = ij)

[ln F(x) = sumlimits_{k=1}^n x^k sumlimits_{i mid k} c_icdot frac{i}{k} ]

我们令(f'_k)表示(ln F(x))(i)次项系数( imes k),知道了(f'_k = sumlimits_{i mid k} c_i i)

那么我们要构造一组(c),使得答案的字典序最小。字典序这个东西有很强的可贪性......

我们肯定会先考虑较小的(i),然后我们又知道对于能被(i)整除的(j),也就是(i mid j),有(f'_j ge f'_i),因为能对(i)产生贡献的(c_k),对(j)也会产生贡献。

所以当我们把(f')求出来的时候,做一个类似筛法的东西,(forall j, imid j),让(f'_j)(f'_i)

最后剩下类非(0)(f'_i)(i)就是在一个集合内的数,这样复杂度是对的。

然而需要MTT,异常duliu......

得到了一个教训......代码中的limit不可以瞎开......

(Code:)(只会写辣鸡版本的MTT......)

#include <bits/stdc++.h>
using namespace std;
typedef long double db;
typedef long long ll;
const db PI=acos(-1.0);
const int N=8e5+10;
int mod;
inline int fpow(int x,int y){
	int ret=1; for(x%=mod;y;y>>=1,x=1ll*x*x%mod)
		if(y&1) ret=1ll*ret*x%mod;
	return ret;
}
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int sub(int x,int y){return x-y<0?x-y+mod:x-y;}
namespace Poly{
	struct cpl{
		db x,y;
		cpl operator + (cpl k1)const{return (cpl){x+k1.x,y+k1.y};}
		cpl operator - (cpl k1)const{return (cpl){x-k1.x,y-k1.y};}
		cpl operator * (cpl k1)const{return (cpl){x*k1.x-y*k1.y,x*k1.y+y*k1.x};}
	};
	int rev[N];
	void init(int n){
		for(int i=0;i<n;i++)
			rev[i]=rev[i>>1]>>1|((i&1)?n>>1:0);
	}
	void fft(cpl *f,int n,int flg){
		for(int i=0;i<n;i++) if(rev[i]<i) swap(f[i],f[rev[i]]);
		for(int len=2,k=1;len<=n;len<<=1,k<<=1){
			cpl wn=(cpl){cos(2*PI/len),flg*sin(2*PI/len)};
			for(int i=0;i<n;i+=len){
				cpl w=(cpl){1,0};
				for(int j=i;j<i+k;j++,w=w*wn){
					cpl tmp=w*f[j+k];
					f[j+k]=f[j]-tmp,f[j]=f[j]+tmp;
				}
			}
		}
		if(flg==-1) for(int i=0;i<n;i++)
			f[i].x/=n;
	}
	void mtt(int *a,int *b,int *c,int n){
		static cpl f[2][N],g[2][N],ans[3][N];
		for(int i=0;i<n;i++){
			f[0][i]=(cpl){(db)(a[i]>>15),0};
			f[1][i]=(cpl){(db)(a[i]&0x7fff),0};
			g[0][i]=(cpl){(db)(b[i]>>15),0};
			g[1][i]=(cpl){(db)(b[i]&0x7fff),0};
		}
		fft(f[0],n,1),fft(f[1],n,1),fft(g[0],n,1),fft(g[1],n,1);
		for(int i=0;i<n;i++){
			ans[0][i]=f[0][i]*g[0][i];
			ans[1][i]=f[0][i]*g[1][i]+f[1][i]*g[0][i];
			ans[2][i]=f[1][i]*g[1][i];
		}
		fft(ans[0],n,-1),fft(ans[1],n,-1),fft(ans[2],n,-1);
		#define normal(x) (((ll)((x)+0.5)%mod+mod)%mod)
		for(int i=0;i<n;i++){
			ll t1=(normal(ans[0][i].x)<<30ll)%mod;
			ll t2=(normal(ans[1][i].x)<<15ll)%mod,t3=normal(ans[2][i].x);
			c[i]=((t1+t2)%mod+t3)%mod;
		}
	}
	void dao(int *f,int n,int *G){
		static int F[N]; for(int i=0;i<=n;i++) F[i]=f[i];
		for(int i=1;i<=n;i++) G[i-1]=1ll*F[i]*i%mod; G[n]=0;
	}
	void jifen(int *f,int n,int *G){
		static int F[N]; for(int i=0;i<=n;i++) F[i]=f[i];
		for(int i=0;i<=n;i++) G[i+1]=1ll*F[i]*fpow(i+1,mod-2)%mod; G[0]=0;
	}
	void getinv(int *f,int n,int *G){
		if(n==1){G[0]=fpow(f[0],mod-2);return;}
		getinv(f,(n+1)>>1,G);
		static int F[N],H[N],H1[N];
		int limit=1; while(limit<=(n-1)*2)limit<<=1; init(limit);
		for(int i=0;i<n;i++) H[i]=G[i],F[i]=f[i];
		for(int i=n;i<limit;i++) H[i]=F[i]=G[i]=0;
		mtt(F,G,H1,limit);
		H1[0]=sub(2,H1[0]);
		for(int i=1;i<limit;i++) H1[i]=i<n?mod-H1[i]:0;
		for(int i=n;i<limit;i++) H1[i]=0;
		mtt(H,H1,G,limit);
		for(int i=n;i<limit;i++) G[i]=0;
	}
	void getln(int *f,int n,int *G){
		static int F[N],iF[N]; for(int i=0;i<n;i++) F[i]=f[i];
		getinv(F,n,iF),dao(F,n-1,F);
		int limit=1; while(limit<=(n-1)*2)limit<<=1; init(limit);
		mtt(F,iF,G,limit);
		jifen(G,n-1,G); for(int i=n;i<limit;i++) G[i]=0;
	}
}
int n,f[N],ans[N];
int main(){
	scanf("%d%d",&n,&mod);
	f[0]=1; for(int i=1;i<=n;i++) scanf("%d",&f[i]);
	Poly::getln(f,n+1,ans);
	for(int i=1;i<=n;i++) ans[i]=1ll*ans[i]*i%mod;
	for(int i=1;i<=n;i++)
		for(int j=i*2;j<=n;j+=i) ans[j]=sub(ans[j],ans[i]);
	int cnt=0;
	for(int i=1;i<=n;i++) if (ans[i]) ++cnt;
	printf("%d
",cnt);
	for(int i=1;i<=n;i++) if(ans[i]) printf("%d ",i);
	return 0;
}
原文地址:https://www.cnblogs.com/wxq1229/p/12288496.html