【考试总结】20220108

传统题

考虑枚举 \(n\) 最终被划分成了 \(i\) 段,那么直接对其染色附加系数就是 \(m\times(m-1)^{i-1}\)

\(f(k)\) 表示最大连续段长度为 \(k\) 的序列的方案数,不难写出:

\[f(k)=\sum_{i=1}^{n} m(m-1)^{i-1}[x^n]\left(\frac{1-x^{k+1}}{1-x}\right)^i \]

后面一部分可以枚举几个选择的次数大于限制的 \(k\),剩下的是一个插板法,得到:

\[f(k)=\sum_{i=1}^{n} m(m-1)^{i-1}\sum_{j=0}^i(-1)^j\binom{i}j\binom{n-j(k+1)+i-1}{i-1} \]

使用吸收恒等式将这后面的 \(1\) 取代了,并交换枚举顺序,答案可以表达为

\[Ans=n^m-\sum_{i=0}^{n-1}\sum_{k=0}^n (-1)^j\frac{1}{n-ik}\sum_{j=i}^nm(m-1)^{j-1}j\binom jk\binom{n-ik}{j} \]

前面部分已经可以调和级数枚举了,后半部分考虑组合意义:

计算表达式是表示在 \(n-ik\) 里面先选择 \(j\) 个球,选择其中的 \(1\) 个特殊球并让其它球染上 \(m-1\) 种颜色,再在选中的球里面选择 \(k\) 个,分开讨论特殊球有没有被 \(\binom{j}k\) 选中:

  • 如果被 \(\binom jk\) 选中,那么组合意义就是 \(k\) 个球中还有 \(k-1\) 个球要染色,再在 \(n-(i+1)k\) 个球中选择大小为 \(j-k\) 的子集染色,注意 \(j\) 是求和,那么就是在 \(n-(i+1)k\) 个中选择一个子集染色

    \(m-1\) 种颜色扩展为 \(m\) 色,最后一种表示是否被选择了,用二项式定理理解也可以

    这部分表达式为 \(\binom{n-ik}{k}k(m-1)^{k-1}m^{n-(i+1)k}\)

  • 如果没有被选中,那么就在 \(n-(i+1)k\) 个中选择一个特殊球,剩下的和上面类似,表达式写作:\(\binom{n-ik}{k}(n-(i+1)k)(m-1)^{k}m^{n-(i+1)k-1}\)

其实这些都是可以代数推导的,详见大佬ycx的草稿纸

Code Display
const int N=3e5+10;
int pw0[N],pw1[N],inv[N],ifac[N],fac[N];
inline int C(int n,int m){return mul(fac[n],mul(ifac[m],ifac[n-m]));}
signed main(){
    n=read(); m=read(); mod=read();
    fac[0]=inv[0]=pw0[0]=pw1[0]=1; 
    for(int i=1;i<=n;++i) pw1[i]=mul(pw1[i-1],m-1),pw0[i]=mul(pw0[i-1],m);
    rep(i,1,n) fac[i]=mul(fac[i-1],i);
    ifac[n]=ksm(fac[n],mod-2);
    Down(i,n,1) ifac[i-1]=mul(ifac[i],i),inv[i]=mul(ifac[i],fac[i-1]);
    int ans=0;
    for(int i=0;i<n;++i){
        for(int k=0;k<=n;++k) if(n>=(i+1)*k){
            int val=mul(C(n-i*k,k),inv[n-i*k]);
            int v1=k?k*pw1[k-1]%mod*pw0[n-(i+1)*k]%mod:0;
            int v2=n-(i+1)*k-1>=0?pw1[k]*(n-(i+1)*k)%mod*pw0[n-(i+1)*k-1]%mod:0;
            ckmul(val,add(v1,v2));
            if(k&1) ckdel(ans,val); else ckadd(ans,val);
        }else break;
    }
    print(del(mul(n,pw0[n]),mul(ans,m)));
    return 0;
}

生成树

将红色边视为 \(1\),绿色边视为 \(x\),蓝色边视为 \(y\),那么答案是 \(\sum_{i=0}^G\sum_{j=0}^B[x^iy^j]F(x)\)

其中 \(F(x)\) 表示直接做矩阵树定理得到的多项式

不难发现这个多项式有 \(\frac{n(n-1)}2\) 个系数那么带等量点值进去做高斯消元即可跑过去,因为系数数量三方之后作为 \(\frac 18\) 的常数

