洛谷 P3676 【小清新数据结构题】

怎么说呢,非常神的一道题

我们先忽略修改操作,考虑换根操作,假设我们的根从(u)换成了(v)那么可以注意到发生变化的(sz)只有两个,(u)(v)

于是我们有这次操作后的点权平方总和为(sum_{i=1}^nsz_i^2)

变化的权值则为(sum^2->(sum-sz_v)^2,sz_v^2->sum^2)

你会发现这个还是非常不好维护

但是貌似(sum_{i=1}^nsz_i*(sum-sz_i))是个定值

这是因为一次换根只会影响到(u,v)两个点,所以变化的只有那两个点,但他们的变化恰好是一个从(sz)变成了(sum-sz),另一个则变成了(sum)

所以这指示这对于任意的一个根(sum_{i=1}^nsz_i*(sum-sz_i))是一个定值

仔细思考会发现这个定值:

[sum_{i=1}^nsum*sz_i-sz_i^2=S ]

[sum sz^2=sum_{i=1}^nsum*sz_i-S ]

由于(S)为定值,假设我们可以求出来,那么我们就只需要维护(sum sum*sz_i)

显然的是(sum)可以提到前面去,于是我们就只需要维护(sz)之和

进一步思考,在换根了之后,一个点对(sz)之和的贡献貌似是其到根的距离(+1)

于是我们就有:

[sum sz = sum w_i*(dist(i,u)+1) ]

其中(dist)表示两点之间的路径,(u)为根

(+1)是可以提出来的,于是我们只需要管前面那一坨

然而有趣的是前面这一坨是要我们对于每个(u)维护(sum dist(i,u)*w_i)

这个东西好像可以用动态点分治来维护?

我们考虑建出点分树,对每个点维护一下子树内到其的(f(u)=sum dist(i,u)*w_i)(s(u)=sum_{}w_i)(g(u)=sum dist(i,fa)*w_i)

好像就可以转移了

[f(u)=f(fa)-g(u)+f(u)+(s(fa)-s(u))*dist(fa,u) ]

在点分树上暴力转移的复杂度是(O(log^2 n))

每次修改就暴力修改这三个值即可,复杂度同样(O(log^2n))

接下来考虑如何在修改之后维护( m S)

仔细思考( m S)的定义发现它是对于每个点其子树内的点权和*子树外的点权和

实际上是(sum w_i*w_j*dist(i,j))

[sum_{i=1}^nsum_{j=i+1}^{n}w_i*w_j*dist(i,j) ]

考虑修改,你会发现貌似对于绝大部分的点对其对答案的贡献也都没有变

貌似唯一变化的就是(sum_{i=1}^nw_u*w_i*dist(u,i))

而且因为只变了(w_u),所以我们提出来就是(w_u*sum_{i=1}^nw_i*dist(i,u))

这个就是答案的变化率

于是好像和换根要维护的东西是一样的...

所以也是点分树上暴力跳然后修改...

复杂度(O(log ^2n))

总体复杂度(O(nlog^2n+qlog^2n))

