NOI2018 情报中心

NOI2018 情报中心 [* hard]

给定一棵树,大小为 (n),边有边权。

给定 (m) 条链,每条链有权值 (w_i),从 (u_i o v_i)

选两条链,满足两者有交,且两者的链并的边权和减去两条链的权值最大。

(nle 5 imes 10^4,mle 10^5,sum nle 10^6,sum mle 2 imes 10^6),时限 ( m 8s)

( m Sol:)

神仙题。

根据 LCA 进行分类讨论:

  1. 一条链的 LCA 在另一条链上。

此时一定形如两条直上直下的链的交,所以将链拆开。

对于两条直上直下的链,答案为独立的权值减去交的部分。

枚举交点中深度最大的点,那么交肯定是 (dep_x-max{dep_{u},dep_v})

对于 (dep_x),其要求答案来自其不同的子树,不然存在深度更大的点。

对于每个点 (x),设 (f(x,j)) 表示下端在 (x) 的子树内,上端在深度 (j) 处的最大值。

对于某个 (j),答案是所有比他小的深度的最大值,即前缀最大值。

实际上每个点对答案的贡献只有两种,同时 ( m dep_x) 本身是对答案没有影响的部分

通过线段树合并维护即可,或者说每次并一棵新树的时候更新一下答案就可以了,类似于 PKUWC MinMax

也可以启发式合并来更新答案。

  1. LCA 相同

枚举 LCA,答案的形式建议画图。

然后会发现一个非常神仙的点:

[frac{1}{2}(w_i+w_j+dis(a_i,a_j)+dis(b_i,b_j))-c_i-c_j ]

就是答案。

对相同的 LCA 建虚树,枚举 (a_i)(a_j) 的 LCA 为 (u),那么 (dis(a_i,a_j)) 可以用 ((d_i+d_j-2d_{lca})),注意到 (2d_{lca}) 是固定的,(w_i+a_i) 设为定值 (u),那么考虑 dfs,在 (b_i) 下面增加节点,不难发现目标是求原树的直径。

由于 (2d_{lca}) 会减少,所以我们需要两个节点来自不同的子树,相当于查询两个点集独立点构成的直径。

这个可以直接套性质,在边权非负的时候成立,同时注意到负数只可能在补充的叶子节点上产生,所以这个结论仍然是对的。

建立虚树后合并答案,复杂度为 (mathcal O(mlog n)),瓶颈是 LCA 的查询。

(Code:)