为什么这样子消元一定有解呢?

简记 \(1\) 维范德蒙德行列式的求值过程:

从底向上将每行行乘 \(x_0\) 减到上一行,再取代数余子式(\(A_{0,0}\) 总能是 \(1\)),提出来的公因式就是每组 1\(x_i-x_0\left(i\in [1,n-1]\right)\)

后面是子问题,可以递归求解,结论是一阶范德蒙德行列式值为 \(\prod_{0\le i<j<n}x_j-x_i\)

二维 \([n\times n\times n\times n]\) 范德蒙德行列式的形状是 \(M_{i,j}=x_{i\%n}^{j/n}y_{i\%n}^{j/n}\left(i,j\in[0,n-1]\right)\)

消元的过程是类似的,将 \(n^4\) 的矩阵分成 \(n^2\)\(n\times n\) 的小块,以这个为单位消,还是让后面的列乘上前面的列乘 \(x_0\),考虑到这里提出公因式的时候每块有 \(n\) 个行,所以结论是

\[\rm Det=\prod_{0\le i<j<n}(x_j-x_i)^n(y_j-y_i)^n \]

这时候根据克拉默法则,系数矩阵是满秩的(行列式不为 \(0\) 的),直接选择一些个互不相同的点值就能得到系数了

关于范德蒙德行列式扩展到更高维度 \(d\) 的式值,考察提取公因子的时候块的个数知道是类似

\[\prod_{0\le i<j<n}\prod_{d=0}^{k-1} (x_{d\to j}-x_{d\to i})^{n^{d-1}} \]

状物

更严谨应该将每个维度的上界设为互不相同,但是这并不是问题所在

Code Display
int a[50][50],n,poly[1100];
inline int Det(){
    int ans=1;
    for(int i=1;i<n;++i){
        if(!a[i][i]){
            for(int j=i+1;j<n;++j) if(a[j][i]){
                swap(a[i],a[j]);
                ans=del(0,ans); 
                break;
            }
        }
        if(!a[i][i]) return 0;
        int inv=ksm(a[i][i],mod-2);
        ckmul(ans,a[i][i]);
        for(int j=i+1;j<n;++j){
            int tmp=mul(a[j][i],inv); if(!tmp) continue;
            for(int k=i;k<n;++k) ckdel(a[j][k],mul(a[i][k],tmp));
        }
    }
    return ans;
}
int tot,id[50][50],m,B,G,col[4][50][50],mat[1000][1000],pw[210][210];
signed main(){
    pw[0][0]=1;
    rep(i,1,200){
        pw[i][0]=1;
        rep(j,1,200) pw[i][j]=mul(pw[i][j-1],i);
    }
    n=read(); m=read(); G=read(); B=read();
    for(int i=1;i<=m;++i){
        int a=read(),b=read(),c=read();
        col[c][a][b]++; col[c][b][a]++;
    }
    for(int x=0;x<n;++x){
        for(int y=0;x+y<n;++y) id[x][y]=++tot;
    }
    tot=0;
    rep(vx,0,n-1) rep(vy,0,n-1-vx){
        memset(a,0,sizeof(a)); ++tot;
        int ex=vx+10,ey=vy+n+11;
        for(int x=0;x<n;++x){
            for(int y=0;x+y<n;++y){
                mat[tot][id[x][y]]=mul(pw[ex][x],pw[ey][y]);
            }
        }
        for(int i=1;i<=n;++i){
            for(int j=1;j<=n;++j) if(col[1][i][j]){
                a[i][j]-=col[1][i][j];
                a[i][i]+=col[1][i][j];
            }
        }
        for(int i=1;i<=n;++i){
            for(int j=1;j<=n;++j) if(col[2][i][j]){
                int tmp=ex*col[2][i][j];
                a[i][j]-=tmp;
                a[i][i]+=tmp;
            }
        }
        for(int i=1;i<=n;++i){
            for(int j=1;j<=n;++j) if(col[3][i][j]){
                int tmp=ey*col[3][i][j];
                a[i][j]-=tmp;
                a[i][i]+=tmp;
            }
        }
        rep(i,1,n) rep(j,1,n) a[i][j]=(a[i][j]%mod+mod)%mod;
        mat[tot][id[n-1][0]+1]=Det();
    }
    for(int i=1;i<=tot;++i){
        if(!mat[i][i]){
            for(int j=i+1;j<=tot;++j) if(mat[j][i]){
                swap(mat[i],mat[j]);
                break;
            }
        }
        if(!mat[i][i]) continue;
        int inv=ksm(mat[i][i],mod-2);
        for(int j=i+1;j<=tot;++j){
            int tmp=mul(mat[j][i],inv); if(!tmp) continue;
            for(int k=i;k<=tot+1;++k) ckdel(mat[j][k],mul(mat[i][k],tmp));
        }
    }
    poly[tot]=mul(mat[tot][tot+1],ksm(mat[tot][tot],mod-2));
    Down(i,tot-1,1) if(mat[i][i]){
        for(int j=i+1;j<=tot;++j) ckdel(mat[i][tot+1],mul(mat[i][j],poly[j]));
        poly[i]=mul(mat[i][tot+1],ksm(mat[i][i],mod-2));
    }
    int ans=0;
    for(int x=0;x<=G;++x){
        for(int y=0;y<=B;++y) ckadd(ans,poly[id[x][y]]);
    }
    print(ans);
    return 0;
}

