uoj#340. 【清华集训2017】小 Y 和恐怖的奴隶主(矩阵加速)

传送门

uoj上的数据太毒了……也可能是我人傻常数大的缘故……

三种血量的奴隶主加起来不超过(8)个,可以枚举每种血量的奴隶主个数,那么总的状态数只有(165)种,设(dp_{t,i,j,k})表示(t)时刻的时候(i)个一血奴隶主,(j)个二血奴隶主,(k)个三血奴隶主的概率,那么转移很明显

if(A)dp[i][A][B][C]+=dp[i-1][A-1][B][C]*A/tot;
if(B)dp[i][A][B][C]+=dp[i-1][A+1][B-1][C+ADD]*B/tot;
if(C)dp[i][A][B][C]+=dp[i-1][A][B+1][C-1+ADD]*C/tot;

(tot)(A+B+C+1)(ADD)代表当前奴隶主总数是否小于(k)

然后期望伤害的话再记一个(E)数组就可以了

上面那个其实就是本题弱化版的题解,本题的话还需要一个矩阵加速

顺带一提,记得卡常

//minamoto
#include<bits/stdc++.h>
#define R register
#define ll long long
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
ll read(){
    R ll res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R int x){
    if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='
';
}
const int N=185,M=15,P=998244353;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
int len,k,m;ll n,y;map<ll,int>mp;
int inv[N],id[M][M][M],ans[N],res[N];
struct Matrix{
    int a[N][N];
    Matrix(){memset(a,0,sizeof(a));}
    inline int* operator [](const int &x){return a[x];}
    friend Matrix operator *(Matrix a,Matrix b){
        Matrix res;
        fp(i,1,len+1)fp(k,1,len+1)if(a[i][k])fp(j,1,len+1){
            res[i][j]=add(res[i][j],1ll*a[i][k]*b[k][j]%P);
        }
        return res;
    }
}dp[N];
#define rep fp(A,0,k)fp(B,0,(m>1?k-A:0))fp(C,0,(m>2?k-A-B:0))
void init(){
    inv[0]=inv[1]=1;
    fp(i,2,10)inv[i]=1ll*inv[P%i]*(P-P/i)%P;
    rep id[A][B][C]=++len;
    rep{
        int i=id[A][B][C],Inv=inv[A+B+C+1],add=(A+B+C<k);
        if(m==1){
            if(A)dp[0][i][id[A-1][B][C]]=1ll*A*Inv%P;
        }else if(m==2){
            if(A)dp[0][i][id[A-1][B][C]]=1ll*A*Inv%P;
            if(B)dp[0][i][id[A+1][B-1+add][C]]=1ll*B*Inv%P;
        }else{
            if(A)dp[0][i][id[A-1][B][C]]=1ll*A*Inv%P;
            if(B)dp[0][i][id[A+1][B-1][C+add]]=1ll*B*Inv%P;
            if(C)dp[0][i][id[A][B+1][C-1+add]]=1ll*C*Inv%P;
        }
        dp[0][i][i]=dp[0][i][len+1]=Inv;
    }
    dp[0][len+1][len+1]=1;
    fp(i,1,60)dp[i]=dp[i-1]*dp[i-1];
}
void Mul(R int *A,R Matrix &B){
    fp(i,1,len+1)res[i]=0;
    fp(i,1,len+1)if(A[i])fp(j,1,len+1)res[j]=add(res[j],1ll*A[i]*B[i][j]%P);
    fp(i,1,len+1)A[i]=res[i];
}
int main(){
//	freopen("testdata.in","r",stdin);
    int T=read();m=read(),k=read();
    init();
    while(T--){
        n=read();
		if(mp.count(n)){print(mp[n]);continue;}
        fp(i,1,len+1)ans[i]=0;
        ans[id[m==1][m==2][m==3]]=1;
        y=n;
        for(R int i=0;y;++i,y>>=1)if(y&1)Mul(ans,dp[i]);
		mp[n]=ans[len+1];
        print(ans[len+1]);
    }
    return Ot(),0;
}
原文地址:https://www.cnblogs.com/bztMinamoto/p/10272438.html