题解 CF1326F2 Wise Men (Hard Version)

前置知识:快速沃尔什变换(FWT)

or卷积

给出两个序列(a,b),求一个序列(c),使得(c_i=sum_{joperatorname{OR}k=i}a_jb_k)

仿照FFT的思路,我们构造两个序列(FWT(a),FWT(b))(对应了FFT里的点值),使得(FWT(c)[i]=FWT(a)[i]cdot FWT(b)[i])。然后再对(FWT(c))做逆变换,得到(c)

FWT算法的结论是:对于or卷积,(FWT(a)[i]=sum_{joperatorname{OR}i=i}a_j)。可以发现,(joperatorname{OR}i=i)就等价于“(j)(i)的一个子集”。

值得一提的是,根据这个定义,FWT-or就相当于是做高维前缀和;FWT-or的逆变换(IFWT-or)就相当于是高维前缀和的逆变换(差分)

在实现时,对于一个最高次项为(2^n)的多项式(a),把它分成(a_0,a_1)两部分,分别表示前面的(2^{n-1})项和后面的(2^{n-1})项,则:

[FWT(a)=egin{cases} (FWT(a_0),FWT(a_0+a_1))&&n>0\ a&&n=0 end{cases} ]

这个逗号是啥意思?因为(FWT(a))是一个长度为(2^n)的序列,因此逗号左边就是序列的前(2^{n-1})项,右边就是序列的后(2^{n-1})项。

而逆变换就把这个过程反过来即可,即:

[IFWT(a)=egin{cases} (IFWT(a_0),IFWT(a_1-a_0))&& n>0\ a&&n=0 end{cases} ]

and卷积

给出两个序列(a,b),求一个序列(c),使得(c_i=sum_{joperatorname{AND}k=i}a_jb_k)

对于and卷积,(FWT(a)[i]=sum_{joperatorname{AND}i=i}a_j)。可以发现,(joperatorname{AND}i=i)就等价于“(i)(j)的一个子集”,和or卷积恰好相反。

同样可以看出,根据这个定义,FWT-and就相当于是做高维后缀和;FWT-and的逆变换(IFWT-and)就相当于是高维后缀和的逆变换(差分)

在实现时,

[FWT(a)=egin{cases} (FWT(a_0+a_1),FWT(a_1))&&n>0\ a&&n=0 end{cases} ]

同理可以做逆变换:

[IFWT(a)=egin{cases} (IFWT(a_0-a_1),IFWT(a_1))&&n>0\ a&&n=0 end{cases} ]

xor卷积

与本题无关。只是顺带提一下做法:

[FWT(a)=egin{cases} (FWT(a_0+a_1),FWT(a_0-a_1))&&n>0\ a&&n=0 end{cases} ]

于是可知,逆变换为:

[IFWT(a)=egin{cases} (IFWT(frac{a_0+a_1}{2}),IFWT(frac{a_0-a_1}{2}))&&n>0\ a&&n=0 end{cases} ]

本题题解

我们设(ans(s))表示串(s)的答案。直接求(ans(s))不好求,考虑集合中至少包含(s)的答案,即所有(sin S)(ans(S))之和,记为(ans'(s))。然后我们对(ans')数组做IFWT-and卷积,就可以求出所有(ans(s))

把朋友之间的关系看做一张无向图。我们定义一条链的长度为它经过的节点数

那么对于一个长度为(n-1)的01串(s),它代表的其实是图中的若干条链。具体来讲,如果在串(s)后面补上一个(0),那么:

  • 串中每段连续的(1)是一条链。如果有(x)(1),则链的长度为(x+1)
  • 每个(0)是单独的一个节点(也就是一条长度为(1)的链)。特别地:一段连续的(1)之后的第一个(0)除外,它这个位置上的节点已经被计入了上一条连续的(1)组成的链中。