最短路径

树上的部分和 BZOJ Normal 一题完全相同,不再赘述,而基环树中不跨过环的部分也要同样计算

设环长为 \(m\) 并将环写作一个序列,那么前 \([1,\frac m2],[\frac m2+1,m]\) 中各点向区间里面的点的路径是固定的,使用朴素的分治统计:设若处理区间 \([l,r]\) 中点 \(mid\)\([l,mid]\)\([mid+1,r]\) 递归处理,现在只算跨过分治中点的点对的贡献

那么处理左边每个点到 \(mid\) 的距离,右边每个点到 \(mid+1\) 的距离,卷起来再平移一位就可以了

对于 \([1,\frac m2],[\frac m2+1,m]\) 两个区间内各取一点的距离和,由于最短路不能确定,那么左边对半划开记作区间 \(1,2\),右边对半划开记作区间 \(3,4\)

这时候不难发现 区间 2 到 区间 3 的最短路是不走 \(\rm edge(1,m)\) 而区间 \(1,4\) 的最短路是走该边,这两部分贡献被确定,递归求解 区间 \(1,3\) 和 区间 \(2,4\) 即可

复杂度不好分析的是系数平移的部分,但是看起来并不能让计算量超过 \(4\times 10^9\) 所以用了个 vector 配 O2 还是比较快的

