题解 P5293 【[HNOI2019]白兔之舞】

题意

有一张顶点数为 ((L+1) imes n) 的有向图,定点用二元组((u,v))(0le ule L,vle 1le n))表示。

在这个矩形中,前面的((u_1,v_1))可以到后面的((u_2,v_2)),且有(w[v_1][v_2])条路。

((0,x))出发,到第二维为(y)的顶点,对于(xin[0,k-1]),有多少条走了(m(mmod k=x))步的路线。

(感觉越解释越乱)

题解

这个屑没有学过矩阵表述不严谨请见谅。

(g_{i,j})为走了(i)步到第二维为(j)的顶点的方案数,有:

[g_{i,j}=sum_{k=1}^ng_{i-1,k}w_{k,j} ]

不难想到矩阵优化。

我们若记

[G_i=(g_{i,1},g_{i,2},ldots,g_{i,n}) ]

[G_0=(underbrace{0,ldots,0}_{x-1},1,underbrace{0,ldots,0}_{n-x}) ]

[S= egin{pmatrix} w_{1,1}&cdots&w_{1,n}\ vdots&ddots&vdots\ w_{n,1}&cdots&w_{n,n} end{pmatrix} ]

有:(G_i=G_{i-1} imes S),即(G_i=G_0 imes S^i)

但此时我们没有考虑白兔兔每次走了几步。所以我们还需要考虑(i)步走了那些列。实际上就是在(L)个位置上选了(i)个位置跳,也就是( binom{L}{i}g_{i,j}= binom{L}i(G_0S^i)_{1,y})

那么此时我们已经表示出跳了(i)的方案数。来考虑答案,列出式子开始乱推:

[ans_t=sum_{i=0}^L[imod k=t] binom{L}{i}g_{i,y} ]

([imod k=t])等价于(i=t+k imes a),即([k|(i-t)])

[ans_t=sum_{i=0}^L[k|(i-t)] binom{L}{i}g_{i,y} ]

然后就是神奇的单位根反演。先上公式:

[forall k,[n|k]=frac{1}{n}sum_{i=0}^{n-1}omega_n^{ik} ]

简单证明一下:

  • ([n|k])(omega_n^{ik}=omega^0=1),显然原式等于(1)
  • 否则,就是一个等比数列求和(dfrac{1}{n} imes dfrac{omega_n^{nk}-omega_n^0}{omega_n^k-1}=0)
    把这个简单的柿子带进去:

[ans_t=sum_{i=0}^Lfrac{1}{k} binom{L}{i}g_{i,y}sum_{j=0}^{k-1}omega_{k}^{(i-t)j} ]

有一说一括号看着很难受,因此拆开来。

[ans_t=frac{1}{k}sum_{i=0}^L binom{L}{i}g_{i,y}sum_{j=0}^{k-1}omega_{k}^{ij}omega_{k}^{-tj} ]

交换求和顺序就是传统艺能了。

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}sum_{i=0}^L binom{L}{i}g_{i,y}omega_{k}^{ij} ]

把我们推出来的(g_{i,y})带进去。

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}sum_{i=0}^L binom{L}{i}omega_{k}^{ij} imes(G_0S^i)_{1,y} ]

矩阵那部分事实上也可以先提取公因式再求和再取下标,即:

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}(G_0sum_{i=0}^L binom{L}{i}omega_{k}^{ij} imes S^i)_{1,y} ]

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}(G_0sum_{i=0}^L binom{L}{i}(omega_{k}^{j} imes S)^i)_{1,y} ]

矩阵这一串柿子不禁让人联想到二项式定理,我们如果记单位矩阵为(I)有:

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}(G_0sum_{i=0}^L binom{L}{i}(omega_{k}^{j} imes S)^iI^{L-i})_{1,y} ]

于是此时的幂形式就显而易见了。

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}(G_0(omega_k^jS+I)^L)_{1,y} ]

如果我们记(f_i=(G_0(omega_k^iS+I)^L)_{1,y}),就可以吧答案写成比较简单的形式。

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-tj}f_j ]

这个(omega_{k}^{-tj})就要用到Bluestein's Algorithm这个听起来非常高大上实际上并不难的东西。核心思想就是把(ij)拆掉,变成与(i,j)一次关系,然后就可以了。

但这题(ij=frac{(i+j)^2}{2}-frac{i^2}{2}-frac{j^2}{2})本题似乎不太行,因为可能会出现(frac{1}{2})次但单位根可能没有二次剩余,因此我们需要另一种奇怪的方法:
(ij= binom{i+j}{2}- binom{i}{2}- binom{j}{2}),正确性拆开来就行了。这样就全部是整数从而避免了二次剩余。

那么带进去:

[ans_t=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{- binom{t+j}{2}+ binom{t}{2}+ binom{j}{2}}f_j ]

