「2019 集训队互测 Day 1」最短路径 (点分治+NTT/FFT+线段树)

「2019 集训队互测 Day 1」最短路径 (点分治+NTT/FFT+线段树)

题意:给定了一棵基环树,求所有的(d(u,v)^k)的期望

(k)较小时,可以想到用斯特林数/二项式定理展开 维护+1操作,对于树的可以从儿子合并上来,对于环上可以枚举每个块求得答案

复杂度为(O(nk))

当图为一棵树时,由于不好处理(x^k),考虑直接求出(d(u,v)=i)的数量

比较容易想到用用点分治+( ext{NTT})求解,复杂度为(O(nlog ^2n))

环上的情况比较麻烦,不妨为每个块标号(1,2,cdots m),每个块包含(sz_i)个结点

显然((i,j))的距离为(minlbrace|i-j|,m-|i-j| brace)

考虑计算所有块((i,j)(i<j))之间的贡献,令(d=lfloor frac{m}{2} floor),则对于(jin[i+1,i+d])在环上的距离为(j-i),否则距离为(m-(j-i))

对于两种情况分类讨论,这里以计算(jin[i+1,i+d])为例

因为是一段区间,考虑直接在线段树的([i+1,i+d])加入(i),然后对于线段树上每个结点计算

推论1:能够被添加到线段树结点([l,r])上的(i)构成一段连续的区间

推论2:从区间([l,r])的一端出发,( ext{dfs})区间内的块得到的(max dis_uleq sum_{i=l}^r sz_i)

因此同样考虑用( ext{NTT})维护该答案,每次更新答案可以看做是区间([l1,r1],[l2,r2](r1<l2))之间的贡献

分别从(r1,l2)开始( ext{dfs})得到(dis_u),然后( ext{NTT})合并,不把([r1+1,l2-1])这一部分在环上的加入( ext{NTT})大小

这样就能保证卷积大小(leq sum_{i=l1}^{r1} sz_i+sum_{i=l2}^{r2} sz_i)

同理可以类似处理(j>i+d)的情况

分析复杂度:每个(i)会出现在线段树上(log n)个位置,每个(j)会在线段树上(log n)层被计算

因此每个点被加入卷积大小的次数为(O(log n)),复杂度为(O(nlog ^2 n))与前面的点分治同阶

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
#define Mod1(x) ((x>=P)&&(x-=P))
#define Mod2(x) ((x<0)&&(x+=P))
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
template <class T> inline void cmin(T &a,T b){ ((a>b)&&(a=b)); }
template <class T> inline void cmax(T &a,T b){ ((a<b)&&(a=b)); }

char IO;
int rd(){
	int s=0,f=0;
	while(!isdigit(IO=getchar())) if(IO=='-') f=1;
	do s=(s<<1)+(s<<3)+(IO^'0');
	while(isdigit(IO=getchar()));
	return f?-s:s;
}

bool Mbe;
const int N=1<<18|10,P=998244353;

int n,m,k;
int A[N];
ll qpow(ll x,ll k=P-2) {
	ll res=1;
	for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
	return res;
}
int Pow[N];
struct Edge{
	int to,nxt;
}e[N];
int head[N],ecnt,deg[N];
void AddEdge(int u,int v) {
	e[++ecnt]=(Edge){v,head[u]};
	head[u]=ecnt,deg[v]++;
}
#define erep(u,i) for(int i=head[u];i;i=e[i].nxt)

int w[N];
void Init() {
	int R=1<<18;
	int t=qpow(3,(P-1)/R);
	w[R/2]=1;
	rep(i,R/2+1,R-1) w[i]=1ll*w[i-1]*t%P;
	drep(i,R/2-1,1) w[i]=w[i<<1];
}

int rev[N];
void NTT(int n,int *a,int f) {
	static int e[N>>1];
	rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int i=e[0]=1,t;i<n;i<<=1) {
		int *e=w+i;
		for(int l=0;l<n;l+=i*2) {
			for(int j=l;j<l+i;++j) {
				t=1ll*a[j+i]*e[j-l]%P;
				a[j+i]=a[j]-t,Mod2(a[j+i]);
				a[j]+=t,Mod1(a[j]);
			}
		}
	}
	if(f==-1) {
		reverse(a+1,a+n);
		ll base=qpow(n);
		rep(i,0,n-1) a[i]=a[i]*base%P;
	}
}
int Init(int n) {
	int R=1,c=-1;
	while(R<=n) R<<=1,c++;
	rep(i,0,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<c);
	return R;
}

