[2018.6.22集训]admirable-拆系数FFT-多项式相关

题目大意

给出一棵树,现在需要将$k$条不相同的路径覆盖到这棵树上。
定义一种合法的路径覆盖方案为,能使得树上的每条边的被覆盖次数$t in {0,1,k}$的方案。
求合法方案的数量,对$10^9 +9$取模。

$n,k leq 10^5$。

题解

很容易想到一个计算答案的方法:
被覆盖$k$次的部分一定是一条链,枚举这条链,并对每条枚举的链,计算可行的方案数并累加。

对于每条覆盖$k$次的链,可以发现,如果分别计算出两个端点处的方案,将其贡献乘起来,即可得到这条链的方案数。
可以发现,对于一个端点处的合法方案,每一条路径在这一侧的端点均在这条链的端点的不同子树中或恰好在这个端点上,因为若有两条路径在同一子树中,会经过同一条边多次,方案会不合法。

考虑如何计算这个方案数,可以使用多项式乘法解决,构造形如$size[v]*x+1$的多项式,并将每个子树的多项式乘起来,再取出前$k$项的系数,采用全排列和组合数计算贡献。

因此,每个点所$u$需要的多项式,为$prodlimits_{(u,v) in E}size[v]*x+1$,除掉与当前枚举的链相交的一侧的儿子的多项式得到。
乘或除掉一个二项式是$O(n)$的,同时如果预处理,乘除操作的次数将为$O(n)$,因此预处理一个节点处的多项式再枚举,复杂度$O(n^2)$。

考虑优化。
对于计算初始每个节点处所有儿子乘起来得到的多项式,考虑分治FFT,$O(nlog^2n)$。
考虑除掉二项式的复杂度,可以发现对于两个$size$相同的多项式,得到的结果相同,而一个节点的子树的不同的$size$的数量为$O(sqrt{n})$,于是改为对子树中每种出现了的$size$预处理,复杂度降至$O(nsqrt{n})$。
最后,对于一个节点,同一棵子树内的节点与它组成的链,得到的这个点的贡献相同,因此可以一起算,复杂度降至$O(n)$,具体可见代码。

于是最后复杂度为$O(nlog^2n+nsqrt{n})$。
由于模数很恶心,需要使用拆系数FFT,这里使用了三模数+中国剩余定理合并的版本。

代码:

#include<map>
#include<cstdio>
#include<algorithm>
using namespace std;

typedef long long ll;
const int N=4e5+9;
const int md=1e9+9;
const int K=33;

inline int read()
{
	int x=0;char ch=getchar();
	while(ch<'0' || '9'<ch)ch=getchar();
	while('0'<=ch && ch<='9')x=x*10+(ch^48),ch=getchar();
	return x;
}

inline void chk(ll &x){if(x>=md)x-=md;}

inline ll qpow(ll a,ll b)
{
	ll ret=1;
	while(b)
	{
		if(b&1)ret=ret*a%md;
		a=a*a%md;b>>=1;
	}
	return ret;
}

int n,k;
int to[N<<1],nxt[N<<1],beg[N],tot;
int fa[N],deg[N],siz[N],len[N];
ll f[N],t[N],sumf[N];
map<int,ll> g[N];
ll fac[N],inv[N],ans;

inline ll c(ll a,ll b)
{
	if(a<b)return 0;
	return fac[a]*inv[b]%md*inv[a-b]%md;
}

inline void init()
{
	fac[0]=1;
	for(ll i=1;i<N;i++)
		fac[i]=fac[i-1]*i%md;
	inv[N-1]=qpow(fac[N-1],md-2);
	for(ll i=N-1;i>=1;i--)
		inv[i-1]=inv[i]*i%md;
	for(ll i=0;i<=k;i++)
		t[i]=fac[i]*c(k,i)%md;
}

inline void add(int u,int v)
{
	to[++tot]=v;
	nxt[tot]=beg[u];
	beg[u]=tot;
	deg[v]++;
}

