【LGP5439】【XR-2】永恒

题目

是个傻题

显然枚举每一条路径经过了多少次,如果(u,v)在树上不是祖先关系的话经过((u,v))这条路径的路径条数就是(sum_u imes sum_v)

于是我们子树大小映射到( m Trie)上去,树形( m dp)一下就可以求出所有点对产生的贡献了

但是这样祖先关系的节点就算错了,我们发现这也非常好算,( m dfs)的时候拿( m LCT)维护一下就好了

代码

#include<bits/stdc++.h>
#define re register
inline int read() {
	char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=3e5+5;
const int mod=998244353;
struct E{int v,nxt;}e[maxn];
inline int qm(int x) {return x>=mod?x-mod:x;}
inline int dqm(int x) {return x<0?x+mod:x;}
int n,m,num,rt,ans,sm[maxn],head[maxn],d[maxn];
char S[maxn];
struct Trie {
	E e[maxn<<1];
	int head[maxn],num,v[maxn],deep[maxn];
	inline void add(int x,int y) {
		e[++num].v=y;e[num].nxt=head[x];head[x]=num;
	}
	void pdfs(int x) {
		for(re int i=head[x];i;i=e[i].nxt) deep[e[i].v]=deep[x]+1,pdfs(e[i].v);
	}
	void dfs(int x,int dep) {
		for(re int i=head[x];i;i=e[i].nxt) {
			dfs(e[i].v,dep+1);
			ans=qm(ans+1ll*dep*v[x]%mod*v[e[i].v]%mod);
			v[x]=qm(v[x]+v[e[i].v]);
		}
	}
}T;
struct LinkCutTree {
	int fa[maxn],ch[maxn][2],rev[maxn],tag[maxn],st[maxn],top,sum[maxn],a[maxn],sz[maxn];
	inline int nrt(int x) {return ch[fa[x]][1]==x||ch[fa[x]][0]==x;}
	inline void pushup(int x) {
		sz[x]=1+sz[ch[x][0]]+sz[ch[x][1]];sum[x]=a[x];
		if(ch[x][0]) sum[x]=qm(sum[x]+sum[ch[x][0]]);
		if(ch[x][1]) sum[x]=qm(sum[x]+sum[ch[x][1]]);
	}
	inline void work(int x,int v) {
		a[x]=qm(a[x]+v);tag[x]=qm(tag[x]+v);
		sum[x]=qm(sum[x]+1ll*sz[x]*v%mod);
	}
	inline void pushdown(int x) {
		if(tag[x]) {
			if(ch[x][0]) work(ch[x][0],tag[x]);
			if(ch[x][1]) work(ch[x][1],tag[x]);
			tag[x]=0;
		}
		if(rev[x]) {
			rev[x]=0;rev[ch[x][0]]^=1;rev[ch[x][1]]^=1;
			std::swap(ch[ch[x][0]][0],ch[ch[x][0]][1]);
			std::swap(ch[ch[x][1]][0],ch[ch[x][1]][1]);
		}
	}
	inline void rotate(int x) {
		int y=fa[x],z=fa[y],w=ch[y][1]==x,k=ch[x][w^1];
		if(nrt(y)) ch[z][ch[z][1]==y]=x;
		ch[x][w^1]=y,ch[y][w]=k;
		pushup(y),pushup(x);fa[k]=y,fa[y]=x,fa[x]=z;
	}
	inline void splay(int x) {
		int y=x;top=0;st[++top]=x;
		while(nrt(y)) y=fa[y],st[++top]=y;
		while(top) pushdown(st[top--]);
		while(nrt(x)) {
			int y=fa[x];
			if(nrt(y)) rotate((ch[fa[y]][1]==y)^(ch[y][1]==x)?x:y);
			rotate(x);
		}
	}
	inline void access(int x) {
		for(re int y=0;x;y=x,x=fa[x])
			splay(x),ch[x][1]=y,pushup(x);
	}
	inline void mrt(int x) {
		access(x);splay(x);rev[x]^=1;std::swap(ch[x][0],ch[x][1]);
	}
	inline void link(int x,int y) {
		mrt(x);fa[x]=y;T.add(x,y);
	}
	inline void split(int x,int y) {
		mrt(x);access(y);splay(y);
	}
	inline void ins(int x,int y,int v) {
		split(x,y);v=dqm(v);work(y,v);
	}
	inline int query(int x,int y) {
		split(x,y);
		return dqm(sum[y]-a[y]);
	}
}lct;
inline void add(int x,int y) {
	e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs1(int x) {
	sm[x]=1;
	for(re int i=head[x];i;i=e[i].nxt) dfs1(e[i].v),sm[x]+=sm[e[i].v];
}
void dfs2(int x) {
	ans=qm(ans+1ll*sm[x]*lct.query(d[x],1)%mod);
	for(re int i=head[x];i;i=e[i].nxt) {
		lct.ins(1,d[x],n-sm[e[i].v]-sm[x]);
		dfs2(e[i].v);
		lct.ins(1,d[x],sm[x]+sm[e[i].v]-n);
	}
}
int main() {
	n=read(),m=read();
	for(re int x,i=1;i<=n;i++) {
		x=read();if(x) add(x,i);else rt=i;
	} 
	for(re int x,i=1;i<=m;i++) {
		x=read();if(x) lct.link(x,i);
	}
	dfs1(rt);scanf("%s",S+1);T.pdfs(1);
	for(re int i=1;i<=n;i++) {
		d[i]=read();
		ans=qm(ans+1ll*sm[i]*T.deep[d[i]]%mod*T.v[d[i]]%mod);
		T.v[d[i]]=qm(T.v[d[i]]+sm[i]);
	}
	T.dfs(1,0);dfs2(rt);printf("%d
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/asuldb/p/11503026.html