XJOI NOI训练2 传送

NTT循环卷积

30分:

可以发现这是一个很明显的分层$DP$,设$dp[i][j]$表示当前走了j步走到i号节点的方案数。如果当前走的步数对节点有限制就直接将这个点的$DP$值赋成$0$

#include <bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;
const int N=1e5+100,M=21;
int n,l,m,k,x[M],y[M],a[N];
inline void add(ll &x,ll y){x=(x+y)%mod;}
inline void mul(ll &x,ll y){x=(x*y)%mod;}
inline void del(ll &x,ll y){x=(x-y+mod)%mod;}
namespace subtask1
{
    ll dp[210][210];
    int vi[210][210];
    void solve()
    {
        dp[0][0]=1;
        for (int i=1;i<=m;i++) vi[y[i]][x[i]]=1;
        for (int i=0;i<l;i++)
        {
            for (int j=0;j<n;j++)
            {
                if (dp[j][i]==0 || vi[j][i]) continue;
                for (int p=1;p<=k;p++) add(dp[(j+a[p])%n][i+1],dp[j][i]);
            }
        }
        printf("%lld
",dp[0][l]);
    }
}
int main()
{
    scanf("%d%d",&n,&l);
    scanf("%d",&m);
    for (int i=1;i<=m;i++) scanf("%d%d",&x[i],&y[i]);
    scanf("%d",&k);
    for (int i=1;i<=k;i++) scanf("%d",&a[i]);
    subtask1::solve();
}

45分:

这个$DP$方程很明显可以用矩阵快速幂优化,因为有限制的点只有$m$个,数量很小,那么最两个限制之间用矩阵快速幂加速递推,当遇到一个限制的时候就停下来,将有限制的点在矩阵中的数改成$0$。不断重复这个过程直到递推到$l$

时间复杂度$O(mn^{3}logL)$

然而这个复杂度和暴力在分数上没有区别

for (int i=0;i<n;i++)
{
     for (int j=1;j<=k;j++) tr.a[i+1][(i+a[j])%n+1]++;
}

这个转移的矩阵其实是一个循环矩阵,一个向量也可以看作一个循环矩阵,那么初始矩阵可以跟转移矩阵直接循环乘积

循环矩阵相乘只要记录第一行的数字相乘

$m_{x}=sum_{(i+j-2)\%n+1=x} m_{i}m_{j}$

利用上面的方法计算即可

时间复杂度$O(mn^{2}logL)$

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;
const int N=1e5+100,M=21;
int n,l,m,k,a[N];
struct node
{
    int x,y;
}sh[M];
struct matrix
{
    ll a[600][600],n;
    inline void clear(){memset(a,0,sizeof(a));}
    inline void init(){for(int i=1;i<=n;i++)a[i][i]=1;}
}tr;
matrix st;
inline int read()
{
    int f=1,x=0;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline void add(ll &x,ll y){x=(x+y)%mod;}
inline void mul(ll &x,ll y){x=(x*y)%mod;}
inline void del(ll &x,ll y){x=(x-y+mod)%mod;}
matrix operator *(matrix a,matrix b)
{
    matrix ans;
    ans.n=a.n;
    ans.clear();
    for (int i=1;i<=a.n;i++)
    {
        for (int j=1;j<=a.n;j++)
        {
            for (int k=1;k<=a.n;k++)
              add(ans.a[i][j],(a.a[i][k]*b.a[k][j])%mod);
        }
    }
    return ans;
}
matrix m_pow(matrix a,int b)
{
    matrix ans;
    ans.n=a.n;
    ans.clear();
    ans.init();
    while (b)
    {
        if (b&1) ans=ans*a;
        b>>=1;
        a=a*a;
    }
    return ans;
}
bool cmp(node a,node b)
{
    return a.x<b.x;
}
namespace subtask2
{
    void solve()
    {
        st.n=tr.n=n;
        for (int i=0;i<n;i++)
        {
            for (int j=1;j<=k;j++) tr.a[i+1][(i+a[j])%n+1]++;
        }
        st.a[1][1]=1;
        sort(sh+1,sh+1+m,cmp);
        sh[0].x=0;
        for (int i=1;i<=m;i++)
        {
            st=st*m_pow(tr,sh[i].x-sh[i-1].x);
            st.a[1][sh[i].y+1]=0;
        }
        st=st*m_pow(tr,l-sh[m].x);
        printf("%lld
",st.a[1][1]);
    }
}
int main()
{
    scanf("%d%d",&n,&l);
    scanf("%d",&m);
    for (int i=1;i<=m;i++) scanf("%d%d",&sh[i].x,&sh[i].y);
    scanf("%d",&k);
    for (int i=1;i<=k;i++) scanf("%d",&a[i]);
    subtask2::solve();
}

65分:

对于循环矩阵的乘积可以发现这是一个循环卷积的形式,直接用$NTT$优化

细节:要将下标减一做$NTT$

时间复杂度$O(nmlognlogL)$

其实到这里离正解只差了一步

jzy就此与$AC$失之交臂

太棒了

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define mk make_pair
const int N=260;
const int LEN=66000;
const int STEP=505;
const int MOD=998244353;
int ADD(int x,int y){return x+y>=MOD ? x+y-MOD : x+y;}
int MUL(int x,int y){return 1ll*x*y%MOD;}
ll dp[STEP][N],bl[STEP][N];
pii limit[35];
   
int n,l,m,K,aa[100005];
void init()
{
    scanf("%d%d",&n,&l);
    scanf("%d",&m);
    for(int i=1;i<=m;i++) scanf("%d%d",&limit[i].first,&limit[i].second);
    scanf("%d",&K);
    for(int i=1;i<=K;i++) scanf("%d",&aa[i]);
}
   
void subtask1()
{
    for(int i=1;i<=m;i++)
        bl[limit[i].first][limit[i].second]=1;
    dp[0][0]=1;
    for(int i=1;i<=l;i++)
    {
        for(int j=0;j<n;j++)
        {
            for(int t=1;t<=K;t++)
            {
                int pos=(j+aa[t])%n;
                if(bl[i][pos]) continue;
                dp[i][pos]=(dp[i][pos]+dp[i-1][j])%MOD;
            }
        }
    }
    printf("%lld
",dp[l][0]);
}
   
int Qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1) ret=MUL(ret,x);
        x=MUL(x,x);
        y>>=1;
    }
    return ret;
}
   
struct matrix{
    int n,a[LEN];
    matrix(){
        memset(a,0,sizeof(a));
    }
    matrix(int n):n(n){
        memset(a,0,sizeof(a));
    }
};
   
int rev[LEN*2],len,k;
void change(int len,int k)
{
    rev[0]=0; rev[len-1]=len-1;
    for(int i=1;i<len-1;i++)
    {
        rev[i]=rev[i>>1]>>1;
        if(i&1) rev[i]+=(1<<(k-1));
    }
}
 
int Wn[LEN*2],Wn_1[LEN*2];
int inv_len;
void ntt(int a[],int len,int flag)
{
    for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int h=1;h<len;h<<=1)
    {
        int wn=Wn[h<<1];
        if(flag==-1) wn=Wn_1[h<<1];
        int tmp1,tmp2;
        for(int i=0;i<len;i+=h*2)
        {
            int w=1;
            for(int j=i;j<i+h;j++)
            {
                //w=w*wn;
                tmp1=a[j],tmp2=1LL*a[j+h]*w%MOD;
                a[j]=(tmp1+tmp2)%MOD;
                a[j+h]=(tmp1-tmp2+MOD)%MOD;
                w=1LL*w*wn%MOD;
            }
        }
    }
    if(flag==-1)
    {
        for(int i=0;i<=len;i++) a[i]=1LL*a[i]*inv_len%MOD;
    }
}
   