inline ll mul(ll *f,ll x,int &l)
{
	f[++l]=0;
	for(int i=l-1;i>=0;i--)
		chk(f[i+1]+=f[i]*x%md);
}

inline ll imul(ll *f,ll x,int &l)
{
	ll invv=qpow(x,md-2);
	for(int i=l-1;i>=0;i--)
	{
		f[i+1]=f[i+1]*invv%md;
		chk(f[i]=f[i]+md-f[i+1]);
	}
	for(int i=0;i<l;i++)
		f[i]=f[i+1];
	l--;
}

inline ll calc(ll *f,int l)
{
	ll ret=0;
	for(int i=0;i<=l;i++)
		chk(ret+=f[i]*t[i]%md);
	return ret;
}

namespace ntt
{
	const ll md1=998244353;
	const ll md2=1004535809;
	const ll md3=469762049;
	const ll M=md1*md2;

	int rev[N];
	ll b[K][N],len[K],seg[N],top;

	const ll muls(ll a,ll b,ll p)
	{
		return ((a*b-(ll)((long double)a*b/p)*p)%p+p)%p;
	}

	inline void initrev(int n)
	{
		for(int i=0;i<n;i++)
			rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
	}

	inline ll qpow(ll a,ll b,ll p)
	{
		ll ret=1;a%=p;
		while(b)
		{
			if(b&1)ret=ret*a%p;
			a=a*a%p;b>>=1;
		}
		return ret;
	}

	inline void ntt(ll *a,int n,int f,ll md)
	{
		for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
		for(int h=2;h<=n;h<<=1)
		{
			ll w=qpow(3,(md-1)/h,md);
			if(f)w=qpow(w,md-2,md);
			for(int j=0;j<n;j+=h)
			{
				ll wn=1ll,x,y;
				for(int k=j;k<j+(h>>1);k++)
				{
					x=a[k],y=a[k+(h>>1)]*wn%md;wn=wn*w%md;
					a[k]=(x+y)%md;a[k+(h>>1)]=(x-y+md)%md;
				}
			}
		}
		if(f)
			for(ll i=0,invn=qpow(n,md-2,md);i<n;i++)
				a[i]=a[i]*invn%md;
	}

	inline void cp(ll *a,ll *b,ll n,ll m)
	{
		for(int i=0;i<n;i++)
			b[i]=a[i];
		for(int i=n;i<m;i++)
			b[i]=0;
	}

	inline void bfmul(ll *a,int n,ll *b,int m,ll *c)
	{
		static ll d[N];
		for(int i=0;i<n+m-1;i++)
			d[i]=0;
		for(int i=0;i<n;i++)
			for(int j=0;j<m;j++)
				chk(d[i+j]+=a[i]*b[j]%md);
		for(int i=0;i<n+m-1;i++)
			c[i]=d[i];
	}

	inline void mul(ll *a,int n,ll *b,int m,ll *c)
	{
		static ll d[N],e[N],f[N],g[N],l;

		if(n+m-1<=1000)
		{
			bfmul(a,n,b,m,c);
			return;
		}
		for(l=1;l<=n+m;l<<=1);initrev(l);

		cp(a,d,n,l);cp(b,e,m,l);
		ntt(d,l,0,md1);ntt(e,l,0,md1);
		for(int i=0;i<l;i++)
			f[i]=d[i]*e[i]%md1;
		ntt(f,l,1,md1);

		cp(a,d,n,l);cp(b,e,m,l);
		ntt(d,l,0,md2);ntt(e,l,0,md2);
		for(int i=0;i<l;i++)
			g[i]=d[i]*e[i]%md2;
		ntt(g,l,1,md2);

		ll inv2=qpow(md2,md1-2,md1);
		ll inv1=qpow(md1,md2-2,md2);
		for(int i=0;i<n+m-1;i++)
		{
			f[i]=muls(f[i]*md2%M,inv2,M);
			g[i]=muls(g[i]*md1%M,inv1,M);
			f[i]=(f[i]+g[i])%M;
		}

		cp(a,d,n,l);cp(b,e,m,l);
		ntt(d,l,0,md3);ntt(e,l,0,md3);
		for(int i=0;i<l;i++)
			g[i]=d[i]*e[i]%md3;
		ntt(g,l,1,md3);

		for(int i=0;i<n+m-1;i++)
		{
			ll k=((g[i]-f[i])%md3+md3)%md3*qpow(M,md3-2,md3)%md3;
			f[i]=(k%md*(M%md)%md+f[i]%md)%md;
		}

		for(int i=0;i<n+m-1;i++)
			c[i]=f[i];
	}

