CF Gym102538A Airplane Cliques

cf

一个自然的想法是在一个点集里选出一个特定的点,在该点处计入点集贡献.由于点集中所有两点间路径的并是个连通块

一个想法就是枚举连通块中深度最浅的点,然后认为在它子树内的距离(le x)的点都可以在点集内.不过这是错的,因为你很轻松就可以找到这个点两棵不同子树内到他距离为(x)的点,而这两个点距离为(2x)

所以现在就是要选择一个点集中的特定点,满足所有到它距离(le x)的点满足两两距离(le x).反过来考虑(),我们枚举点集中深度最深的点,深度相同就按照编号排序(其实就是找bfs序最大的点),这时候,所有满足bfs序更小的,到这个点距离(le x)的点都是满足两两距离限制的.假设当前枚举的bfs序最大的点为(a),现在考虑(p,q)两点,路径([a,p])和路径([a,q])的分叉点为(b).可以发现(p,q)之中最多有一个在分叉点上方

qwq

  • 如果两个点都在(b)下方,由于(a)为当前深度最深的点,那么一定有(max(dis(b,p),dis(b,q))le dis(a,b)),所以(dis(b,p)+dis(b,q)=max(dis(b,p),dis(b,q))+min(dis(b,p),dis(b,q))le dis(a,b)+min(dis(b,p),dis(b,q))=min(dis(a,p),dis(a,q))le x)

  • 如果有一个点都在(b)上方(假设为(q)),因为(dis(b,q)le dis(a,b)),所以(dis(b,p)+dis(b,q)le dis(a,b)+dis(b,q)=dis(a,q)le x)

所以对于每个点(a),如果统计出bfs序比它小的,到它距离(le x)的点个数(cn_a),那对于(ans_i)(inom{cn_a-1}{i-1})的贡献,这个可以把组合数拆开后ntt计算卷积的值

至于(cn_a)的计算可以一个log或两个log,如果是一个log,那么可以先算出(f_i)表示以某个点(或一条边上的中点)(i)为中点,半径为(lfloorfrac{x}{2} floor)的连通块内点数,然后按照bfs序的逆序枚举点(u),到(u)距离(le x)且深度不大于(u)的连通块点数就是(u)往上跳(lfloorfrac{x}{2} floor)距离到的点(v)(f_v)的值,再考虑bfs序要(le u)的bfs序的话,就每找到一个(v)就给(f_v)减掉1即可,这样在后面就不会统计到bfs序更大的点了

#include<bits/stdc++.h>
#define LL long long

using namespace std;
const int N=6e5+10,M=(1<<20)+10,mod=998244353;
int rd()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
    return x*w;
}
void ad(int &x,int y){x+=y,x-=x>=mod?mod:0;}
int fpow(int a,int b){int an=1;while(b){if(b&1) an=1ll*an*a%mod;a=1ll*a*a%mod,b>>=1;}return an;}
int ginv(int a){return fpow(a,mod-2);}
int to[N<<1],nt[N<<1],hd[N],tot=1;
void adde(int x,int y)
{
	++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
	++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int n,m,lm,sz[N],f[N],g[N],mx,nsz,rt;
bool ban[N];
void fdrt(int x,int ffa)
{
	sz[x]=1;
	int nx=0;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		fdrt(y,x),sz[x]+=sz[y],nx=max(nx,sz[y]);
	}
	nx=max(nx,nsz-sz[x]);
	if(mx>nx) mx=nx,rt=x;
}
void d1(int x,int ffa,int de)
{
	m=max(m,de),g[de]+=x<=n;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		d1(y,x,de+1);
	}
}
void d2(int x,int ffa,int de)
{
	if(de>lm) return;
	f[x]+=g[min(m,lm-de)];
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		d2(y,x,de+1);
	}
}
void wk(int x)
{
	mx=nsz+1,fdrt(x,0);
	x=rt,ban[x]=1,fdrt(x,0);
	d1(x,0,0);
	for(int i=1;i<=m;++i) g[i]+=g[i-1];
	f[x]+=g[min(lm,m)];
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		d2(y,x,1);
	}
	memset(g,0,sizeof(int)*(m+1));
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		m=0,d1(y,x,1);
		for(int i=1;i<=m;++i) g[i]+=g[i-1];
		for(int i=0;i<=m;++i) g[i]=-g[i];
		d2(y,x,1),memset(g,0,sizeof(int)*(m+1));
	}
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		nsz=sz[y],wk(y);
	}
}
int st[N],tp,dp[N],ff[N],sq[N];
void d3(int x,int ffa)
{
	st[++tp]=x,ff[x]=st[max(1,tp-lm)];
	dp[x]=tp;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(y==ffa) continue;
		d3(y,x);
	}
	--tp;
}
int fac[N],iac[N],W[21],iW[21],rdr[M],aa[M],bb[M];
void ntt(int *a,int n,bool op)
{
	int l=0,y;
	while((1<<l)<n) ++l;
	for(int i=0;i<n;++i)
	{
		rdr[i]=(rdr[i>>1]>>1)|((i&1)<<(l-1));
		if(i<rdr[i]) swap(a[i],a[rdr[i]]);
	}
	for(int i=1,p=0;i<n;i<<=1,++p)
	{
		int ww=op?W[p]:iW[p];
		for(int j=0;j<n;j+=i<<1)
			for(int k=0,w=1;k<i;++k,w=1ll*w*ww%mod)
			{
				y=1ll*a[j+k+i]*w%mod;
				a[j+k+i]=(a[j+k]-y+mod)%mod;
				a[j+k]=(a[j+k]+y)%mod;
			}
	}
	if(!op) for(int i=0,w=ginv(n);i<n;++i) a[i]=1ll*a[i]*w%mod;
}

int main()
{
	freopen("1.in","r",stdin); 
	freopen("1.out","w",stdout);
	for(int i=1,p=0;p<=20;i<<=1,++p)
		W[p]=fpow(3,(mod-1)/(i<<1)),iW[p]=ginv(W[p]);
	fac[0]=1;
	for(int i=1;i<=N-5;++i) fac[i]=1ll*fac[i-1]*i%mod;
	iac[N-5]=ginv(fac[N-5]);
	for(int i=N-5;i;--i) iac[i-1]=1ll*iac[i]*i%mod;
	n=rd(),lm=rd();
	for(int i=1;i<n;++i)
	{
		int x=rd(),y=rd();
		adde(x,i+n),adde(y,i+n);
	}
	nsz=n+n-1,wk(1);
	d3(1,0);
	for(int i=1;i<=n;++i) sq[i]=i;
	sort(sq+1,sq+n+1,[&](int aa,int bb){return dp[aa]>dp[bb];});
	for(int i=1;i<=n;++i)
	{
		int x=sq[i];
		--f[ff[x]],++aa[f[ff[x]]];
	}
	for(int i=0;i<=n;++i) aa[i]=1ll*aa[i]*fac[i]%mod;
	for(int i=0;i<=n;++i) bb[i]=iac[n-i];
	int len=1;
	while(len<=n+n+2) len<<=1;
	ntt(aa,len,1),ntt(bb,len,1);
	for(int i=0;i<len;++i) aa[i]=1ll*aa[i]*bb[i]%mod;
	ntt(aa,len,0);
	for(int i=0;i<n;++i) printf("%d ",(int)(1ll*aa[n+i]*iac[i]%mod));
	return 0;
}
原文地址:https://www.cnblogs.com/smyjr/p/12681278.html