牛客练习赛81D 小Q与树

题意

Link

给定一棵树,每个点 (x) 有点权 (a_x),求:

[sum_{u eq v}operatorname{dis}(u,v)min{a_u,a_v} ]

Solution

考虑 dsu on tree。考虑当前我们在遍历 (l) 的后代,遍历到了 (u),那么其贡献为:

[sum_{operatorname{lca}(u,v)=l} (dep_u+dep_v-2dep_l)min{a_u,a_v} ]

对于所有 (a_u<a_v),其贡献为:

[egin{align*} &sum_{operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_u\ =&cnt(dep_u-2dep_l)a_u+a_usum_{operatorname{lca}(u,v)=l}dep_v end{align*} ]

其中 (cnt=sum_{operatorname{lca}(u,v)=l} [a_u<a_v])

对于所有 (a_uge a_v),其贡献为:

[egin{align*} &sum_{operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_v\ =&(dep_u-2dep_l)sum_{operatorname{lca}(u,v)=l}a_v+sum_{operatorname{lca}(u,v)=l}dep_va_v end{align*} ]

树状数组分别维护 (cnt)(sum a_v)(sum dep_v)(sum a_vdep_v) 即可。

#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
inline int read()
{
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
const int N=2e5+10,M=4e5+10,maxn=2e5,mod=998244353;
struct bit{
	int c[N];
	bit(){memset(c,0,sizeof(c));}
	void modify(int x,int d){for(;x<=maxn;x+=x&-x)c[x]+=d;}
	int query(int x){int ans=0;for(;x;x^=x&-x)ans+=c[x];return ans;}
}T,T1,T2,T3;
//T:dep[x],   T1:cnt,   T2:a[i],   T3:a[i]dep[i] 
int head[N],ver[M],nxt[M],tot=0;
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
int sz[N],son[N],f[N],dep[N];
void dfs(int x,int fa)
{
	sz[x]=1,dep[x]=dep[fa]+1;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa)continue;
		dfs(y,x),f[x]=(f[x]+1ll*sz[x]*sz[y]%mod)%mod,sz[x]+=sz[y];
		if(!son[x]||sz[y]>sz[son[x]])son[x]=y;
	}
}
int Ans[N],ans=0,a[N],t[N],dt=0;
void ff(int x)
{
	int sum1=T2.query(a[x]-1)%mod,xx=(dep[x]-dt*2+mod)%mod,Sum1=T3.query(a[x]-1)%mod	;
	int ans1=(1ll*sum1*xx%mod+Sum1)%mod;
	
	int cnt2=T1.query(maxn)-T1.query(a[x]-1),sum2=T.query(maxn)-T.query(a[x]-1);
	int ans2=(1ll*cnt2*t[a[x]]%mod*xx%mod+1ll*t[a[x]]*sum2%mod)%mod;
	
	ans+=(ans1+ans2)%mod;
	ans%=mod;
}
void calc(int x,int fa,int op)
{
	if(op==0)T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
	else if(op==1)ff(x);
	else T.modify(a[x],-dep[x]),T1.modify(a[x],-1),T2.modify(a[x],-t[a[x]]),T3.modify(a[x],-t[a[x]]*dep[x]%mod);;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa)continue;
		calc(y,x,op);
	}
}
void dsu(int x,int fa,int op)
{
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa||y==son[x])continue;
		dsu(y,x,0);
	}
	if(son[x])dsu(son[x],x,1);
	dt=dep[x];ff(x);
	T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa||y==son[x])continue;
		calc(y,x,1),calc(y,x,0);
	}
	Ans[x]=ans;
	if(!op)calc(x,fa,-1);
	ans=0;
}
signed main()
{
	int n=read(),m=n;
	for(int i=1;i<=n;i++)t[i]=a[i]=read();
	for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
	sort(t+1,t+m+1),m=unique(t+1,t+m+1)-t-1;
	for(int i=1;i<=n;i++)a[i]=lower_bound(t+1,t+m+1,a[i])-t;
//	for(int i=1;i<=n;i++)printf("a[%d]=%d
",i,a[i]);
	dfs(1,0),dsu(1,0,1);
//	for(int i=1;i<=n;i++)printf("dep[%d]=%d
",i,dep[i]);
	int sum=0;
	for(int i=1;i<=n;i++)sum+=Ans[i],sum%=mod;
	printf("%lld",sum*2%mod);
        return 0;
}
原文地址:https://www.cnblogs.com/juruo-zzt/p/15479045.html