整理一下。

[ans_t=frac{omega_{k}^ binom{t}{2}}{k}sum_{j=0}^{k-1}omega_{k}^{- binom{t+j}{2}} imes omega_{k}^ binom{j}{2}f_j ]

如果我们记(c_{k+t}=dfrac{ans_tk}{omega_{k}^ binom{t}{2}})(a_i=omega_{k}^{- binom{i}{2}}),(b_{t-i}=omega_{k}^ binom{i}{2}f_i)

[c_{k+t}=sum_{i=0}^{k-1}a_{t+i}b_{k-i} ]

MTT随便卷一下就好了。不过注意(a)需要倍长QwQ

代码

#include<bits/stdc++.h>
namespace in{
	char buf[1<<21],*p1=buf,*p2=buf;
	inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
	template <typename T>inline void read(T& t){
		t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
	    while(isdigit(ch)){t=t*10+ch-48;ch = getc();}if(f)t=-t;
	}
	template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
}
namespace out{
	char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
	inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
	inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
	template <typename T>void write(T x) {
		static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
		while (len>=0)putc(buf[len]),--len;
	}
}
typedef std::complex<double>complex;
const int N=4e6+10;const double PI=acos(-1);const complex I=complex(0,1);
int rev[N];complex Wn[N];int M,mod;
int ksm(int x,int y) {
	int re=1;
	for(;y;y>>=1,x=1LL*x*x%mod)if(y&1)re=1LL*re*x%mod;
	return re;
}
inline long long num(complex x){double d=x.real();return d<0?(long long)(d-0.5)%mod:(long long)(d+0.5)%mod;}
struct poly{
	std::vector<complex>a0,a1;
	int size(){return a0.size();}
    void resize(int n){a0.resize(n);a1.resize(n);}
	void set(int x,long long y){
		y%=mod;
		a0[x]=y/M;
		a1[x]=y%M;
	}
	long long get(int x){return (M*M*num(a0[x].real())%mod +
				M*(num(a0[x].imag())+num(a1[x].real()))%mod+num(a1[x].imag()))%mod;}
	long long val(int x){return (long long)(M*a0[x].real()+a1[x].real()+mod)%mod;}
};
poly operator+(poly a,poly b){
    int n=std::max(a.size(),b.size());a.resize(n),b.resize(n);
    for(int i=0;i<n;i++)a.set(i,a.val(i)+b.val(i));return a;
}
poly operator-(poly a,poly b){
    int n=std::max(a.size(),b.size());a.resize(n),b.resize(n);
    for(int i=0;i<n;i++)a.set(i,a.val(i)-b.val(i));return a;
}
inline poly one(){poly a;a.resize(1);a.set(0,1);return a;}
inline int ext(int n){int k=0;while((1<<k)<n)k++;return k;}
inline void init(int k){
	int n=1<<k;
	for(int i=0;i<n;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	for(int i=0;i<n;i++)
		Wn[i]={cos(PI/n*i),sin(PI/n*i)};
}
void FFT(std::vector<complex>&A,int n,int t){
	if(t<0)for(int i=1;i<n;i++)if(i<(n-i))std::swap(A[i],A[n-i]);
	for(int i=0;i<n;i++)
		if(i<rev[i])std::swap(A[i],A[rev[i]]);
	for(int m=1;m<n;m<<=1)
		for(int i=0;i<n;i+=m<<1)
			for(int k=i;k<i+m;k++){
				complex W=Wn[1ll*(k-i)*n/m];
				complex a0=A[k],a1=A[k+m]*W;
				A[k]=a0+a1;A[k+m]=a0-a1; 
			}
	if(t<0)for(int i=0;i<n;i++)A[i]/=n;
}
void MTT(poly &A,int n,int t){
	for(int i=0;i<n;i++)A.a0[i]=A.a0[i]+I*A.a1[i];
	FFT(A.a0,n,t);
	for(int i=0;i<n;i++)A.a1[i]=std::conj(A.a0[i?n-i:0]);
	for(int i=0;i<n;i++){
		complex p=A.a0[i],q=A.a1[i];
		A.a0[i]=(p+q)*0.5;A.a1[i]=(q-p)*0.5*I;
	}
}
inline poly operator*(poly a,poly b){
    int n=a.size()+b.size()-1,k=ext(n);
    a.resize(1<<k),b.resize(1<<k),init(k);
    MTT(a,1<<k,1);MTT(b,1<<k,1);
	for(int i=0;i<(1<<k);i++){
		complex p=a.a0[i]*b.a0[i]+I*a.a1[i]*b.a0[i];
		complex q=a.a0[i]*b.a1[i]+I*a.a1[i]*b.a1[i];
		a.a0[i]=p,a.a1[i]=q;
	}
    FFT(a.a0,1<<k,-1);FFT(a.a1,1<<k,-1);a.resize(n);
	long long tmp;for(int i=0;i<n;i++)
		tmp=a.get(i),a.set(i,tmp);
	return a;
}
struct modint{
    int x;
    modint(int o=0){x=o;}
    modint &operator = (int o){return x=o,*this;}
    modint &operator +=(modint o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
    modint &operator -=(modint o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
    modint &operator *=(modint o){return x=1ll*x*o.x%mod,*this;}
    modint &operator ^=(int b){
        modint a=*this,c=1;
        for(;b;b>>=1,a*=a)if(b&1)c*=a;
        return x=c.x,*this;
    }
    modint &operator /=(modint o){return *this *=o^=mod-2;}
    modint &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
    modint &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
    modint &operator *=(int o){return x=1ll*x*o%mod,*this;}
    modint &operator /=(int o){return *this *= ((modint(o))^=mod-2);}
	template<class I>friend modint operator +(modint a,I b){return a+=b;}
    template<class I>friend modint operator -(modint a,I b){return a-=b;}
    template<class I>friend modint operator *(modint a,I b){return a*=b;}
    template<class I>friend modint operator /(modint a,I b){return a/=b;}
    friend modint operator ^(modint a,int b){return a^=b;}
    friend bool operator ==(modint a,int b){return a.x==b;}
    friend bool operator !=(modint a,int b){return a.x!=b;}
    bool operator ! () {return !x;}
    modint operator - () {return x?mod-x:0;}
	modint &operator++(int){return *this+=1;}
};
struct Mat{
	//矩阵
	int n,m;modint val[5][5]; 
	modint*operator[](const int x){return val[x];}
	friend Mat operator*(Mat A,modint B){
		for(int i=1;i<=A.n;i++)for(int j=1;j<=A.m;j++)A[i][j]*=B;
		return A;
	}
	friend Mat operator*(Mat A,Mat B){
		static Mat C;if(A.m!=B.n)exit(11);
		C.n=A.n,C.m=B.m;
		for(int i=1;i<=C.n;i++)
			for(int j=1;j<=C.m;j++){
				C[i][j]=0;
				for(int k=1;k<=A.m;k++)
					C[i][j]+=A[i][k]*B[k][j];
			}
		return C;
	}
	friend Mat operator+(Mat A,Mat B){
		if(A.n!=B.n||A.m!=B.m)exit(12);
		for(int i=1;i<=A.n;i++)for(int j=1;j<=A.m;j++)A[i][j]+=B[i][j];
		return A;
	}
	
};
int n,k,L,x,y;
Mat _1,S,G0;
modint w[N];
Mat operator^(Mat A,int B){
	Mat tmp=_1;
	while(B){
		if(B&1)tmp=tmp*A;
		A=A*A;B>>=1;
	}return tmp;
}
modint getG(int x){
	static int p[50],o=0;
	for(int i=2,y=x-1;i<=y;i++)
		if(y%i==0){
			p[++o]=i;
			while(y%i==0)y/=i;
		}
	for(int i=2;;i++){
		bool flag=true;
		for(int j=1;j<=o;j++)if((modint(i)^((mod-1)/p[j]))==1){flag=false;break;}
		if(flag)return i;
	}
}
#define C2(x) (1ll*(x)*(x-1)/2) 
signed main(){
	//freopen("1.in","r",stdin);
	in::read(n,k,L,x,y,mod);M=sqrt(mod);
	_1.n=_1.m=n;for(int i=1;i<=n;i++)_1[i][i]=1;
	S.n=S.m=n;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)in::read(S[i][j]);
	G0.n=1;G0.m=n;G0[1][x]=1;
	w[0]=1;w[1]=getG(mod)^((mod-1)/k);
	for(int i=2;i<=k;i++)w[i]=w[i-1]*w[1];
	//out::write(getG(mod).x);
	poly a,b;a.resize(2*k);b.resize(k+1);
	for(int i=0;i<2*k;i++)a.set(i,w[(k-C2(i)%k)%k].x);
	//for(int i=0;i<2*k;i++)std::cout<<a.val(i)<<" ";std::cout<<std::endl;
	for(int i=0;i<k;i++)b.set(k-i,(w[C2(i)%k]*(G0*((S*w[i]+_1)^L))[1][y]).x);
	//for(int i=0;i<k;i++)std::cout<<b.val(i)<<" ";std::cout<<std::endl;
	poly c=a*b;
	for(int i=0;i<k;i++)out::write((c.val(k+i)*w[C2(i)%k]/k).x),out::putc('
');
	out::flush();
	return 0;
}
原文地址:https://www.cnblogs.com/juruo-cjl/p/14243889.html