LOJ#575. 「LibreOJ NOI Round #2」不等关系 容斥+分治NTT

容斥+分治NTT.    

令 $dp[i]$ 表示以 $i$ 结尾的方案数.  

如果只有小于号的话 $dp[i]$ 是非常好求的:$frac{n!}{prod a_{i}}$ 即总阶乘除以每一个小于号连续段.    

有大于号的时候考虑容斥:   

遇到第一个大于号的时候先不考虑当前位置关系,方案数就是 $dp[j] imes inom{i}{i-j}$.   

那么我们多加了当前位置是小于号的情况,需要在下一次减掉.  

遇到第二个大于号的时候也不考虑当前位置关系,减掉 $dp[j] imes inom{i}{i-j}$,这时将上面多加的减掉了,但是又多减了两个位置都是小于号的方案数.   

所以我们就得到了一个容斥式子:$dp[i]=sum_{j=0}^{i-1} [s_{j}='>'] (-1)^{c[i-1]-c[j]}dp[j] imes inom{i}{i-j}$     

这个式子可以用分治 NTT 优化到 $O(n log^2 n)$.  

code:  

#include <vector>
#include <cstdio> 
#include <cstring> 
#include <algorithm>   
#define N 100007    
#define ll long long 
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;  
char str[N];  
int A[N<<2],B[N<<2],bu[N];    
int c[N],fac[N],inv[N],dp[N],g[N],n;     
int qpow(int x,int y) {  
    int tmp=1;  
    for(;y;y>>=1,x=(ll)x*x%mod)  { 
        if(y&1) tmp=(ll)tmp*x%mod;  
    }  
    return tmp;  
}   
int get_inv(int x) { 
    return qpow(x,mod-2);  
}   
void NTT(int *a,int len,int op) {  
    for(int i=0,k=0;i<len;++i) { 
        if(i>k) swap(a[i],a[k]);  
        for(int j=len>>1;(k^=j)<j;j>>=1);  
    }   
    for(int l=1;l<len;l<<=1) {  
        int wn=qpow(3,(mod-1)/(l<<1));   
        if(op==-1) { 
            wn=get_inv(wn); 
        }       
        for(int i=0;i<len;i+=l<<1) { 
            int w=1,x,y;  
            for(int j=0;j<l;++j) { 
                x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod;  
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod;  
            }
        }
    }    
    if(op==-1) {      
        int iv=get_inv(len); 
        for(int i=0;i<len;++i) {  
            a[i]=(ll)a[i]*iv%mod;  
        }
    }
}
void init() {  
    fac[0]=inv[1]=1;  
    for(int i=1;i<N;++i) { 
        fac[i]=(ll)fac[i-1]*i%mod;  
    }  
    for(int i=2;i<N;++i) { 
        inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;  
    }    
    inv[0]=1;  
    for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod;  
}         
void solve(int l,int r) {   
    if(l==r) { 
        return;  
    }   
    int mid=(l+r)>>1,s1=0,s2=0,lim;  
    solve(l,mid);  
    for(int i=l;i<=mid;++i) {      
        if(str[i]=='<') A[s1++]=0;  
        else A[s1++]=(ll)dp[i]*bu[i+1]%mod;   
    }
    for(int i=0;i<=r-l;++i) {   
        B[s2++]=g[i];  
    }      
    for(lim=1;lim<(s1+s2);lim<<=1);  
    for(int i=s1;i<lim;++i) A[i]=0;  
    for(int i=s2;i<lim;++i) B[i]=0;  
    NTT(A,lim,1),NTT(B,lim,1);  
    for(int i=0;i<lim;++i) { 
        A[i]=(ll)A[i]*B[i]%mod;  
    }  
    NTT(A,lim,-1);   
    for(int i=mid+1;i<=r;++i) {    
        (dp[i]+=(ll)bu[i]*A[i-l]%mod)%=mod;  
    }
    for(int i=0;i<lim;++i) A[i]=B[i]=0;  
    solve(mid+1,r);   
}
int main() { 
    // setIO("input");  
    init();   
    scanf("%s",str+1);          
    n=strlen(str+1)+1;    
    for(int i=1;i<n;++i) {   
        c[i]=c[i-1]+(str[i]=='>');  
    }   
    for(int i=1;i<=n;++i) {    
        if(c[i-1]&1) bu[i]=mod-1;  
        else bu[i]=1;  
    }
    dp[0]=1;       
    for(int i=1;i<=n;++i) 
        g[i]=inv[i];       
    solve(0,n);  
    printf("%d
",(ll)dp[n]*fac[n]%mod);  
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/13356438.html