[SDOI2019]染色(DP)

好神的题啊!

看了这题只会第一个subtask,又参考了HN-CJ鸽王zsy的题解,实在太菜了。

暴力转移是O(nc2),很显然没有分。考虑子任务1,2,只需要转移包含已染色格子的列,然后状态数只有O(nc),对于关键两列(即有染色的列)间,只有5种状态。而这个可以初始化转移,转移讨论有点复杂,而且我不会用数学公式,就不打出吧。转移后即可直接DP。然后对于子任务3,4,把它们分割即可,把两边方案乘起来就行了,于是可以做到O(nc),得到96分的好成绩。然后听Claris所述,DP所有转移操作即为T1的操作,于是可以做到O(n+c)

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7,mod=1e9+9;
int n,m,tot,ans=1,a[N],b[N],c[N],p[N],g[N][5];
int qpow(int a,int b)
{
    int ret=1;
    while(b)
    {
        if(b&1)ret=1ll*ret*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return ret;
}
struct array{
    int mul,add,inv,sum,top,s[N<<2],f[N];
    array(){mul=inv=1;}
    void modify(int x,int v)
    {
        sum=(sum+1ll*(mod-f[x])*mul+mod-add)%mod;
        f[x]=1ll*(v-add+mod)*inv%mod,s[++top]=x,sum=(sum+v)%mod;
    }
    void plus(int v){sum=(sum+1ll*m*v)%mod,add=(add+v)%mod;}
    void cover(int v)
    {
        while(top)f[s[top--]]=0;
        mul=inv=1,add=v,sum=1ll*m*v%mod;
    }
    void mult(int v)
    {
        if(v)
        sum=1ll*sum*v%mod,mul=1ll*mul*v%mod,add=1ll*add*v%mod,inv=1ll*inv*qpow(v,mod-2)%mod;
        else cover(0);
    }
    int query(int x)
    {
        if(x)return(1ll*f[x]*mul+add)%mod;
        return sum;
    }
}F;
void trans(int x,int y,int z,int w)
{
    if(x!=y)
    {
        int fy=F.query(y),sum=(F.query(0)-fy+mod)%mod;
        F.mult((g[z][2^w]-g[z][4]+mod)%mod);
        F.plus((1ll*sum*g[z][4]+1ll*fy*g[z][3^w])%mod);
        F.modify(x,(1ll*sum*g[z][3^w]+1ll*fy*g[z][1^w])%mod);
        F.modify(y,0);
    }
    else{
        int sum=F.query(0);
        F.mult((g[z][0^w]-g[z][2^w]+mod)%mod);
        F.plus(1ll*sum*g[z][2^w]%mod);
        F.modify(x,0);
    }
}
void build(int x,int y,int z,int w)
{
    if(w==x)F.cover(g[z][2]),F.modify(x,0),F.modify(y,g[z][0]);
    else if(w==y)F.cover(g[z][3]),F.modify(x,g[z][1]),F.modify(y,0);
    else F.cover(g[z][4]),F.modify(x,g[z][3]),F.modify(y,g[z][2]),F.modify(w,0);
}
void solve(int x,int y,int z,int w)
{
    int ret=0;
    if(w==x)
    {
        int fy=F.query(y),sum=(F.query(0)-fy+mod)%mod;
        ret=(1ll*sum*g[z][2]+1ll*fy*g[z][0])%mod;
    }
    else if(w==y)
    {
        int fx=F.query(x),sum=(F.query(0)-fx+mod)%mod;
        ret=(1ll*sum*g[z][3]+1ll*fx*g[z][1])%mod;
    }
    else{
        int fx=F.query(x),fy=F.query(y),sum=(1ll*F.query(0)-fx-fy+2*mod)%mod;
        ret=(1ll*sum*g[z][4]+1ll*fx*g[z][3]+1ll*fy*g[z][2])%mod;
    }
    ans=1ll*ans*ret%mod;
}
int cal(int x,int y,int z,int u,int v)
{
    if(x==u)return y==v?g[z][0]:g[z][2];
    if(x==v)return y==u?g[z][1]:g[z][3];
    if(y==u)return g[z][3];
    if(y==v)return g[z][2];
    return g[z][4];
}
int main()
{
    scanf("%d%d",&n,&m);
    int mp[5][5]={{0,1,0,m-2,1ll*(m-2)*(m-3)%mod},
    {1,0,m-2,0,1ll*(m-2)*(m-3)%mod},{0,2,m-2,2*m-5,2ll*(m-3)*(m-3)%mod},
    {2,0,2*m-5,m-2,2ll*(m-3)*(m-3)%mod},{1,1,m-3,m-3,1ll*(m-3)*(m-4)%mod+1}};
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        if(a[i]&&a[i]==a[i-1]){puts("0");return 0;}
    }
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&b[i]);
        if(b[i]&&(a[i]==b[i]||b[i]==b[i-1])){puts("0");return 0;}
    }
    g[0][0]=1;
    for(int i=1;i<=n;i++)
    for(int j=0;j<5;j++)
    for(int k=0;k<5;k++)
    g[i][k]=(g[i][k]+1ll*g[i-1][j]*mp[j][k])%mod;
    int v1=qpow(m-2,mod-2),v2=qpow(1ll*(m-2)*(m-3)%mod,mod-2);
    for(int i=0;i<=n;i++)
    g[i][2]=1ll*g[i][2]*v1%mod,g[i][3]=1ll*g[i][3]*v1%mod,g[i][4]=1ll*g[i][4]*v2%mod;
    for(int i=1;i<=n;i++)
    if(a[i]||b[i])
    {
        p[++tot]=i;
        if(b[i]&&!a[i])swap(a[i],b[i]),c[i]=1;
        else if(b[i])c[i]=2;
    }
    int pw=qpow(1ll*(m-1)*(m-2)%mod+1,p[1]-1);
    if(c[p[1]]<2)F.cover(pw),F.modify(a[p[1]],0);else ans=pw;
    for(int i=2;i<=tot;i++)
    if(c[p[i-1]]==2)
        if(c[p[i]]==2)ans=1ll*ans*cal(a[p[i-1]],b[p[i-1]],p[i]-p[i-1],a[p[i]],b[p[i]])%mod;
        else{
            if(c[p[i]])swap(a[p[i-1]],b[p[i-1]]);
            build(a[p[i-1]],b[p[i-1]],p[i]-p[i-1],a[p[i]]);
            if(c[p[i]])swap(a[p[i-1]],b[p[i-1]]);
        }
    else if(c[p[i]]==2)
    {
        if(c[p[i-1]])swap(a[p[i]],b[p[i]]);
        solve(a[p[i]],b[p[i]],p[i]-p[i-1],a[p[i-1]]);
        if(c[p[i-1]])swap(a[p[i]],b[p[i]]);
    }
    else trans(a[p[i-1]],a[p[i]],p[i]-p[i-1],c[p[i]]^c[p[i-1]]);
    if(c[p[tot]]<2)ans=1ll*ans*F.query(0)%mod;
    ans=1ll*ans*qpow(1ll*(m-1)*(m-2)%mod+1,n-p[tot])%mod;
    printf("%d",ans);
}
View Code
原文地址:https://www.cnblogs.com/hfctf0210/p/10834784.html