「2019 集训队互测 Day 1」最短路径 (点分治+NTT/FFT+线段树)
题意:给定了一棵基环树,求所有的(d(u,v)^k)的期望
当(k)较小时,可以想到用斯特林数/二项式定理展开 维护+1操作,对于树的可以从儿子合并上来,对于环上可以枚举每个块求得答案
复杂度为(O(nk))
当图为一棵树时,由于不好处理(x^k),考虑直接求出(d(u,v)=i)的数量
比较容易想到用用点分治+( ext{NTT})求解,复杂度为(O(nlog ^2n))
环上的情况比较麻烦,不妨为每个块标号(1,2,cdots m),每个块包含(sz_i)个结点
显然((i,j))的距离为(minlbrace|i-j|,m-|i-j| brace)
考虑计算所有块((i,j)(i<j))之间的贡献,令(d=lfloor frac{m}{2} floor),则对于(jin[i+1,i+d])在环上的距离为(j-i),否则距离为(m-(j-i))
对于两种情况分类讨论,这里以计算(jin[i+1,i+d])为例
因为是一段区间,考虑直接在线段树的([i+1,i+d])加入(i),然后对于线段树上每个结点计算
推论1:能够被添加到线段树结点([l,r])上的(i)构成一段连续的区间
推论2:从区间([l,r])的一端出发,( ext{dfs})区间内的块得到的(max dis_uleq sum_{i=l}^r sz_i)
因此同样考虑用( ext{NTT})维护该答案,每次更新答案可以看做是区间([l1,r1],[l2,r2](r1<l2))之间的贡献
分别从(r1,l2)开始( ext{dfs})得到(dis_u),然后( ext{NTT})合并,不把([r1+1,l2-1])这一部分在环上的加入( ext{NTT})大小
这样就能保证卷积大小(leq sum_{i=l1}^{r1} sz_i+sum_{i=l2}^{r2} sz_i)
同理可以类似处理(j>i+d)的情况
分析复杂度:每个(i)会出现在线段树上(log n)个位置,每个(j)会在线段树上(log n)层被计算
因此每个点被加入卷积大小的次数为(O(log n)),复杂度为(O(nlog ^2 n))与前面的点分治同阶
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define Mod1(x) ((x>=P)&&(x-=P))
#define Mod2(x) ((x<0)&&(x+=P))
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
template <class T> inline void cmin(T &a,T b){ ((a>b)&&(a=b)); }
template <class T> inline void cmax(T &a,T b){ ((a<b)&&(a=b)); }
char IO;
int rd(){
int s=0,f=0;
while(!isdigit(IO=getchar())) if(IO=='-') f=1;
do s=(s<<1)+(s<<3)+(IO^'0');
while(isdigit(IO=getchar()));
return f?-s:s;
}
bool Mbe;
const int N=1<<18|10,P=998244353;
int n,m,k;
int A[N];
ll qpow(ll x,ll k=P-2) {
ll res=1;
for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
return res;
}
int Pow[N];
struct Edge{
int to,nxt;
}e[N];
int head[N],ecnt,deg[N];
void AddEdge(int u,int v) {
e[++ecnt]=(Edge){v,head[u]};
head[u]=ecnt,deg[v]++;
}
#define erep(u,i) for(int i=head[u];i;i=e[i].nxt)
int w[N];
void Init() {
int R=1<<18;
int t=qpow(3,(P-1)/R);
w[R/2]=1;
rep(i,R/2+1,R-1) w[i]=1ll*w[i-1]*t%P;
drep(i,R/2-1,1) w[i]=w[i<<1];
}
int rev[N];
void NTT(int n,int *a,int f) {
static int e[N>>1];
rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=e[0]=1,t;i<n;i<<=1) {
int *e=w+i;
for(int l=0;l<n;l+=i*2) {
for(int j=l;j<l+i;++j) {
t=1ll*a[j+i]*e[j-l]%P;
a[j+i]=a[j]-t,Mod2(a[j+i]);
a[j]+=t,Mod1(a[j]);
}
}
}
if(f==-1) {
reverse(a+1,a+n);
ll base=qpow(n);
rep(i,0,n-1) a[i]=a[i]*base%P;
}
}
int Init(int n) {
int R=1,c=-1;
while(R<=n) R<<=1,c++;
rep(i,0,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<c);
return R;
}
int Q[N],L,R,vis[N];
namespace pt1{
const int N=1010;
int dis[N];
void Bfs(int u) {
rep(i,1,n) dis[i]=-1;
dis[Q[L=R=1]=u]=0;
while(L<=R) {
u=Q[L++];
erep(u,i){
int v=e[i].to;
if(~dis[v]) continue;
dis[v]=dis[u]+1,Q[++R]=v;
}
}
}
void Solve() {
int ans=0;
rep(i,2,n) {
Bfs(i);
rep(j,1,i-1) ans=(ans+Pow[dis[j]])%P;
}
ans=ans*qpow(n*(n-1)/2)%P;
printf("%d
",ans);
}
}
int Ans[N],sz[N];
namespace pt2{
int mi=1e9,rt;
void FindRt(int n,int u,int f) {
int ma=0; sz[u]=1;
erep(u,i) {
int v=e[i].to;
if(v==u || v==f || vis[v]) continue;
FindRt(n,v,u),sz[u]+=sz[v],cmax(ma,sz[v]);
}
cmax(ma,n-sz[u]);
if(mi>ma) mi=ma,rt=u;
}
int F[N],A[N],B[N];
void Solve(int n,int k) {
// 容斥型 点分治
int R=Init(n*2+1);
rep(i,0,R) F[i]=0;
rep(i,0,n) F[i]=A[i];
NTT(R,F,1);
rep(i,0,R-1) F[i]=1ll*F[i]*F[i]%P;
NTT(R,F,-1);
if(k==1) rep(i,0,n*2) Ans[i]+=F[i],Mod1(Ans[i]);
else rep(i,0,n*2) Ans[i]-=F[i],Mod2(Ans[i]);
}
int maxd;
void dfs(int u,int f,int d=0) {
A[d]++,sz[u]=1,cmax(maxd,d);
erep(u,i) {
int v=e[i].to;
if(v==u || v==f || vis[v]) continue;
dfs(v,u,d+1),sz[u]+=sz[v];
}
}
void Divide(int n,int u) {
mi=1e9,FindRt(n,u,0),u=rt;
vis[u]=1;
int D=0;B[0]=1;
erep(u,i) {
int v=e[i].to;
if(vis[v]) continue;
maxd=0,dfs(v,u,1);
Solve(maxd,-1);
rep(j,0,maxd) B[j]+=A[j],A[j]=0;
cmax(D,maxd);
}
rep(i,0,D) A[i]=B[i],B[i]=0;
Solve(D,1);
rep(i,0,D) A[i]=0;
erep(u,i) {
int v=e[i].to;
if(vis[v]) continue;
Divide(sz[v],v);
}
}
void Solve() {
rep(i,1,n) vis[i]=0;
Divide(n,1);
int ans=0;
rep(i,1,n) ans=(ans+1ll*Ans[i]*Pow[i])%P;
ans=ans*qpow(1ll*n*(n-1)%P)%P;
printf("%d
",ans);
}
}
int QL[N<<2],QR[N<<2];
void Add(int p,int l,int r,int ql,int qr,int x) {
// 在线段树上加入结点
if(ql<=l && r<=qr) {
if(!QL[p]) QL[p]=x;
QR[p]=x;
return;
}
int mid=(l+r)>>1;
if(ql<=mid) Add(p<<1,l,mid,ql,qr,x);
if(qr>mid) Add(p<<1|1,mid+1,r,ql,qr,x);
}
int typ;
int X[N],Y[N],D;
void dfs(int *C,int u,int f,int d) {
cmax(D,d),C[d]++;
for(int i=head[u];i;i=e[i].nxt) {
int v=e[i].to;
if(v==f || vis[v]) continue;
dfs(C,v,u,d+1);
}
}
void Mark(int i,int k) {
int l=A[i==1?m:i-1],r=A[i==m?1:i+1];
vis[l]=vis[r]=k;
}
void Get(int p,int l,int r) {
if(QL[p]) {
// 计算区间QL,QR到l,r的贡献
if(typ==0) {
int qr=QR[p];
rep(x,QL[p],QR[p]) Mark(x,1),dfs(X,A[x],0,qr-x),Mark(x,0);
int T=D; D=0;
rep(x,l,r) Mark(x,1),dfs(Y,A[x],0,x-l),Mark(x,0);
int R=Init(T+D+1);
NTT(R,X,1),NTT(R,Y,1);
rep(i,0,R-1) X[i]=1ll*X[i]*Y[i]%P;
NTT(R,X,-1);
rep(i,0,T+D) Ans[i+l-qr]+=X[i],Mod1(Ans[i+l-qr]);
rep(i,0,R) X[i]=Y[i]=0;
} else {
int ql=QL[p];
rep(x,QL[p],QR[p]) Mark(x,1),dfs(X,A[x],0,x-ql),Mark(x,0);
int T=D; D=0;
rep(x,l,r) Mark(x,1),dfs(Y,A[x],0,r-x),Mark(x,0);
int R=Init(T+D+1);
NTT(R,X,1),NTT(R,Y,1);
rep(i,0,R-1) X[i]=1ll*X[i]*Y[i]%P;
NTT(R,X,-1);
int d=ql+m-r;
rep(i,0,T+D) Ans[i+d]+=X[i],Mod1(Ans[i+d]);
rep(i,0,R) X[i]=Y[i]=0;
}
QL[p]=QR[p]=0;
}
if(l==r) return;
int mid=(l+r)>>1;
Get(p<<1,l,mid),Get(p<<1|1,mid+1,r);
}
int main() {
freopen("path.in","r",stdin),freopen("path.out","w",stdout);
n=rd(),k=rd();
rep(i,1,n) Pow[i]=qpow(i,k);
rep(i,1,n) {
int u=rd(),v=rd();
AddEdge(u,v),AddEdge(v,u);
}
if(n<=1000) return pt1::Solve(),0;
Init(),L=1;
// 拓扑求环
rep(i,1,n) if(deg[i]==1) sz[Q[++R]=i]=1;
while(L<=R) {
int u=Q[L++]; vis[u]=1;
for(int i=head[u];i;i=e[i].nxt) {
int v=e[i].to;
if(deg[v]<=1) sz[u]+=sz[v];
if(--deg[v]==1) Q[++R]=v;
}
}
for(int u=1;u<=n;++u) if(!vis[u]) {
while(1) {
vis[u]=1,A[++m]=u;
int nxt=-1;
for(int i=head[u];i;i=e[i].nxt) {
int v=e[i].to;
if(!vis[v]) nxt=v;
}
if(nxt==-1) break;
u=nxt;
}
break;
}
if(m==1) return pt2::Solve(),0;
fprintf(stderr,"Circle Length =%d
",m);
rep(i,1,n) vis[i]=0;
k=m/2;
rep(i,1,m) {
Mark(i,1);
pt2::Divide(sz[A[i]],A[i]);
Mark(i,0);
}
rep(i,1,n) Ans[i]=1ll*Ans[i]*(P+1)/2%P;
rep(i,1,n) vis[i]=0;
rep(i,1,m-1) Add(1,1,m,i+1,min(i+k,m),i);
typ=0,Get(1,1,m);
rep(i,1,m-k-1) Add(1,1,m,i+k+1,m,i);
typ=1,Get(1,1,m);
int ans=0;
rep(i,1,n) ans=(ans+1ll*Ans[i]*Pow[i])%P;
ans=ans*qpow(1ll*n*(n-1)/2%P)%P;
printf("%d
",ans);
}