int Q[N],L,R,vis[N];

namespace pt1{ 
	const int N=1010;
	int dis[N];
	void Bfs(int u) {
		rep(i,1,n) dis[i]=-1;
		dis[Q[L=R=1]=u]=0;
		while(L<=R) {
			u=Q[L++];
			erep(u,i){
				int v=e[i].to;
				if(~dis[v]) continue;
				dis[v]=dis[u]+1,Q[++R]=v;
			}
		}
	}
	void Solve() {
		int ans=0;
		rep(i,2,n) {
			Bfs(i);
			rep(j,1,i-1) ans=(ans+Pow[dis[j]])%P;
		}
		ans=ans*qpow(n*(n-1)/2)%P;
		printf("%d
",ans);
	}
}

int Ans[N],sz[N];
namespace pt2{ 
	int mi=1e9,rt;
	void FindRt(int n,int u,int f) {
		int ma=0; sz[u]=1;
		erep(u,i) {
			int v=e[i].to;
			if(v==u || v==f || vis[v]) continue;
			FindRt(n,v,u),sz[u]+=sz[v],cmax(ma,sz[v]);
		}
		cmax(ma,n-sz[u]);
		if(mi>ma) mi=ma,rt=u;
	}

	int F[N],A[N],B[N];
	void Solve(int n,int k) {
        // 容斥型 点分治
		int R=Init(n*2+1);
		rep(i,0,R) F[i]=0;
		rep(i,0,n) F[i]=A[i];
		NTT(R,F,1);
		rep(i,0,R-1) F[i]=1ll*F[i]*F[i]%P;
		NTT(R,F,-1);
		if(k==1) rep(i,0,n*2) Ans[i]+=F[i],Mod1(Ans[i]);
		else rep(i,0,n*2) Ans[i]-=F[i],Mod2(Ans[i]);
	}
	int maxd;
	void dfs(int u,int f,int d=0) {
		A[d]++,sz[u]=1,cmax(maxd,d);
		erep(u,i) {
			int v=e[i].to;
			if(v==u || v==f || vis[v]) continue;
			dfs(v,u,d+1),sz[u]+=sz[v];
		}
	}
	void Divide(int n,int u) {
		mi=1e9,FindRt(n,u,0),u=rt;
		vis[u]=1;
		int D=0;B[0]=1;
		erep(u,i) {
			int v=e[i].to;
			if(vis[v]) continue;
			maxd=0,dfs(v,u,1);
			Solve(maxd,-1);
			rep(j,0,maxd) B[j]+=A[j],A[j]=0;
			cmax(D,maxd);
		}
		rep(i,0,D) A[i]=B[i],B[i]=0;
		Solve(D,1);
		rep(i,0,D) A[i]=0;
		erep(u,i) {
			int v=e[i].to;
			if(vis[v]) continue;
			Divide(sz[v],v);
		}
	}
	void Solve() {
		rep(i,1,n) vis[i]=0;
		Divide(n,1);
		int ans=0;
		rep(i,1,n) ans=(ans+1ll*Ans[i]*Pow[i])%P;
		ans=ans*qpow(1ll*n*(n-1)%P)%P;
		printf("%d
",ans);
	}
}

int QL[N<<2],QR[N<<2];
void Add(int p,int l,int r,int ql,int qr,int x) {
    // 在线段树上加入结点
	if(ql<=l && r<=qr) {
		if(!QL[p]) QL[p]=x;
		QR[p]=x;
		return;
	}
	int mid=(l+r)>>1;
	if(ql<=mid) Add(p<<1,l,mid,ql,qr,x);
	if(qr>mid) Add(p<<1|1,mid+1,r,ql,qr,x);
}

int typ;
int X[N],Y[N],D;

void dfs(int *C,int u,int f,int d) {
	cmax(D,d),C[d]++;
	for(int i=head[u];i;i=e[i].nxt) {
		int v=e[i].to;
		if(v==f || vis[v])  continue;
		dfs(C,v,u,d+1);
	}
}

void Mark(int i,int k) {
	int l=A[i==1?m:i-1],r=A[i==m?1:i+1];
	vis[l]=vis[r]=k;
}

void Get(int p,int l,int r) { 
	if(QL[p]) {
        // 计算区间QL,QR到l,r的贡献
		if(typ==0) {
			int qr=QR[p];
			rep(x,QL[p],QR[p]) Mark(x,1),dfs(X,A[x],0,qr-x),Mark(x,0);
			int T=D; D=0;
			rep(x,l,r) Mark(x,1),dfs(Y,A[x],0,x-l),Mark(x,0);
			int R=Init(T+D+1);
			NTT(R,X,1),NTT(R,Y,1);
			rep(i,0,R-1) X[i]=1ll*X[i]*Y[i]%P;
			NTT(R,X,-1);
			rep(i,0,T+D) Ans[i+l-qr]+=X[i],Mod1(Ans[i+l-qr]);
			rep(i,0,R) X[i]=Y[i]=0;
		} else {
			int ql=QL[p];
			rep(x,QL[p],QR[p]) Mark(x,1),dfs(X,A[x],0,x-ql),Mark(x,0); 
			int T=D; D=0;
			rep(x,l,r) Mark(x,1),dfs(Y,A[x],0,r-x),Mark(x,0);
			int R=Init(T+D+1);
			NTT(R,X,1),NTT(R,Y,1);
			rep(i,0,R-1) X[i]=1ll*X[i]*Y[i]%P;
			NTT(R,X,-1);
			int d=ql+m-r;
			rep(i,0,T+D) Ans[i+d]+=X[i],Mod1(Ans[i+d]);
			rep(i,0,R) X[i]=Y[i]=0;
		}
		QL[p]=QR[p]=0;
	}
	if(l==r) return;
	int mid=(l+r)>>1;
	Get(p<<1,l,mid),Get(p<<1|1,mid+1,r);
}

int main() {
	freopen("path.in","r",stdin),freopen("path.out","w",stdout);
	n=rd(),k=rd();
	rep(i,1,n) Pow[i]=qpow(i,k);
	rep(i,1,n) {
		int u=rd(),v=rd();
		AddEdge(u,v),AddEdge(v,u);
	}
	if(n<=1000) return pt1::Solve(),0;
	Init(),L=1;
    // 拓扑求环
	rep(i,1,n) if(deg[i]==1) sz[Q[++R]=i]=1;
	while(L<=R) {
		int u=Q[L++]; vis[u]=1;
		for(int i=head[u];i;i=e[i].nxt) {
			int v=e[i].to;
			if(deg[v]<=1) sz[u]+=sz[v];
			if(--deg[v]==1) Q[++R]=v;
		}
	}
	for(int u=1;u<=n;++u) if(!vis[u]) {
		while(1) {
			vis[u]=1,A[++m]=u;
			int nxt=-1;
			for(int i=head[u];i;i=e[i].nxt) {
				int v=e[i].to;
				if(!vis[v]) nxt=v;
			}
			if(nxt==-1) break;
			u=nxt;
		}
		break;
	}
	if(m==1) return pt2::Solve(),0;
	fprintf(stderr,"Circle Length =%d
",m);
	rep(i,1,n) vis[i]=0;

	k=m/2;
	rep(i,1,m) {
		Mark(i,1);
		pt2::Divide(sz[A[i]],A[i]);
		Mark(i,0);
	}
	rep(i,1,n) Ans[i]=1ll*Ans[i]*(P+1)/2%P;
	rep(i,1,n) vis[i]=0;
	rep(i,1,m-1) Add(1,1,m,i+1,min(i+k,m),i);
	typ=0,Get(1,1,m);
	rep(i,1,m-k-1) Add(1,1,m,i+k+1,m,i);
	typ=1,Get(1,1,m);
	int ans=0;
	rep(i,1,n) ans=(ans+1ll*Ans[i]*Pow[i])%P;
	ans=ans*qpow(1ll*n*(n-1)/2%P)%P;
	printf("%d
",ans);
}
原文地址:https://www.cnblogs.com/chasedeath/p/13954352.html