FWT 等总结 题解


FWT可以解决位运算卷积问题。
(h(i)=sumlimits_{j⊕k=i} f(j)*g(k)),其中“⊕”表示位运算。

与卷积:

定义(f)(F)的变换:(F(i)=sumlimits_{j&i==i}^{ }f(j))
这样,若(h(i)=sumlimits_{j and k=i} f(j)*g(k)),则(H(i)=F(i)*G(i))
变换方法:就是按照长度为(2^i)分段,把每段的后半部分加到前半部分(1对0有额外贡献)。
逆变换就是减回去。时间复杂度:(O(nlogn))

代码:

void fwtand(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+k]=(sz[j+k]+sz[j+(i>>1)+k])%md;
		}
	}
}
void ifwtand(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+k]=(sz[j+k]-sz[j+(i>>1)+k]+md)%md;
		}
	}
}

或卷积:

与“与卷积”类似。
定义(f)(F)的变换:(F(i)=sumlimits_{j|i==i}^{ }f(j))
这样,若(h(i)=sumlimits_{j or k=i} f(j)*g(k)),则(H(i)=F(i)*G(i))
变换方法:就是按照长度为(2^i)分段,把每段的前半部分加到后半部分(0对1有额外贡献)。
逆变换就是减回去。时间复杂度:(O(nlogn))

代码:

void fwtor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]+sz[j+k])%md;
		}
	}
}
void ifwtor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]-sz[j+k]+md)%md;
		}
	}
}

这两个其实是高维前/后缀和

异或卷积:

这个比较常用。
定义(f)(F)的变换:(F(i)=sumlimits_{j=0}^{2^n-1}(-1)^{bit(j and i)}f(j))
这样,若(h(i)=sumlimits_{j xor k=i} f(j)*g(k)),则(H(i)=F(i)*G(i))
变换方法:就是按照长度为(2^i)分段,把每段的前半部分变为前半部分加后半部分,
后半部分变为前半部分减后半部分。
逆变换就是相当于已知(a+b=x,a-b=y),则(a=(x+y)/2,b=(x-y)/2)
就是正变换再除以2。
时间复杂度:(O(nlogn))

代码:

void fwtxor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
			{
				int a=sz[j+k],b=sz[j+(i>>1)+k];
				sz[j+k]=(a+b)%md;
				sz[j+(i>>1)+k]=(a-b+md)%md;
			}
		}
	}
}
void ifwtxor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
			{
				int a=sz[j+k],b=sz[j+(i>>1)+k];
				sz[j+k]=1ll*(a+b)*inv%md;
				sz[j+(i>>1)+k]=1ll*(a-b+md)*inv%md;
			}
		}
	}
}

FST:子集卷积

(h(i)=sumlimits_{j or k=i且j and k=0} f(j)*g(k))
比或卷积多了一个限制。
我们发现,设(s(i))表示(i)的二进制表示中1的个数,那么如果(i|j=k,i&j=0),则(s(i)+s(j)=s(k))
利用这个性质,我们可以加一维表示(s),在(F*G)时考虑(s)的限制。
时间复杂度:(O(nlog^2n))

代码:

for(int i=0;i<len;i++)
{
	for(int j=0;j<17;j++)
	{
		if(i&(1<<j))
			sl[i]+=1;
	}
}
for(int i=0;i<len;i++)
	a[sl[i]][i]=sz[i];
for(int i=0;i<18;i++)
	fwtor(a[i],len);
for(int i=0;i<18;i++)
{
	for(int j=0;i+j<18;j++)
	{
		for(int k=0;k<len;k++)
			h1[i+j][k]=(h1[i+j][k]+1ll*a[i][k]*a[j][k])%md;
	}
}
for(int i=0;i<18;i++)
	ifwtor(h1[i],len);
for(int i=0;i<len;i++)
	ab[i]=h1[sl[i]][i];

例题:

CF914G

题意:

给你一个长度为(n)的数组(s).定义五元组((a,b,c,d,e))是合法的当且仅当:

①.(1le a,b,c,d,ele n)

②.((s_a|s_b)&s_c&(s_d)^(s_e)=2^i,iin Z)

③.(s_a&s_b=0)

对于所有合法的五元组((a,b,c,d,e))

(sum f(s_a|s_b)*f(s_c)*f(s_d)^(s_e)mod 10^9+7)

(f_0=0,f_1=1,f_i=f_{i-1}+f_{i-2})

(1le nle10^6,0le s_ilt2^{17})

模板题。

先考虑((s_a|s_b)),发现是FST卷积。
再考虑((s_d)^(s_e)),是异或卷积。
然后就是与卷积了。

代码:

#include <stdio.h>
#define md 1000000007
#define inv 500000004
#define len 131072
int sz[132000],sl[132000];
void fwtor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]+sz[j+k])%md;
		}
	}
}
void ifwtor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]-sz[j+k]+md)%md;
		}
	}
}
void fwtand(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+k]=(sz[j+k]+sz[j+(i>>1)+k])%md;
		}
	}
}
void ifwtand(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
				sz[j+k]=(sz[j+k]-sz[j+(i>>1)+k]+md)%md;
		}
	}
}
void fwtxor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
			{
				int a=sz[j+k],b=sz[j+(i>>1)+k];
				sz[j+k]=(a+b)%md;
				sz[j+(i>>1)+k]=(a-b+md)%md;
			}
		}
	}
}
void ifwtxor(int sz[132000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		for(int j=0;j<n;j+=i)
		{
			for(int k=0;k<(i>>1);k++)
			{
				int a=sz[j+k],b=sz[j+(i>>1)+k];
				sz[j+k]=1ll*(a+b)*inv%md;
				sz[j+(i>>1)+k]=1ll*(a-b+md)*inv%md;
			}
		}
	}
}
int a[18][132000],h1[18][132000],fib[132000],h2[132000];
int ab[132000],de[132000],ans[132000];
int main()
{
	int n;
	scanf("%d",&n);
	for(int i=0;i<n;i++)
	{
		int a;
		scanf("%d",&a);
		sz[a]+=1;
	}
	for(int i=0;i<len;i++)
	{
		for(int j=0;j<17;j++)
		{
			if(i&(1<<j))
				sl[i]+=1;
		}
	}
	fib[0]=0;fib[1]=1;
	for(int i=2;i<len;i++)
		fib[i]=(fib[i-1]+fib[i-2])%md;
	for(int i=0;i<len;i++)
		a[sl[i]][i]=sz[i];
	for(int i=0;i<18;i++)
		fwtor(a[i],len);
	for(int i=0;i<18;i++)
	{
		for(int j=0;i+j<18;j++)
		{
			for(int k=0;k<len;k++)
				h1[i+j][k]=(h1[i+j][k]+1ll*a[i][k]*a[j][k])%md;
		}
	}
	for(int i=0;i<18;i++)
		ifwtor(h1[i],len);
	for(int i=0;i<len;i++)
		ab[i]=h1[sl[i]][i];
	for(int i=0;i<len;i++)
		de[i]=sz[i];
	fwtxor(de,len);
	for(int i=0;i<len;i++)
		de[i]=1ll*de[i]*de[i]%md;
	ifwtxor(de,len);
	for(int i=0;i<len;i++)
	{
		ab[i]=1ll*ab[i]*fib[i]%md;
		sz[i]=1ll*sz[i]*fib[i]%md;
		de[i]=1ll*de[i]*fib[i]%md;
	}
	fwtand(ab,len);
	fwtand(sz,len);
	fwtand(de,len);
	for(int i=0;i<len;i++)
		ans[i]=1ll*ab[i]*sz[i]%md*de[i]%md;
	ifwtand(ans,len);
	int jg=0;
	for(int i=1;i<=len;i=(i<<1))
		jg=(jg+ans[i])%md;
	printf("%d",jg);
	return 0;
}