Code Display
inline poly Mul(poly a,poly b){
    int n=a.size(),m=b.size(),lim=1;
    while(lim<=(n+m)) lim<<=1;
    NTT(a,lim,1); NTT(b,lim,1);
    for(int i=0;i<lim;++i) ckmul(a[i],b[i]);
    NTT(a,lim,-1);
    a.resize(n+m-1);
    return a;
}
int u[N],v[N];
bool vis[N],ins[N],onc[N];
int f[N],ans,siz[N],rt,nsum,nrt;
inline void findrt(int x,int fat){
    f[x]=0; siz[x]=1;
    for(auto t:g[x]) if(!vis[t]&&(!onc[t]||t==nrt)&&t!=fat){
        findrt(t,x); siz[x]+=siz[t];
        ckmax(f[x],siz[t]);
    }
    ckmax(f[x],nsum-siz[x]);
    if(f[x]<f[rt]) rt=x;
    return ;
}
int ar[N],dmx;
inline void get_node(int x,int fat,int ndep){
    ar[ndep]++; ckmax(dmx,ndep);
    for(auto t:g[x]) if(!vis[t]&&t!=fat&&(t==nrt||!onc[t])) get_node(t,x,ndep+1);
    return ;
}
inline vector<int> turn(){
    vector<int> res; res.resize(dmx+1);
    for(int i=0;i<=dmx;++i) res[i]=ar[i],ar[i]=0;
    return res;
}    
int all[N];
inline void upd(vector<int> p,int fl){
    for(int i=0;i<p.size();++i) all[i]+=p[i]*fl;
    return ;
}
inline void solve(int x){
    vis[x]=1;
    int amx=0;
    vector<int> now={1};
    for(auto t:g[x]) if(!vis[t]&&(!onc[t]||t==nrt)){
        get_node(t,x,1);
        vector<int> p=turn();
        upd(Mul(p,p),-1);
        now.resize(max(now.size(),p.size()));
        for(int i=0;i<p.size();++i) now[i]+=p[i];
        dmx=0;
    }
    upd(Mul(now,now),1); now.clear(); now.shrink_to_fit();
    for(auto t:g[x]) if(!vis[t]&&(!onc[t]||t==nrt)){
        nsum=siz[t]; rt=0; findrt(t,0);
        solve(rt);
    }
    return ;
}
vector<int> G[N];
int stk[N],top,cir[N],num;
inline void findcir(int x,int pre){
    if(ins[x]){
        do cir[++num]=stk[top]; while(stk[top--]!=x);
        return ;
    }
    ins[x]=1; stk[++top]=x;
    for(auto e:G[x]) if(e^pre){
        int to=u[e]==x?v[e]:u[e];
        findcir(to,e);
        if(num) return ;
    }
    ins[x]=0; top--; return ;
}
poly dis[N];
inline void getpoly(int x,int fat,int ndep){
    nsum++;
    if(dis[nrt].size()<=ndep) dis[nrt].pb(1);
    else dis[nrt][ndep]++;
    for(int t:g[x]) if(t!=fat&&!onc[t]){
        getpoly(t,x,ndep+1);
    } return ;
}
inline void Merge(int l1,int r1,int l2,int r2,int shift){
    if(r1<l1||r2<l2) return ;
    poly lef,rig;
    for(int i=r1;i>=l1;--i){
        int siz=dis[cir[i]].size();
        for(int j=0;j<siz;++j){
            if(lef.size()<=j+r1-i) lef.pb(dis[cir[i]][j]);
            else lef[j+r1-i]+=dis[cir[i]][j];
        }
    }
    for(int i=l2;i<=r2;++i){
        int siz=dis[cir[i]].size();
        for(int j=0;j<siz;++j){
            if(rig.size()<=j+i-l2) rig.pb(dis[cir[i]][j]);
            else rig[j+i-l2]+=dis[cir[i]][j];
        }
    }
    vector<int> ee=Mul(lef,rig);
    for(int i=0;i<ee.size();++i){
        all[i+shift]+=ee[i];
    }
}
inline void solve_poly(int l,int r){
    if(l==r) return ;
    int mid=(l+r)>>1;
    solve_poly(l,mid); solve_poly(mid+1,r);
    Merge(l,mid,mid+1,r,1);
    return ;
}
inline void conquer(int l1,int r1,int l2,int r2){
    if(r1<l1||r2<l2) return ; 
    if(l1==r1&&l2==r2){
        return Merge(l1,r1,l2,r2,min(r2-l1,l1+num-r2));
    }
    int mid1=(l1+r1)>>1,mid2=(l2+r2)>>1;
    Merge(mid2+1,r2,l1,mid1,l1+num-r2); Merge(mid1+1,r1,l2,mid2,l2-r1);
    conquer(l1,mid1,l2,mid2); conquer(mid1+1,r1,mid2+1,r2);
}
signed main(){
    n=read(); k=read(); int invn=ksm(n*(n-1)/2%mod,mod-2);
    bool tree=0;
    for(int i=1;i<=n;++i){
        u[i]=read(),v[i]=read();
        if(u[i]==v[i]) tree=1;
        else g[u[i]].pb(v[i]),g[v[i]].pb(u[i]); 
    }
    f[0]=2e9;
    if(tree){
        nsum=n; findrt(1,0);
        solve(rt);
        int ans=0;
        for(int i=1;i<n;++i) ckadd(ans,mul(all[i],ksm(i,k)));
        print(mul((mod+1)/2,mul(ans,invn)));
        exit(0);
    }
    rep(i,1,n) if(u[i]!=v[i]) G[u[i]].pb(i),G[v[i]].pb(i);
    findcir(1,0);
    rep(i,1,num) onc[cir[i]]=1;
    rep(i,1,num){
        rt=nsum=0; getpoly(nrt=cir[i],0,0);
        findrt(cir[i],0);
        solve(rt);
    }
    for(auto &t:all) t/=2;
    int mid=(num+1)>>1;
    solve_poly(1,mid); solve_poly(mid+1,num);// clear distance pairs
    conquer(1,mid,mid+1,num);
    int ans=0;
    for(int i=1;i<=n;++i) ckadd(ans,mul(all[i],ksm(i,k)));
    print(mul(ans,invn));
    return 0;
}
原文地址:https://www.cnblogs.com/yspm/p/15779395.html