#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
#define re register
#define int long long
inline int gi() {
    char cc = getchar() ; int cn = 0, flus = 1 ; 
    while( cc > '9' || cc < '0' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
    while( cc <= '9' && cc >= '0' ) cn = cn * 10 + cc - '0', cc = getchar() ; 
    return cn * flus ; 
}
const int N = 2e5 + 5 ; 
int n, q, head[N], w[N], dp[N], vis[N], RS ; 
int d[N], f[N], g[N], fa[N], cnt, rt, root, sum, S ; 
int Fa[N], sz[N], Top[N], Son[N], dep[N], sw[N] ; 
struct E {
	int to, next ; 
} e[N * 2] ;
inline void add( int x, int y ) {
	e[++ cnt] = (E){ y, head[x] }, head[x] = cnt ,
	e[++ cnt] = (E){ x, head[y] }, head[y] = cnt ; 
}
inline void dfs1( int x, int ff ) {
	dep[x] = dep[ff] + 1, sz[x] = 1, Fa[x] = ff, sw[x] = w[x] ; 
	Next( i, x ) {
		int v = e[i].to ; if( v == ff ) continue ; 
		dfs1( v, x ), sz[x] += sz[v], sw[x] += sw[v], d[x] += d[v] ; 
		if( sz[v] > sz[Son[x]] ) Son[x] = v ;  
	}
	S += ( RS - sw[x] ) * sw[x] ; 
}
inline void dfs2( int x, int ff ) {
	Top[x] = ff ; 
	if( Son[x] ) dfs2( Son[x], ff ) ;
	Next( i, x ) {
		int v = e[i].to ; if( v == Fa[x] || v == Son[x] ) continue ;
		dfs2( v, v ) ;  
	} 
}
int LCA( int x, int y ) {
	while( Top[x] != Top[y] ) {
		if( dep[Top[x]] < dep[Top[y]] ) swap( x, y ) ;
		x = Fa[Top[x]] ; 
	}
	return ( dep[x] > dep[y] ) ? y : x ; 
}
inline void get_rt( int x, int ff ) {
	sz[x] = 1, dp[x] = 0 ; 
	Next( i, x ) {
		int v = e[i].to ; if( v == ff || vis[v] ) continue ; 
		get_rt( v, x ), sz[x] += sz[v] ;
		dp[x] = max( sz[v], dp[x] ) ;
	}
	dp[x] = max( dp[x], sum - sz[x] ) ;
	if( dp[x] <= dp[rt] ) rt = x ; 
}
inline void solve( int x ) {
	vis[x] = 1 ; 
	Next( i, x ) {
		int v = e[i].to ; if( vis[v] ) continue ; 
		rt = 0, dp[0] = sum = sz[v], get_rt( v, x ),
		fa[rt] = x, solve( rt ) ;
	}
}
inline int dist( int x, int y ) {
	return dep[x] + dep[y] - 2 * dep[LCA(x, y)] ; 
}
void Init( int x ) {
	rep( i, 1, n ) {
		int u = i, fr ; d[i] += w[i] ;
		while( fa[u] ) {
			fr = dist( i, fa[u] ), f[fa[u]] += fr * w[i], 
			d[fa[u]] += w[i], g[u] += fr * w[i], u = fa[u] ;
		}
	}
}
int Query( int x ) {
	int Ans = f[x], u = x ; 
	while( fa[u] ) {
		Ans += ( f[fa[u]] - g[u] ) ;
		Ans += ( d[fa[u]] - d[u] ) * dist( fa[u], x ) ;
		u = fa[u] ; 
	}
	return Ans ; 
}
void Update( int x, int p ) {
	int u = x, fr, y = p - w[x] ; 
	int ru = Query( x ) ;
	S += ru * y, d[x] += y, RS += y ; ;
	while( fa[u] ) {
		fr = dist( x, fa[u] ), f[fa[u]] += fr * y, 
		d[fa[u]] += y, g[u] += fr * y, u = fa[u] ;
	}
	w[x] = p ; 
}
signed main() {
    sum = n = gi(), q = gi(), rt = 0, dp[0] = n + 1 ; int opt, x, y ; 
    rep( i, 2, n ) x = gi(), y = gi(), add( x, y ) ;
    rep( i, 1, n ) w[i] = gi(), RS += w[i] ; 
    dfs1( 1, 1 ), dfs2( 1, 1 ), get_rt( 1, 1 ), 
	root = rt, solve( rt ), Init(root) ; 
	while( q-- ) {
		opt = gi(), x = gi() ;
		if( opt == 1 ) y = gi(), Update( x, y ) ; 
		else printf("%lld
", ( Query( x ) + RS ) * RS - S ) ;
	}
	return 0 ; 
}
原文地址:https://www.cnblogs.com/Soulist/p/11637321.html