按照上述规则,不难发现,所有链的长度之和恰好为(n)。而对于一个01串(s)来说,(ans'(s))只取决于它划分出的链的长度的可重集。例如:(ans'(0111011)=ans'(1100111)),因为它们的这个可重集都是({1,3,4})

又因为所有链的长度之和恰好为(n),故本质不同的可重集数量只有(P(n))种,其中(P(n))表示(n)的划分数。(P(18)=385)。于是我们只需要对这(P(n))个“链的长度的可重集”,分别求答案。

(f_{i,mask})表示对于一个大小为(i)的节点集合(mask),图中有多少条链,恰好经过(mask)中的这些节点。

如果我们求出了(f_{i,mask})数组,那么对于一个“链的长度的可重集”(T),它的答案就是(displaystylesum_{m_1,dots,m_{|T|}} prod_{i=1}^{|T|}f_{len(T_i),m_i})。其中(len(T_i))表示(T)中第(i)条链的长度。前面的(sum)枚举的是一个(m_i)数组,表示对每个(i)各取一个大小为(len(T_i))的点集(m_i),要求这些(m_i)的并为([1,n])且互相不交。容易发现只要并为([1,n])就必然互相不交,因为它们的(len(T_i))之和为(n)。所以我们可以做一个FMT-or卷积。把所有(f_{len(T_i)})(|T|)个序列卷起来。卷积结果的(2^n-1)项前的系数即为(T)这个可重集的答案。

现在最后的问题是(f_{i,mask})数组怎么求。可以做简单的状压DP。设(dp[mask][j])表示经过了(mask)中的这些节点,最后一个经过的节点为(j)的链的数量。转移时枚举一个不在(mask)中切与(j)有连边的点作为下一个点即可。则(f_{i,mask}=sum_{j=1}^{n}dp[mask][j])

DP求(f_{i,mask})的复杂度为(O(2^nn^2)),之后枚举每个可重集,求答案的复杂度为(O(P(n)2^nn)),其中(P(18)=385)

参考代码:

//problem:CF1326F2
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline int readbit(){
	char ch=getchar();
	while(ch<'0'||ch>'1')ch=getchar();
	return ch-'0';
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=18;
int n,a[MAXN+5][MAXN+5];
ll dp[1<<MAXN][MAXN+5],f[MAXN+5][1<<MAXN],h[400],ans[1<<MAXN];

int cnt;
map<vector<int>,int>mp;
vector<int>vec[400],tmp;
void dfs(int cur,int lst){
	if(cur==n+1){
		mp[tmp]=++cnt;
		vec[cnt]=tmp;
		return;
	}
	if(n-cur+1<lst)return;
	for(int i=lst;cur+i-1<=n;++i){
		tmp.pb(i);
		dfs(cur+i,i);
		tmp.pop_back();
	}
}

int bitcnt(uint x){
	int res=0;
	for(int j=0;j<=31;++j)res+=((x>>j)&1u);
	return res;
}
void fwt_or(ll *f,uint n,int flag){
	// FWT_or(A)[i] = sum_{j|i=i} A[j]
	//即:j是i的一个子集
	for(uint i=1;i<n;i<<=1){
		for(uint j=0;j<n;j+=(i<<1)){
			for(uint k=0;k<i;++k){
				f[i+j+k]+=f[j+k]*flag;
			}
		}
	}
}
void fwt_and(ll *f,uint n,int flag){
	// FWT_and(A)[i] = sum_{j&i=i} A[j]
	//即:i是j的一个子集
	for(uint i=1;i<n;i<<=1){
		for(uint j=0;j<n;j+=(i<<1)){
			for(uint k=0;k<i;++k){
				f[j+k]+=f[i+j+k]*flag;
			}
		}
	}
}

int main() {
	n=read();
	dfs(1,1);//搜出所有划分数
	for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)a[i][j]=readbit();
	
	//dp[mask][j] 表示经过了mask中这些点,以j结尾的链有多少.
	//用来求出 f[i][mask] 表示经过了大小为i的集合mask的链的数量
	for(int i=1;i<=n;++i)dp[1u<<(i-1)][i]=1;
	for(uint i=1;i<(1u<<n);++i){
		for(int j=1;j<=n;++j)if((i>>(j-1))&1u){
			for(int k=1;k<=n;++k)if(a[j][k]&&!((i>>(k-1))&1u)){
				dp[i|(1u<<(k-1))][k]+=dp[i][j];
			}
		}
		int t=bitcnt(i);
		for(int j=1;j<=n;++j)f[t][i]+=dp[i][j];
	}
	
	for(int i=1;i<=n;++i)fwt_or(f[i],1u<<n,1);
	static ll IE[1<<MAXN],tmp[1<<MAXN];
	IE[0]=1;
	fwt_or(IE,1u<<n,1);
	for(int i=1;i<=cnt;++i){
		for(uint j=0;j<(1u<<n);++j)tmp[j]=IE[j];
		for(uint j=0;j<vec[i].size();++j){
			for(uint k=0;k<(1u<<n);++k)tmp[k]=(ll)tmp[k]*f[vec[i][j]][k];
		}
		fwt_or(tmp,1u<<n,-1);
		h[i]=tmp[(1u<<n)-1];
	}
	for(uint i=0;i<(1u<<(n-1));++i){
		vector<int>tmp;
		for(int j=0;j<=n-1;){
			int jj=j+1;
			while(jj-1<=n-2 && ((i>>(jj-1))&1u))++jj;
			tmp.pb(jj-j);
			j=jj;
		}
		sort(tmp.begin(),tmp.end());
		//for(uint j=0;j<tmp.size();++j)cout<<tmp[j]<<" ";cout<<endl;
		assert(mp.count(tmp));
		ans[i]=h[mp[tmp]];
	}
	fwt_and(ans,(1u<<(n-1)),-1);
	for(uint i=0;i<(1u<<(n-1));++i)printf("%lld ",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/dysyn1314/p/12534771.html