uoj310【UNR #2】黎明前的巧克力

题意:有一个集合,选出两个不相交的子集,使其异或和相等,问方案数。

考虑dp:设(dp(i,j))表示考虑到i,两人异或为j的方案数。
(dp(i,j)=dp(i-1,j)+2*dp(i-1,j)^(a(i)))
考虑FWT:对每个i构造A,使(A(0)=1,A(a(i))=2)
对每个A做FWT,乘起来后再IFWT。但是复杂度太高。
根据公式,可以发现,FWT(A)的每位只能是3或-1。
那么,只要知道FWT后,A的每个对应位置之和,就能解出3和-1的数量了,之后快速幂即可。
根据加法的运算律,可以得知若干个长度相等的序列FWT后对应位置求和,等于先求和,再FWT。
所以求和后,FWT,之后快速幂算出每个位置的值,再IFWT,最后位置0的值减1就是答案。
时间复杂度:(O(mlogm))
思路非常巧妙。

代码:

#include <stdio.h>
#define md 998244353
#define inv 499122177
#define len 1048576
void fwt(int sz[1050000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		int t=(i>>1);
		for(int j=0;j<n;j+=i)
		{
			for(int k=j;k<j+t;k++)
			{
				int a=sz[k+t];
				sz[k+t]=sz[k]-a;
				sz[k]=sz[k]+a;
			}
		}
	}
}
void ifwt(int sz[1050000],int n)
{
	for(int i=2;i<=n;i=(i<<1))
	{
		int t=(i>>1);
		for(int j=0;j<n;j+=i)
		{
			for(int k=j;k<j+t;k++)
			{
				int a=sz[k+t];
				sz[k+t]=1ll*(sz[k]-a+md)*inv%md;
				sz[k]=1ll*(sz[k]+a)*inv%md;
			}
		}
	}
}
int sz[1050010],mi[1050000],m3[1050010],sl[1050000];
int main()
{
	int n;
	scanf("%d",&n);
	for(int i=0;i<n;i++)
		scanf("%d",&sz[i]);
	sl[0]=n;
	for(int i=0;i<n;i++)
		sl[sz[i]]+=2;
	fwt(sl,len);
	for(int i=0;i<len;i++)
		mi[i]=(n+sl[i])/4;
	m3[0]=1;
	for(int i=1;i<=n;i++)
		m3[i]=3ll*m3[i-1]%md;
	for(int i=0;i<len;i++)
	{
		if((n-mi[i])%2==0)
			mi[i]=m3[mi[i]];
		else
			mi[i]=md-m3[mi[i]];
	}
	ifwt(mi,len);
	printf("%d",(mi[0]-1+md)%md);
	return 0;
}

扩展

这个技巧还可以扩展:
就是说当权值有k个时,我们先将其中一个权值变为0,这样总共可能的贡献有(2^{k-1})种。(每个系数是1或-1),0的贡献一定是1。
为了把这(2^{k-1})种的数量分别求出来,我们需要找(2^{k-1})个等式。
可以枚举剩余k-1个元素的子集,只将其异或的位置+1,做FWT。
这样,每个位置,会得到(2^{k-1})个数。将这些数做FWT后,就可以的得到(2^{k-1})种可能分别的数量,快速幂即可。
证明略。
(核心)代码:

int xo=0,mi=1;
for(int i=0;i<n;i++)
{
	for(int j=0;j<k;j++)
		scanf("%d",&p[i][j]);
	xo^=p[i][0];
	for(int j=1;j<k;j++)
		p[i][j]^=p[i][0];
}
ans[xo]=1;
fwtxor(ans,(1<<m));
for(int s=0;s<(1<<(k-1));s++)
{
	for(int i=0;i<n;i++)
	{
		int z=0;
		for(int j=1;j<k;j++)
		{
			if(s&(1<<(j-1)))
				z^=p[i][j];
		}
		
		sz[z]+=1;
	}
	fwtxor(sz,(1<<m));
	for(int i=0;i<(1<<m);i++)
	{
		
		nf[(i<<(k-1))|s]=sz[i];
		sz[i]=0;
	}
}
for(int i=0;i<(1<<m);i++)
{
	for(int s=0;s<(1<<(k-1));s++)
		zz[s]=nf[(i<<(k-1))|s];
	ifwtxor(zz,1<<(k-1));
	for(int s=0;s<(1<<(k-1));s++)
	{
		int he=sl[0];
		for(int j=1;j<k;j++)
		{
			if(s&(1<<(j-1)))
				he=(he-sl[j]+md)%md;
			else
				he=(he+sl[j])%md;
		}
		ans[i]=1ll*ans[i]*ksm(he,zz[s])%md;
	}
}
ifwtxor(ans,(1<<m));

CF662C Binary Table

题意:
有一个 n 行 m 列的表格,每个元素都是 0/1 ,每次操作可以选择一行或一列,把 0/1 翻转,即把 0 换为 1 ,把 1 换为 0 。请问经过若干次操作后,表格中最少有多少个 1。((1leq n leq 20,1leq m leq 10^5))
首先,我们可以枚举行的交换,共(2^n)种。
然后,对每一列,考虑它是否交换。复杂度为(O(nm2^n))
考虑优化:
首先,我们发现,如果记翻转为1,那么翻转就是异或。
记B数组表示状态压缩后的每列的出现次数。
记A数组表示一列为这个状态的1的最少个数。
那么,设(C_i=A_{ixorj}*B_j)之和,那么C的最小值就是答案。
反一下,将A和B做异或卷积,即可得到C。时间复杂度(O(nm+n2^n))

代码:

#include <stdio.h>
#define ll long long
void fwt(ll sz[1048576],int n)
{
	for(int h=2;h<=n;h=(h<<1))
	{
		for(int i=0;i<n;i+=h)
		{
			for(int j=0;j<(h>>1);j++)
			{
				ll a=sz[i+j],b=sz[i+j+(h>>1)];
				sz[i+j]=a+b;
				sz[i+j+(h>>1)]=a-b;
			}
		}
	}
}
void ifwt(ll sz[1048576],int n)
{
	for(int h=2;h<=n;h=(h<<1))
	{
		for(int i=0;i<n;i+=h)
		{
			for(int j=0;j<(h>>1);j++)
			{
				ll a=sz[i+j],b=sz[i+j+(h>>1)];
				sz[i+j]=(a+b)/2;
				sz[i+j+(h>>1)]=(a-b)/2;
			}
		}
	}
}
int sz[20][100005];char zf[100005];
ll sa[1048576],sb[1048576];
int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=0;i<n;i++)
	{
		scanf("%s",zf);
		for(int j=0;j<m;j++)
			sz[i][j]=zf[j]-'0';
	}
	for(int i=0;i<(1<<n);i++)
	{
		int s=0;
		for(int j=0;j<n;j++)
		{
			if(i&(1<<j))
				s+=1;
		}
		sa[i]=n-s;
		if(s<sa[i])sa[i]=s;
	}
	for(int i=0;i<m;i++)
	{
		int s=0;
		for(int j=0;j<n;j++)
		{
			if(sz[j][i])
				s|=(1<<j);
		}
		sb[s]+=1;
	}
	fwt(sa,1<<n);fwt(sb,1<<n);
	for(int i=0;i<(1<<n);i++)
		sa[i]*=sb[i];
	ifwt(sa,1<<n);
	int ans=99999999;
	for(int i=0;i<(1<<n);i++)
	{
		if(sa[i]<ans)
			ans=sa[i];
	}
	printf("%d",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/lnzwz/p/11257691.html