int a[LEN*2],b[LEN*2];
matrix operator * (matrix A,matrix B)
{
    memset(a,0,sizeof(a));
    memset(b,0,sizeof(b));
    for(int i=0;i<n;i++) a[i]=A.a[i];
    for(int i=0;i<n;i++) b[i]=B.a[i];
    ntt(a,len,1); ntt(b,len,1);
    for(int i=0;i<len;i++) a[i]=MUL(a[i],b[i]);
    ntt(a,len,-1);
    for(int i=0;i<n;i++) A.a[i]=ADD(a[i],a[i+n]);
    return A;
}
   
matrix qpow(matrix A,int y)
{
    matrix C(A.n); C.a[0]=1;
    while(y)
    {
        if(y&1) C=C*A;
        A=A*A;
        y>>=1;
    }
    return C;
}
   
matrix ans,Base;
int exi_step[LEN];
void build()
{
    memset(exi_step,0,sizeof(exi_step));
    for(int i=1;i<=K;i++) exi_step[aa[i]]++;
    for(int i=0;i<n;i++) Base.a[i]=exi_step[i];
}
   
void subtask2()
{
    int now_step=0;
    ans.n=n; Base.n=n; ans.a[0]=1;
    build();
    sort(limit+1,limit+m+1);
    matrix tmp(n);
    for(int i=1,j;i<=m;i=j+1)
    {
        j=i;
        while(j<m&&limit[j+1].first==limit[i].first) j++;
        tmp=qpow(Base,limit[i].first-now_step); now_step=limit[i].first;
        ans=ans*tmp;
        for(int t=i;t<=j;t++) ans.a[limit[t].second]=0;
    }
    Base=qpow(Base,l-now_step);
    ans=ans*Base;
    printf("%d
",ans.a[0]);
}
   
signed main()
{
    init();
    k=0,len=1;
    while(len<n+n) len<<=1,k++;
    change(len,k);
    for(int h=1;h<=len;h<<=1)
    {
        Wn[h]=Qpow(3,(MOD-1)/h);
        Wn_1[h]=Qpow(Wn[h],MOD-2);
    }
    inv_len=Qpow(len,MOD-2);
    subtask2();
}

100分:

其实65分的那个做法是在每一次矩阵的乘法中的时候都要做一遍$NTT$

其实并不需要,可以把一个循环矩阵看作一个多项式,其实就是一个多项式的快速幂(循环卷积)

将多项式$DFT$后转为点值表示形式后,直接对每一个点的点值做快速幂,然后$IDFT$还原回去

然后有一个细节,其实$DFT$实现的就是$len$长度的循环卷积,平时使用的$FFT$,$NTT$都是通过补$0$,来用循环卷积实现线性卷积

这道题中保证了$n$是$2$的次幂,直接$DFT$后做快速幂就可以了

$P.S.$对于任意长度循环卷积$CZT$,利用Bluestein’s Algorithm,网址

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include <bits/stdc++.h>
#define mod 998244353
#define ll long long
#define re register int
using namespace std;
const int N=1e5+100,M=21;
int n,l,m,k,a[N],cnt,rev[N];
ll st[N],tr[N];
struct node
{
    int x,y;
}sh[M];
inline int read()
{
    int f=1,x=0;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline void add(ll &x,ll y){x=(x+y)%mod;}
inline void mul(ll &x,ll y){x=(x*y)%mod;}
inline void del(ll &x,ll y){x=(x-y+mod)%mod;}
inline ll m_pow(ll a,int b)
{
    ll ans=1;
    while (b)
    {
        if (b&1) ans=(ans*a)%mod;
        b>>=1;
        a=(a*a)%mod;
    }
    return ans;
}
bool cmp(node a,node b)
{
    return a.x<b.x;
}
inline void change(int len)
{
    for (int i=0;i<len;i++)
    {
        rev[i]=rev[i>>1]>>1;
        if (i&1) rev[i]|=len>>1;
    }
}
inline void ntt(ll y[],int len,int v)
{
    for (int i=0;i<len;i++) if (i<rev[i]) swap(y[i],y[rev[i]]);
    for (int i=2;i<=len;i<<=1)
    {
        ll step=m_pow(3,(mod-1)/i);
        if (v==-1) step=m_pow(step,mod-2);
        for (int j=0;j<len;j+=i)
        {
            ll x=1;
            for (int k=j;k<j+i/2;k++)
            {
                ll a=y[k],b=(x*y[k+i/2])%mod;
                y[k]=(a+b)%mod;
                y[k+i/2]=(a-b+mod)%mod;
                x=(x*step)%mod;
            }
        }
    }
    if (v==-1)
    {
        int invlen=m_pow(len,mod-2);
        for (int i=0;i<len;i++) y[i]=(y[i]*invlen)%mod;
    }
}
int main()
{
    scanf("%d%d",&n,&l);
    scanf("%d",&m);
    for (re i=1;i<=m;++i) scanf("%d%d",&sh[i].x,&sh[i].y);
    scanf("%d",&k);
    for (re i=1;i<=k;++i) scanf("%d",&a[i]);
    for (re i=1;i<=k;++i) tr[a[i]%n]++;
    st[0]=1;
    sort(sh+1,sh+1+m,cmp);
    sh[0].x=0;
    change(n);
    ntt(tr,n,1);
    for (re i=1;i<=m;++i)
    {
        ntt(st,n,1);
        for (re j=0;j<n;j++) st[j]=(st[j]*m_pow(tr[j],sh[i].x-sh[i-1].x))%mod;
        ntt(st,n,-1);
        st[sh[i].y]=0;
    }
    ntt(st,n,1);
    for (re i=0;i<n;i++) st[i]=(st[i]*m_pow(tr[i],l-sh[m].x))%mod;
    ntt(st,n,-1);
    printf("%lld
",st[0]);
}
原文地址:https://www.cnblogs.com/huangchenyan/p/13387214.html