luogu P5293 [HNOI2019]白兔之舞

传送门

关于这题答案,因为在所有行,往后跳到任意一行的(w_{i,j})都是一样的,所以可以算出跳(x)步的答案然后乘上(inom{l}{x}),也就是枚举跳到了哪些行

如果记跳x步的方案是(f_x),(n=1)时,(f_x={w_{1,1}}^x);(n>1)时,因为n很小,转移可以写成矩阵,然后矩阵快速幂后就是初始矩阵乘转移矩阵的第一行第(y)列的值

后面记(w)为对应的转移矩阵,(a)为初始矩阵(省略下标(_{(1,y)}))

我们把答案式子列出来(ans_x=sum_{i=0}^{l}[imod k=x]aw^iinom{l}{i})

然后可以快乐的推导(ans_x=sum_{i=0}^{l}[k|(i-x)]aw^iinom{l}{i})

那个条件是单位根反演,即([n|m]=>[frac{1}{n}sum_{i=0}^{n-1}omega_{n}^{im}=1]),所以可以得到
(ans_x=sum_{i=0}^{l}frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{j(i-x)}aw^iinom{l}{i})
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-jx}sum_{i=0}^{l}omega_{k}^{ji}aw^iinom{l}{i})
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-jx}sum_{i=0}^{l}a(omega_{k}^jw)^iinom{l}{i})
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-jx}asum_{i=0}^{l}inom{l}{i}(omega_{k}^jw)^iI^{l-i})
二项式定理得
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-jx}a(omega_{k}^jw+I)^l)

后面那个东西可以预处理,第(i)项记为(g_i),然后我们要求
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{-jx}g_j)

可以看到(-jx)比较麻烦,不过(ij=frac{(i+j)^2}{2}-frac{i^2}{2}-frac{j^2}{2}),但是(omega_{k})不一定有二次剩余,所以可以这样(ij=inom{i+j}{2}-inom{i}{2}-inom{j}{2}),然后
(ans_x=frac{1}{k}sum_{j=0}^{k-1}omega_{k}^{inom{j-x}{2}-inom{j}{2}-inom{-x}{2}}g_j)
(ans_x=frac{1}{k}omega_{k}^{-inom{-x}{2}}sum_{j=0}^{k-1}omega_{k}^{inom{j-x}{2}}omega_{k}^{-inom{j}{2}}g_j)

然后是个卷积形式,就可以(MTT)算了

注意优化常数,例如矩乘少一点取模,还有就是(FFT)可以优化一下长度

// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<cmath>
#include<ctime>
#include<queue>
#include<map>
#include<set>
#define LL long long
#define db long double

