联考20200722 T1 集合划分


分析:
首先是一个(O(n^2))的DP,设(f_{i,j,0/1})表示做了前(i)个,用了(j)(A),最后一个是(A/B)的方案数
然后我们不看最后一位,发现(f_{i,j})两个状态可以用(2*2)的转移矩阵DP
发现转移矩阵与(j)没有关系,把(j)去掉,维护(f_i=sum_{j=0}a_jx^j)的生成函数,(x^j)项系数就是(f_{i,j})
如果加一位(A)相当于乘一个(x),否则乘一个(1)
分治维护矩阵上的多项式
复杂度(O(nlog^2n)),我的常数巨大2333

#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<string>

#define maxn 200005
#define INF 0x3f3f3f3f
#define MOD 998244353
#define Poly vector<int>

using namespace std;

inline long long getint()
{
	long long num=0,flag=1;char c;
	while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
	while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
	return num*flag;
}

int n;
int A[maxn],B[maxn];
struct node{
	Poly a[2][2];
}P[maxn];
int rev[maxn];

inline int upd(int x){return x<MOD?x:x-MOD;}
inline int ksm(int num,int k)
{
	int ret=1;
	for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
	return ret;
}
inline Poly add(Poly x,Poly y)
{
	int mx=max(x.size(),y.size());
	x.resize(mx),y.resize(mx);
	for(int i=0;i<mx;i++)x[i]=upd(x[i]+y[i]);
	return x;
}

inline void NTT(Poly &a,int N,int opt)
{
	for(int i=0;i<N;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int i=1;i<N;i<<=1)
	{
		int wn=ksm(3,(MOD-1)/(i<<1));
		if(!~opt)wn=ksm(wn,MOD-2);
		for(int j=0;j<N;j+=i<<1)for(int k=0,w=1;k<i;k++,w=1ll*w*wn%MOD)
		{
			int x=a[j+k],y=1ll*a[i+j+k]*w%MOD;
			a[j+k]=upd(x+y),a[i+j+k]=upd(x-y+MOD);
		}
	}
	if(!~opt)for(int i=0,Inv=ksm(N,MOD-2);i<N;i++)a[i]=1ll*a[i]*Inv%MOD;
}

inline node mul(node y,node x)
{
	int N=x.a[0][0].size(),M=y.a[0][0].size(),len=1;
	while(len<N+M)len<<=1;
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|(i&1?len>>1:0);
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)
	{
		x.a[i][j].resize(len),y.a[i][j].resize(len);
		NTT(x.a[i][j],len,1),NTT(y.a[i][j],len,1);
	}
	node z;
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)
	{
		z.a[i][j].resize(len);
		for(int k=0;k<2;k++)for(int l=0;l<len;l++)z.a[i][j][l]=(z.a[i][j][l]+1ll*x.a[i][k][l]*y.a[k][j][l])%MOD;
	}
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)NTT(z.a[i][j],len,-1),z.a[i][j].resize(N+M-1);
	return z;
}

inline node solve(int l,int r)
{
	if(l==r)return P[l];
	int mid=(l+r)>>1;
	return mul(solve(l,mid),solve(mid+1,r));
}

int main()
{
	n=getint(),getint();
	for(int i=1;i<=2*n;i++)A[i]=getint();
	for(int i=1;i<=2*n;i++)B[i]=getint();
	for(int i=1;i<=2*n;i++)
	{
		for(int j=0;j<2;j++)for(int k=0;k<2;k++)P[i].a[j][k].resize(2);
		if(A[i-1]<=A[i])P[i].a[0][0][1]=1;
		if(B[i-1]<=A[i])P[i].a[0][1][1]=1;
		if(A[i-1]<=B[i])P[i].a[1][0][0]=1;
		if(B[i-1]<=B[i])P[i].a[1][1][0]=1;
	}
	node Ans=solve(1,2*n);
	printf("%d
",upd(Ans.a[0][0][n]+Ans.a[1][0][n]));
}

原文地址:https://www.cnblogs.com/Darknesses/p/13362597.html