题解 P5860 【「SWTR-03」Counting Trees】

题意

(n)个点,每个点度数给定。

选出若干个点能组成树的方案数。

题解

对于一棵树,有(n)个点,有(n-1)条边,(sum_{i=1}^ndeg_i=2(n-1)),即(sum_{i=1}^ndeg_i-2=-2)。所以(sum^n_{i=1} v_i-2=-2)是一个必要条件。(充分条件留给读者自证)

看做生成函数,对于(v_i),其生成函数就是(1+x^{v_i-2})即选或不选。

[F(x)=prod_{i=1}^n(1-x^{v_i-2}),ans=[x^{-2}]F(x) ]

可这个负数次很讨厌,我们无法愉快地计算,因此需要乱搞。

[F(x)=overbrace{prod_{v_i=1}(1+x^{-1})}^{m ext{个}} overbrace{prod_{v_i=2}(1+1)}^{k ext{个}} overbrace{prod_{v_i>2}(1+x^{v_i-2})}^{n-m-k ext{个}} ]

那么给所有(1+x^{-1})乘上(x),那么就相当于(1+x),记(a={|v_i-2|,v_i=1|v_i>2})

[egin{aligned} x^mF(x) &=2^kprod_{i}(1+x^{a_i})\ &=2^kexp(sum_{i}ln(1+x^a_i))\ &=2^kexp(sum_{i}int(ln(1+x^{a_i}))^prime dx)\ &=2^kexp(sum_{i}intfrac{a_ix^{a_i-1}}{1+x^{a_i}}dx)\ &=2^kexp(sum_{i}a_ix^{a_i-1}intsum_{j=0}^infty (-x^{a_i})^jdx)\ &=2^kexp(sum_{i}a_ix^{a_i-1}intsum_{j=0}^infty (-1)^jx^{a_ij}dx)\ &=2^kexp(sum_{i}intsum_{j=0}^infty (-1)^ja_ix^{a_i(j+1)-1}dx)\ &=2^kexp(sum_{i}sum_{j=0}^infty (-1)^jint a_ix^{a_i(j+1)-1}dx)\ &=2^kexp(sum_{i}sum_{j=0}^infty (-1)^jfrac{a_ix^{a_i(j+1)}}{a_i(j+1)}dx)\ &=2^kexp(sum_{i}sum_{j=1}^infty (-1)^{j-1}frac{x^{a_ij}}{j}dx)\ end{aligned} ]

就是一个小套路。

最后(ans=[x^{m-2}]x^mF(x))就行了。

代码

最后送上萌新慢的要死的代码