using namespace std;
const int N=1e5+100,M=270000+10;
const db pi=acos(-1);
int rd()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x*w;
}
int n,m,l,x,y,mod,sqtm,wm,wk,f[N];
int fpow(int a,int b){int an=1;while(b){if(b&1) an=1ll*an*a%mod;a=1ll*a*a%mod,b>>=1;} return an;}
int inv(int a){return fpow(a,mod-2);}
LL c2(int a){return 1ll*(a)*(a-1)/2;}
int getw(int mod)
{
    int st[20],tp=0,xx=mod-1,phim=xx,lim=sqrt(xx);
    for(int i=2;xx>1&&i<=lim;++i)
        if(xx%i==0)
        {
            st[++tp]=i;
            while(xx%i==0) xx/=i;
        }
    if(xx>1) st[++tp]=xx;
    for(int g=2;;++g)
    {
        bool o=1;
        for(int j=1;j<=tp&&o;++j)
            o=fpow(g,phim/st[j])>1;
        if(o) return g;
    }
}
struct matrix
{
    int a[3][3];
    matrix(){memset(a,0,sizeof(a));}
    matrix operator + (const matrix &bb) const
        {
            matrix an;
            for(int i=0;i<3;++i)
                for(int j=0;j<3;++j)
                    an.a[i][j]=(a[i][j]+bb.a[i][j])%mod;
            return an;
        }
    matrix operator * (const int &bb) const
        {
            matrix an;
            for(int i=0;i<3;++i)
                for(int j=0;j<3;++j)
                    an.a[i][j]=1ll*a[i][j]*bb%mod;
            return an;
        }
    matrix operator * (const matrix &bb) const
        {
            matrix an;
            for(int i=0;i<3;++i)
                for(int j=0;j<3;++j)
                {
                    LL nw=0;
                    for(int k=0;k<3;++k)
                        nw+=1ll*a[i][k]*bb.a[k][j];
                    an.a[i][j]=nw%mod;
                }
            return an;
        }
    matrix operator ^ (const int &bb) const
        {
            int b=bb;
            matrix an,a=*this;
            for(int i=0;i<3;++i) an.a[i][i]=1;
            while(b)
            {
                if(b&1) an=an*a;
                a=a*a,b>>=1;
            }
            return an;
        }
}maa,mab,me;
int nn,rdr[M];
struct comp
{
    db r,i;
    comp(){}
    comp(db nr,db ni){r=nr,i=ni;}
    comp operator + (const comp &bb) const {return comp(r+bb.r,i+bb.i);}
    comp operator - (const comp &bb) const {return comp(r-bb.r,i-bb.i);}
    comp operator * (const comp &bb) const {return comp(r*bb.r-i*bb.i,r*bb.i+i*bb.r);}
}p1[M],p2[M],p3[M],p4[M],p5[M],p6[M],p7[M];
void fft(comp *a,int op)
{
    comp x,y,w;
    for(int i=0;i<nn;++i)
        if(i<rdr[i]) swap(a[i],a[rdr[i]]);
    for(int i=1;i<nn;i<<=1)
    {
        comp ww=comp(cos(pi/i),op*sin(pi/i));
        for(int j=0;j<nn;j+=i<<1)
        {
            w=comp(1,0);
            for(int k=0;k<i;++k,w=w*ww)
                x=a[j+k],y=a[j+k+i]*w,a[j+k]=x+y,a[j+k+i]=x-y;
        }
    }
    if(op==-1) for(int i=0;i<nn;++i) a[i].r/=nn;
}
void mul(int *a,int *b)
{
    for(int i=0;i<nn;++i)
        p1[i]=comp(a[i]/sqtm,0),p2[i]=comp(a[i]%sqtm,0);
    for(int i=0;i<nn;++i)
        p3[i]=comp(b[i]/sqtm,0),p4[i]=comp(b[i]%sqtm,0);
    fft(p1,1),fft(p2,1),fft(p3,1),fft(p4,1);
    for(int i=0;i<nn;++i) p5[i]=p1[i]*p3[i],p6[i]=p1[i]*p4[i]+p2[i]*p3[i],p7[i]=p2[i]*p4[i];
    fft(p5,-1),fft(p6,-1),fft(p7,-1);
    for(int i=0;i<nn;++i)
        a[i]=((LL)(p5[i].r+0.5)%mod*sqtm%mod*sqtm%mod+(LL)(p6[i].r+0.5)%mod*sqtm%mod+(LL)(p7[i].r+0.5)%mod)%mod;
}
int aa[M],bb[M],an[N];
    
int main()
{
    n=rd(),m=rd(),l=rd(),x=rd()-1,y=rd()-1,mod=rd();
    sqtm=sqrt(mod);
    for(int i=0;i<3;++i) me.a[i][i]=1;
    for(int i=0;i<n;++i)
        for(int j=0;j<n;++j)
            mab.a[i][j]=rd();
    maa.a[0][x]=1;
    wm=getw(mod),wk=fpow(wm,(mod-1)/m);
    for(int i=0,j=1;i<m;++i,j=1ll*j*wk%mod)
        f[i]=(maa*((mab*j+me)^l)).a[0][y];
    int ll=0,invwk=inv(wk);
    nn=1;
    while(nn<(m<<2)) nn<<=1,++ll;
    for(int i=0;i<nn;++i) rdr[i]=(rdr[i>>1]>>1)|((i&1)<<(ll-1));
    for(int i=0;i<=m+m;++i) aa[i]=fpow(wk,(c2(i-m)+mod-1)%(mod-1));
    for(int i=0;i<m;++i) bb[m-i]=1ll*f[i]*fpow(invwk,(c2(i)+mod-1)%(mod-1))%mod;
    mul(aa,bb);
    int invk=inv(m);
    for(int i=0;i<m;++i) an[i]=1ll*invk*fpow(invwk,(c2(-i)+mod-1)%(mod-1))%mod*aa[m+m-i]%mod;
    for(int i=0;i<m;++i) printf("%d
",an[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/smyjr/p/10699867.html