BZOJ 5322: [Jxoi2018]排序问题 模拟+贪心

细节比较多,好多地方容易写挂. 

code: 

#include <cstdio>
#include <map>    
#include <string> 
#include <algorithm> 
#define N 200005            
#define ll long long              
#define MAXN 11000000   
#define mod 998244353    
using namespace std; 
namespace IO
{
    char buf[100000],*p1,*p2;
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int rd()
    {
        int x=0;   
        char s=nc();
        while(s<'0') s=nc();
        while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
        return x;
    }    
    void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');}
    void setIO(string s)
    {
        string in=s+".in";
        string out=s+".out";
        freopen(in.c_str(),"r",stdin);
        freopen(out.c_str(),"w",stdout);
    }
};       
int fac[MAXN],inv[MAXN],a[N],bu[MAXN],in[MAXN];      
int qpow(int x,int y) 
{
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod)  
        if(y&1) tmp=(ll)tmp*x%mod;  
    return tmp; 
}             
void solve() 
{
    using namespace IO;  
    int n,m,l,r,ty=0,len,tot,i,j=0,k; 
    n=rd(),m=rd(),l=rd(),r=rd(),len=r-l+1,tot=fac[n+m];        
    for(i=1;i<=n;++i) a[i]=rd();  
    sort(a+1,a+1+n);   
    de2+=n;   
    for(i=1;i<=n;i=k) 
    {
        k=i; 
        while(a[k]==a[i]&&k<=n) ++k;    
        int cur=k-i;                    
        if(a[i]>=l&&a[i]<=r) 
        { 
            ++bu[cur];    
            ++ty; 
            j=max(j,cur);    
            tot=(ll)tot*inv[cur]%mod;    
        }
        else
        {
            tot=(ll)tot*inv[cur]%mod;  
        }
    }        
    bu[0]=len-ty;             
    for(i=0;i<=j;++i) 
    {
        tot=(ll)tot*qpow(in[i+1],min(bu[i],m))%mod;         
        m-=min(bu[i],m);                 
        if(!m) break;   
        bu[i+1]+=bu[i];    
    }                 
    if(i==j+1&&m) 
    {            
        int t=bu[j+1];       
        int st=j+2;     
        int ed=j+1+(m/t);        
        int remain=m-(m/t)*t;       
        int de=qpow((ll)fac[ed]*inv[st-1]%mod,t);     
        tot=(ll)tot*qpow(de,mod-2)%mod;    
        if(remain) tot=(ll)tot*qpow(qpow(ed+1,remain),mod-2)%mod;                  
    } 
    printf("%d
",tot);               
    for(i=0;i<=j+1;++i) bu[i]=0;         
}
int main() 
{
    using namespace IO; 
    // setIO("input");   
    int T,i,j;  
    T=rd();   
    fac[0]=1;  
    for(i=1;i<MAXN;++i) fac[i]=(ll)fac[i-1]*i%mod;   
    inv[MAXN-1]=qpow(fac[MAXN-1],mod-2);  
    for(i=MAXN-1;i;--i) inv[i-1]=(ll)inv[i]*i%mod; 
    in[0]=in[1]=1;             
    for(i=2;i<N;++i)  in[i]=(ll)(mod-mod/i)*in[mod%i]%mod;           
    while(T--)solve();    
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/12245944.html