#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define int long long
#define pb push_back
int gi() {
	char cc = getchar() ; int cn = 0, flus = 1 ;
	while( cc < '0' || cc > '9' ) {  if( cc == '-' ) flus = - flus ; cc = getchar() ; }
	while( cc >= '0' && cc <= '9' )  cn = cn * 10 + cc - '0', cc = getchar() ;
	return cn * flus ;
} 
const int inf = 1e17 + 5 ; 
const int Inf = 1e16 + 7 ;  
const int N = 2e5 + 5 ; 
struct E { int to, next, w ; } e[N << 1] ; 
struct Line { int a, b, len, w, lca ; } A[N] ;
int n, m, cnt, Ans, idx, head[N], dfn[N], dep[N], dis[N], top[N], fa[N], son[N], sz[N] ; 
void add( int x, int y, int z ) {
	e[++ cnt] = (E){ y, head[x], z }, head[x] = cnt,
	e[++ cnt] = (E){ x, head[y], z }, head[y] = cnt ;  
}
void dfs1(int x, int ff) {
	dep[x] = dep[ff] + 1, fa[x] = ff, sz[x] = 1 ; 
	Next(i, x) {
		int v = e[i].to ; if( v == ff ) continue ; 
		dis[v] = dis[x] + e[i].w, dfs1(v, x), sz[x] += sz[v] ;
		if( sz[son[x]] <= sz[v] ) son[x] = v ;  
	}
}
void dfs2(int x, int high) {
	top[x] = high, dfn[x] = ++ idx ; if( son[x] ) dfs2(son[x], high) ; 
	Next(i, x) {
		int v = e[i].to ; if( v == son[x] || v == fa[x] ) continue ; 
		dfs2(v, v) ; 
	}
}
int LCA(int x, int y) {
	while(top[x] != top[y]) {
		if( dep[top[x]] < dep[top[y]] ) swap(x, y) ; 
		x = fa[top[x]] ;  
	} return ( dep[x] > dep[y] ) ? y : x ; 
}
int Dis(int x, int y) {
	return dis[x] + dis[y] - 2 * dis[LCA(x, y)] ;  
}
void Init() {
	rep( i, 1, n ) head[i] = son[i] = sz[i] = dis[i] = dep[i] = top[i] = fa[i] = 0 ;
	cnt = 0, Ans = -inf ; 
}
bool cmp( int x, int y ) { return dfn[x] < dfn[y] ; }
namespace Solve1 {
	#define ls(x) tr[x].l
	#define rs(x) tr[x].r
	struct Tr {
		int l, r, f, g ; 
		void init() { l = r = 0, f = g = -inf ; }
	} tr[N * 15] ; 
	struct node { int l, f, g ; } ;
	vector<node> F[N] ; 
	int num, rt[N], uA ; 
	void merge( int &x, int u, int l, int r, int type ) {
		if( !x || !u ) return x = x + u, void() ; 
		if( l == r ) { 
			return tr[x].f = max( tr[x].f, tr[u].f ), tr[x].g = max( tr[x].g, tr[u].g ), void() ; 
		} 
		int mid = ( l + r ) >> 1 ; 
		if( type ) uA = max( uA, tr[rs(u)].g + tr[ls(x)].f ), uA = max( uA, tr[ls(u)].f + tr[rs(x)].g ) ; 
		merge(ls(x), ls(u), l, mid, type), merge(rs(x), rs(u), mid + 1, r, type) ; 
		tr[x].f = max( tr[ls(x)].f, tr[rs(x)].f ), tr[x].g = max( tr[ls(x)].g, tr[rs(x)].g ) ;
	} 
	void del(int x, int l, int r, int k) {
		if( l == r ) return tr[x].f = tr[x].g = -inf, void() ; 
		int mid = ( l + r ) >> 1 ; 
		if( k <= mid ) del( ls(x), l, mid, k ) ;
		else del( rs(x), mid + 1, r, k ) ; 
		tr[x].f = max( tr[ls(x)].f, tr[rs(x)].f ), 
		tr[x].g = max( tr[ls(x)].g, tr[rs(x)].g ) ; 
	}
	void ins(int &x, int l, int r, int k, int f, int g, int type) { 
		if( !x ) x = ++ num, tr[x].init() ; 
		if( l == r ) { 
			tr[x].f = max( tr[x].f, f ), tr[x].g = max( tr[x].g, g ) ; return ; 
		} int mid = ( l + r ) >> 1 ; 
		if( k <= mid ) {
			if(type) uA = max( uA, tr[rs(x)].g + f ) ; 
			ins( ls(x), l, mid, k, f, g, type ) ; 
		}
		else {
			if(type) uA = max( uA, tr[ls(x)].f + g ) ; 
			ins( rs(x), mid + 1, r, k, f, g, type ) ; 
		}
		tr[x].f = max( tr[x].f, f ), tr[x].g = max( tr[x].g, g ) ;
	}
	void dfs(int x, int fa) {
		int fl = 0, ans = -inf ; uA = -inf ; 
		for( node u : F[x] ) if( u.l < dep[x] ) uA = -inf, 
		ins(rt[x], 1, n, u.l, u.f, u.g, fl ), ans = max( ans, uA ), fl = 1 ; 
		ans = max( ans, uA ) ; 
		Next( i, x ) {
			int v = e[i].to ; if( v == fa ) continue ; 
			dfs(v, x), uA = -inf, merge(rt[x], rt[v], 1, n, fl), ans = max( ans, uA ), fl = 1 ; 
		}
		Ans = max( Ans, ans - dis[x] ) ; 
		if( x != 1 ) del(rt[x], 1, n, dep[x] - 1) ; 
	}
	void solve() {
		node u ; 
		rep( i, 1, m ) 
			u = (node){ dep[A[i].lca], A[i].len - A[i].w, A[i].len - A[i].w + dis[A[i].lca] },
			F[A[i].a].pb(u), F[A[i].b].pb(u) ; 
		tr[0].f = tr[0].g = -inf, dfs(1, 1) ; 
	}
	void init() {
		rep( i, 1, n ) F[i].clear(), F[i].shrink_to_fit(), rt[i] = 0 ; 
		rep( i, 1, num ) tr[i].init() ; 
		num = 0 ; 
	}
}
namespace S2 {
	struct node { int a, b, len, w ; } ;
	struct qwq { int u, w ; } ;
	struct Li { 
		int a, b, w1, w2, d ; 
		void init() { a = b = 0, w1 = w2 = d = -inf ; }
	} zk[N] ;
	int num, K[N], top, st[N] ; vector<qwq> R[N] ; 
	vector<node> F[N] ; vector<int> G[N] ; 
	void Add(int x, int y) { G[x].pb(y) ; }
	void insert(int x) {
		int u = LCA( x, st[top] ) ; if( u == x ) return ;
		while( dfn[u] < dfn[st[top - 1]] ) Add( st[top - 1], st[top] ), -- top ; 
		if( dfn[u] < dfn[st[top]] ) Add( u, st[top] ), -- top ;
		if( dfn[u] > dfn[st[top]] ) st[++ top] = u ; 
		st[++ top] = x ; 
	}
	void inc(int x, qwq o, int &sz) {
		if( sz == 0 ) { ++ sz, zk[x].a = o.u, zk[x].w1 = o.w ; return ; }
		if( sz == 1 ) { ++ sz, zk[x].b = o.u, zk[x].w2 = o.w, zk[x].d = Dis(zk[x].a, zk[x].b) + zk[x].w1 + zk[x].w2 ; return ; }
		if( sz >= 2 ) {
			int v = o.u, w = o.w ;
			int d1 = Dis( zk[x].a, v ) + w + zk[x].w1 ;
			int d2 = Dis( zk[x].b, v ) + w + zk[x].w2 ; 
			if( d1 > zk[x].d && d1 >= d2 ) zk[x].b = v, zk[x].w2 = w, zk[x].d = d1 ; 
			if( d2 > zk[x].d && d2 > d1 ) zk[x].a = v, zk[x].w1 = w, zk[x].d = d2 ; 
		}
	}
	void Dfs(int x) {
		int ans = -inf, fl = 0, sz = 0 ; zk[x].init() ; 
		for( qwq o : R[x] ) o.w += dis[x], inc( x, o, sz ) ;
		Ans = max( Ans, ( zk[x].d - 2 * dis[x] ) / 2 ) ; 
		for(int v : G[x]) {
			Dfs(v) ; 
			if( zk[v].a ) {
				if( zk[x].a ) Ans = max( Ans, (Dis(zk[x].a, zk[v].a) + zk[x].w1 + zk[v].w1 - 2 * dis[x]) / 2 ) ;
				if( zk[x].b ) Ans = max( Ans, (Dis(zk[x].b, zk[v].a) + zk[x].w2 + zk[v].w1 - 2 * dis[x]) / 2 ) ;
			}
			if( zk[v].b ) {
				if( zk[x].a ) Ans = max( Ans, (Dis(zk[x].a, zk[v].b) + zk[x].w1 + zk[v].w2 - 2 * dis[x]) / 2 ) ;
				if( zk[x].b ) Ans = max( Ans, (Dis(zk[x].b, zk[v].b) + zk[x].w2 + zk[v].w2 - 2 * dis[x]) / 2 ) ;
			}
			if( zk[v].a ) inc(x, (qwq){zk[v].a, zk[v].w1}, sz) ;
			if( zk[v].b ) inc(x, (qwq){zk[v].b, zk[v].w2}, sz) ; 
			zk[v].init() ; 
		} R[x].clear(), G[x].clear(), R[x].shrink_to_fit(), G[x].shrink_to_fit() ; 
	}
	void Solve(int x) {
		if( F[x].size() < 2 ) return ; 
		for( node u : F[x] ) K[++ num] = u.a, K[++ num] = u.b ; 
		for( node u : F[x] ) R[u.a].pb((qwq){u.b, u.len - 2 * u.w}), R[u.b].pb((qwq){u.a, u.len - 2 * u.w}) ; 
		st[++ top] = x ; 
		sort( K + 1, K + num + 1, cmp ) ; 
		rep( i, 1, num ) insert(K[i]) ; 
		while( top > 1 ) Add( st[top - 1], st[top] ), -- top ; 
		for( int v : G[x] ) Dfs(v), zk[v].init() ; 
		R[x].clear(), G[x].clear(), zk[x].init(), top = num = 0 ;
		R[x].shrink_to_fit(), G[x].shrink_to_fit() ;
	}
	void solve() {
		rep( i, 1, n ) zk[i].init() ; 
		rep( i, 1, m ) 
			F[A[i].lca].pb((node){A[i].a, A[i].b, A[i].len, A[i].w}) ;
		rep( i, 1, n ) Solve(i) ; 
		rep( i, 1, n ) F[i].clear(), F[i].shrink_to_fit() ;
	}
}
signed main()
{
	int T = gi() ; 
	while( T-- ) {
		n = gi() ; int x, y, z ; Ans = -inf ;
		rep( i, 2, n ) x = gi(), y = gi(), z = gi(), add(x, y, z) ;
		dfs1(1, 1), dfs2(1, 1) ;
		m = gi() ; 
		rep( i, 1, m ) 
			A[i].a = gi(), A[i].b = gi(), A[i].w = gi(), 
			A[i].lca = LCA(A[i].a, A[i].b), A[i].len = Dis(A[i].a, A[i].b) ; 
		Solve1::solve() ; S2::solve() ; 
		if( Ans < -Inf ) puts("F") ; 
		else printf("%lld
", Ans ) ; 
		Init(), Solve1::init() ; 
	}
	return 0 ;
}
原文地址:https://www.cnblogs.com/Soulist/p/13672558.html