#include<bits/stdc++.h>
//#define faster
namespace in{
	#ifdef faster
	char buf[1<<21],*p1=buf,*p2=buf;
	inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
	#else
	inline int getc(){return getchar();}
	#endif
	template <typename T>inline void read(T& t){
		t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
	    while(isdigit(ch)){t=t*10+ch-48;ch = getc();}if(f)t=-t;
	}
	template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
}
namespace out{
	char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
	inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
	inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
	template <typename T>void write(T x) {
		static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
		while (len>=0)putc(buf[len]),--len;
	}
}
using namespace std;
template<const int mod>
struct modint{
    int x;
    modint<mod>(int o=0){x=o;}
    modint<mod> &operator = (int o){return x=o,*this;}
    modint<mod> &operator +=(modint<mod> o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
    modint<mod> &operator -=(modint<mod> o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
    modint<mod> &operator *=(modint<mod> o){return x=1ll*x*o.x%mod,*this;}
    modint<mod> &operator ^=(int b){
        modint<mod> a=*this,c=1;
        for(;b;b>>=1,a*=a)if(b&1)c*=a;
        return x=c.x,*this;
    }
    modint<mod> &operator /=(modint<mod> o){return *this *=o^=mod-2;}
    modint<mod> &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
    modint<mod> &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
    modint<mod> &operator *=(int o){return x=1ll*x*o%mod,*this;}
    modint<mod> &operator /=(int o){return *this *= ((modint<mod>(o))^=mod-2);}
	template<class I>friend modint<mod> operator +(modint<mod> a,I b){return a+=b;}
    template<class I>friend modint<mod> operator -(modint<mod> a,I b){return a-=b;}
    template<class I>friend modint<mod> operator *(modint<mod> a,I b){return a*=b;}
    template<class I>friend modint<mod> operator /(modint<mod> a,I b){return a/=b;}
    friend modint<mod> operator ^(modint<mod> a,int b){return a^=b;}
    friend bool operator ==(modint<mod> a,int b){return a.x==b;}
    friend bool operator !=(modint<mod> a,int b){return a.x!=b;}
    bool operator ! () {return !x;}
    modint<mod> operator - () {return x?mod-x:0;}
	modint<mod> &operator++(int){return *this+=1;}
};
const int N=4e6+5;
const int mod=998244353;
const modint<mod> GG=3,Ginv=modint<mod>(1)/3,I=86583718;
struct poly{
	vector<modint<mod>>a;
	modint<mod>&operator[](int i){return a[i];}
	int size(){return a.size();}
	void resize(int n){a.resize(n);}
	void reverse(){std::reverse(a.begin(),a.end());}
};
int rev[N];
inline int ext(int n){int k=0;while((1<<k)<n)k++;return k;}
inline void init(int k){int n=1<<k;for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));}
inline void ntt(poly&a,int k,int typ){
	int n=1<<k;
	for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int mid=1;mid<n;mid<<=1){
		modint<mod> wn=(typ>0?GG:Ginv)^((mod-1)/(mid<<1));
		for(int r=mid<<1,j=0;j<n;j+=r){
			modint<mod> w=1;
			for(int k=0;k<mid;k++,w=w*wn){
				modint<mod> x=a[j+k],y=w*a[j+k+mid];
				a[j+k]=x+y,a[j+k+mid]=x-y;
			}
		}
	}
	if(typ<0){
		modint<mod> inv=modint<mod>(1)/n;
		for(int i=0;i<n-1;i++)a[i]*=inv;
	}
}
inline poly one(){poly a;a.a.push_back(1);return a;}
poly operator +(poly a,poly b){
	int n=max(a.size(),b.size());a.resize(n),b.resize(n);
	for(int i=0;i<n;i++)a[i]+=b[i];return a;
}
poly operator -(poly a,poly b){
	int n=max(a.size(),b.size());a.resize(n),b.resize(n);
	for(int i=0;i<n;i++)a[i]-=b[i];return a;
}
inline poly operator*(poly a,poly b){
	int n=a.size()+b.size()-1,k=ext(n);
	a.resize(1<<k),b.resize(1<<k),init(k);
	ntt(a,k,1);ntt(b,k,1);for(int i=0;i<(1<<k);i++)a[i]*=b[i];
	ntt(a,k,-1),a.resize(n);return a;
}
inline poly operator*(poly a,modint<mod> b){for(int i=0;i<a.size();i++)a[i]*=b;return a; }
inline poly operator/(poly a,modint<mod> b){for(int i=0;i<a.size();i++)a[i]/=b;return a; }
inline poly operator-(poly a){for(int i=0;i<a.size();i++)a[i]=-a[i];return a; }
poly inv(poly F,int k){
	int n=1<<k;F.resize(n);
	if(n==1){F[0]=modint<mod>(1)/F[0];return F;}
	poly G,H=inv(F,k-1);
	G.resize(n),H.resize(n<<1),F.resize(n<<1);
	for(int i=0;i<n/2;i++)G[i]=H[i]*2;
	init(k+1),ntt(H,k+1,1),ntt(F,k+1,1);
	for(int i=0;i<(n<<1);i++)H[i]=H[i]*H[i]*F[i];
	ntt(H,k+1,-1),H.resize(n);
	for(int i=0;i<n;i++)G[i]-=H[i];return G;
}
inline poly inv(poly a){
	int n=a.size();
	a=inv(a,ext(n)),a.resize(n);return a;;
}
inline poly deriv(poly a){//求导 
	int n=a.size()-1;
	for(int i=0;i<n;i++)a[i]=a[i+1]*(i+1);
	a.resize(n);return a;
}
inline poly inter(poly a){//求原 
	int n=a.size()+1;a.resize(n);
	for(int i=n;i>=1;i--)a[i]=a[i-1]/i;
	a[0]=0;return a;
}
inline poly ln(poly a){
	int n=a.size();
	a=inter(deriv(a)*inv(a));
	a.resize(n);return a;
}
poly exp(poly a,int k){
	int n=1<<k;a.resize(n);
	if(n==1)return one();
	poly f0=exp(a,k-1);f0.resize(n);
	return f0*(one()+a-ln(f0)); 
}
poly exp(poly a){
	int n=a.size();
	a=exp(a,ext(n));a.resize(n);return a;
}
int n,m,k,v[N],cnt[N];
poly F;
signed main(){
	in::read(n);for(int i=1;i<=n;i++)
		in::read(v[i]),(v[i]==2?k:cnt[abs(v[i]-2)])++,v[i]==1&&(m++);
	F.resize(m);
	for(int i=1;i<m;i++)
		for(int j=1;i*j<m;j++)
			F[i*j]+=((j&1)?modint<mod>(cnt[i]):-modint<mod>(cnt[i]))/j;
	F=exp(F);out::write((F[m-2]*(modint<mod>(2)^k)).x);
	out::flush();
	return 0;
}
原文地址:https://www.cnblogs.com/juruo-cjl/p/14319253.html