牛客挑战赛39 D 牛牛的数学题 NTT FMT FWT

LINK:牛牛的数学题

题目看起来很不可做的样子。

但是 不难分析一下i,j之间的关系。

对于x=i|j且i&j==0, i,j一定是x的子集 我们可以暴力枚举子集来处理x这个数组。

考虑 x+k这个东西 对于一个y来说 x(0->y) k(0->y)容易发现这是一个NTT.

对于最外层^h 显然是FWT_xor 注意 FWT_xor 是 a0=a0+a1 a1=a0-a1.IFWT_xor a0=(a0+a1)>>1.a1=(a0-a1)>>1.

这点要熟记。

值得一提的是 题解在第一部中使用的是FMT不过我不会 所以暴力枚举子集了。

这里有一个trik 子集枚举是从大到小的 可以只枚举一半 这样可以优化一下复杂度。

3^17/2什么的 跑的还挺快。重要的一点:该取模的地方一定要取模 不然要调好久才能看出来。

const int MAXN=1<<17,G=3;
int n,lim,INV2,Q;
int mark[MAXN],rev[MAXN<<1];
int A[MAXN],B[MAXN],C[MAXN<<1],D[MAXN<<1],S[MAXN<<1];
inline void solve_AB()
{
	S[0]=(ll)A[0]*B[0]%mod;
	for(int i=1;i<lim;++i)
	{
		for(int j=i;j;j=i&(j-1))
		{
			if(mark[j]==i)break;
			mark[j]=mark[i^j]=i;
			S[i]=(S[i]+(ll)A[j]*B[i^j]+(ll)A[i^j]*B[j])%mod;
		}
	}
}
inline int ksm(int b,int p)
{
	int cnt=1;
	while(p)
	{
		if(p&1)cnt=(ll)cnt*b%mod;
		b=(ll)b*b%mod;p=p>>1;
	}
	return cnt;
}
inline void NTT(int *a,int op)
{
	rep(1,lim-1,i)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int len=2;len<=lim;len=len<<1)
	{
		int mid=len>>1;
		int wn=ksm(G,op==1?(mod-1)/len:mod-1-(mod-1)/len);
		for(int j=0;j<lim;j+=len)
		{
			ll d=1;
			for(int i=0;i<mid;++i)
			{
				int x=a[i+j],y=a[i+j+mid]*d%mod;
				a[i+j]=(x+y)%mod;a[i+j+mid]=(x-y+mod)%mod;
				d=d*wn%mod;
			}
		}
	}
	if(op==-1)
	{
		int INV=ksm(lim,mod-2);
		rep(0,lim-1,i)a[i]=(ll)a[i]*INV%mod;
	}
}
inline void solve_SC()
{
	while(lim<=n+n)lim=lim<<1;
	rep(1,lim-1,i)rev[i]=rev[i>>1]>>1|((i&1)?lim>>1:0);
	NTT(S,1);NTT(C,1);
	rep(0,lim-1,i)C[i]=(ll)C[i]*S[i]%mod;
	NTT(C,-1);
}
inline void FWT_xor(int *a,int op)
{
	for(int len=2;len<=lim;len=len<<1)
	{
		int mid=len>>1;
		for(int j=0;j<lim;j+=len)
		{
			for(int i=0;i<mid;++i)
			{
				int x=a[i+j],y=a[i+j+mid];
				if(op==1)a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
				else a[i+j]=(ll)(x+y)*INV2%mod,a[i+j+mid]=(ll)(x-y+mod)*INV2%mod;
			}
		}
	}
}
inline void solve_CD()
{
	FWT_xor(C,1);FWT_xor(D,1);
	rep(0,lim-1,i)D[i]=(ll)D[i]*C[i]%mod;
	FWT_xor(D,-1);
}
int main()
{
	freopen("1.in","r",stdin);
	get(n);INV2=ksm(2,mod-2);
	rep(0,n,i)get(A[i]);
	rep(0,n,i)get(B[i]);
	rep(0,n,i)get(C[i]);
	rep(0,n,i)get(D[i]);
	lim=1;while(lim<=n)lim=lim<<1;
	solve_AB();
	solve_SC();
	solve_CD();
	get(Q);
	rep(1,Q,i)put(D[read()]);
	return 0;
}
原文地址:https://www.cnblogs.com/chdy/p/12726618.html