	inline void work(int l,int r)
	{
		if(l==r)
		{
			len[++top]=2;
			b[top][0]=1;
			b[top][1]=seg[l];
			return;
		}

		int mid=l+r>>1;
		work(l,mid);
		work(mid+1,r);
		mul(b[top-1],len[top-1],b[top],len[top],b[top-1]);
		len[top-1]=len[top-1]+len[top]-1;
		top--;
	}
}

inline void dfs_pre(int u)
{
	siz[u]=1;sumf[u]=0;int cnt=0;
	for(int i=beg[u];i;i=nxt[i])
		if(to[i]!=fa[u])
		{
			fa[to[i]]=u;
			dfs_pre(to[i]);
			siz[u]+=siz[to[i]];
			chk(sumf[u]+=sumf[to[i]]);
		}

	len[u]=0;f[0]=1;
	for(int i=beg[u];i;i=nxt[i])
		if(to[i]!=fa[u])
			ntt::seg[++cnt]=siz[to[i]];
	if(u)
		ntt::seg[++cnt]=n-siz[u];
	ntt::top=0;
	ntt::work(1,cnt);
	len[u]=ntt::len[1]-1;
	for(int i=0;i<=len[u];i++)
		f[i]=ntt::b[1][i];

	for(int i=beg[u];i;i=nxt[i])
		if(to[i]!=fa[u] && !g[u].count(siz[to[i]]))
		{
			imul(f,siz[to[i]],len[u]);
			g[u][siz[to[i]]]=calc(f,len[u]);
			mul(f,siz[to[i]],len[u]);
		}
	if(!g[u].count(n-siz[u]))
	{
		imul(f,n-siz[u],len[u]);
		g[u][n-siz[u]]=calc(f,len[u]);
		mul(f,n-siz[u],len[u]);
	}
	chk(sumf[u]+=g[u][n-siz[u]]);
}

namespace chain
{
	int mina()
	{
		ans=c(n,2);
		for(int i=1;i<=n;i++)
		{
			chk(ans+=c(k,1)*(i-1)%md*(n-i)%md);
			chk(ans+=c(k,1)*c(n-i,2)%md*(1+c(k,1)*(i-1)%md)%md);
		}
		printf("%lld
",ans);
		return 0;
	}
}

inline void dfs_calc(int u,ll sum)
{
	if(fa[u])
		chk(ans+=g[u][n-siz[u]]*sum%md);
	for(int i=beg[u];i;i=nxt[i])
		if(to[i]!=fa[u])
		{
			chk(ans+=(g[u][siz[to[i]]]*sumf[to[i]])%md);
			sum+=sumf[to[i]];
		}
	for(int i=beg[u];i;i=nxt[i])
		if(to[i]!=fa[u])
			dfs_calc(to[i],(sum-sumf[to[i]]+g[u][siz[to[i]]]+md)%md);
}

int main()
{
	n=read();k=read();
	for(int i=2,u,v;i<=n;i++)
	{
		u=read();v=read();
		add(u,v);add(v,u);
	}

	init();
	for(int i=1;i<=n;i++)
		if(deg[i]>2)
			goto hell;
	return chain::mina();
	hell:;

	 dfs_pre(1);
	 dfs_calc(1,ans=0);

	 printf("%lld
",ans*qpow(2,md-2)%md);

	 return 0;
}
原文地址:https://www.cnblogs.com/zltttt/p/9215964.html