dtoj#4138. 染色(ranse)

 

题目描述:

算法标签:斯特林树,分治ntt

思路:

对于固定一个点,讨论他有几个叶子的情况,观察规律发现他的系数斯特林数。于是我们可以求出对于每一个节点建成多少个联通快的方案树,再把每一个节点的方案书卷积起来,用分治ntt维护。

对于求斯特林数,可以先求出小的一部分,对于大的一部分,再用斯特林数的通式求法。

以下代码:

#include<bits/stdc++.h>
#define il inline
#define LL long long
#define vet vector<int>
#define _(d) while(d(isdigit(ch=getchar())))
using namespace std;
const int N=3e5+5,p=998244353;
int G[2],jc[N],ny[N],in[N];
int n,c,k,sz,s[318][318],a[N],b[N],num,sum[N];
il int read(){
    int x,f=1;char ch;
    _(!)ch=='-'?f=-1:f;x=ch^48;
    _()x=(x<<1)+(x<<3)+(ch^48);
    return f*x;
}
il int ksm(LL a,int y){
    LL b=1;
    while(y){
        if(y&1)b=b*a%p;
        a=a*a%p;y>>=1;
    }
    return b;
}
il int mu(int x,int y){
    if(x+y>=p)return x+y-p;
    return x+y;
}
class ntt{
    int v[N],t,l;
    il void init(int x){
        t=1;l=0;
        while(t<=x)t<<=1,l++;
        for(int i=0;i<t;i++)v[i]=(v[i>>1]>>1)|((i&1)<<l-1);
    }
    il void dft(int *x,int op){
        for(int i=0;i<t;i++)if(i<v[i])swap(x[i],x[v[i]]);
        for(int i=1;i<t;i<<=1){
            int wn=ksm(G[op],(p-1)/(i<<1));
            for(int j=0;j<t;j+=i<<1){
                for(int k=0,w=1;k<i;k++,w=1ll*w*wn%p){
                    int A=x[j+k],B=1ll*x[i+j+k]*w%p;
                    x[j+k]=mu(A,B);x[i+j+k]=mu(A,p-B);
                }
            }
        }
        if(op){
            int kk=ksm(t,p-2);
            for(int i=0;i<t;i++)x[i]=1ll*x[i]*kk%p;    
        }
    }
public:
    il void mult(int x){
        init(x);
        dft(a,0);dft(b,0);
        for(int i=0;i<t;i++)a[i]=1ll*a[i]*b[i]%p;
        dft(a,1);
    }
    il void clear(){
        for(int i=0;i<t;i++)a[i]=b[i]=0;
    }
}T;
il vet work(int x){
    vet res;res.resize(x+1);
    if(x<=sz){
        for(int i=1;i<=x;i++)res[i]=s[x][i];
    }
    else{
        for(int i=0;i<=x;i++){
            if(i&1)a[i]=p-ny[i];else a[i]=ny[i];
            b[i]=1ll*ksm(i,x)*ny[i]%p;
        }
        T.mult(x<<1);
        for(int i=1;i<=x;i++)res[i]=a[i];
        T.clear();
    }
    if(num^x){
        for(int i=0;i<x;i++)res[i]=res[i+1];
        int tmp=1;res.resize(x);
        for(int i=1;i<x;i++)tmp=1ll*(c-i)*tmp%p,res[i]=1ll*res[i]*tmp%p;
    }
    else{
        num=0;
        int tmp=1;
        for(int i=1;i<=x;i++)tmp=1ll*(c-i+1)*tmp%p,res[i]=1ll*res[i]*tmp%p;
    }
    return res;
}
il vet Solve(int l,int r){
    if(l==r)return work(in[l]);
    int mid=upper_bound(sum+l,sum+r+1,sum[l-1]+((sum[r]-sum[l-1])>>1))-sum-1;
    vet res1=Solve(l,mid),res2=Solve(mid+1,r),res;
    int s1=res1.size(),s2=res2.size(),s=s1+s2-1;
    res.resize(s);
    for(int i=0;i<s1;i++)a[i]=res1[i];
    for(int i=0;i<s2;i++)b[i]=res2[i];
    T.mult(s);
    for(int i=0;i<s;i++)res[i]=a[i];
    T.clear();
    return res;
}
int main()
{
    n=read();c=read();k=read();sz=(int)sqrt(n);
    jc[0]=1;for(int i=1;i<=n;i++)jc[i]=1ll*i*jc[i-1]%p;
    ny[n]=ksm(jc[n],p-2);for(int i=n;i;i--)ny[i-1]=1ll*i*ny[i]%p;
    s[0][0]=1;G[0]=3;G[1]=ksm(3,p-2);
    for(int i=1;i<=sz;i++){
        for(int j=1;j<=i;j++)
            s[i][j]=mu(s[i-1][j-1],1ll*j*s[i-1][j]%p);
    }
    for(int i=1;i<n;i++)in[read()]++,in[read()]++;
    num=in[1];sort(in+1,in+1+n);
    for(int i=1;i<=n;i++)sum[i]=sum[i-1]+in[i];
    vet res=Solve(1,n);
    int ans=0;
    for(int i=1;i<n;i++)ans=mu(ans,1ll*ksm(i,k)*res[i]%p);
    printf("%d
",ans);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/Jessie-/p/10416244.html