Solution -「51nod 1868」彩色树

(mathcal{Description})

  Link & 双倍经验 Link.

  给定一棵 (n) 个结点的树,每个结点有一种颜色。记 (g(u,v)) 表示 (u)(v) 简单路径上的颜色种数,求

[sum_{{p_n}}sum_{i=1}^{n-1}g(p_i,p_{i+1}) ]

  其中 ({p_n}) 表示 (1sim n) 的排列。

  (nle10^5),答案对 ((10^9+7)) 取模。

(mathcal{Solution})

  常见但不熟悉的 trick 题。(

  不难想到分颜色计算贡献。而“路径上至少出现某种颜色”不好计算,考虑反向计算“路径上不包含某种颜色”的路径条数。

  对于颜色 (c),删除所有颜色为 (c) 的结点,记得到联通块的大小为 (s_{1..m}),那么上述路径数量为 (sum_{i=1}^mfrac{s_i(s_i-1)}2)。称一个联通块深度最小的结点为其顶点,我们尝试在顶点处对这个联通块计数。可以看出,要不顶点是根,要不顶点父亲的颜色为 (c)。前者单独考虑,后者仅需用顶点子树大小减去顶点子树内以 (c) 颜色的结点作为根的子树大小。所以用前后作差的 trick:DFS 进入子树前,记录目前以 (c) 色点为根的子树大小和 (s),退出子树后,得到当前以 (c) 色点为根的子树大小和 (t)。那么该联通块的大小即为 (t-s)

  所以 DFS 一遍就可以啦。复杂度 (mathcal O(n))

(mathcal{Code})

/* Clearink */

#include <cstdio>
#include <vector>
#include <algorithm>

#define rep( i, l, r ) for ( int i = l, repEnd##i = r; i <= repEnd##i; ++i )
#define per( i, r, l ) for ( int i = r, repEnd##i = l; i >= repEnd##i; --i )

inline int rint () {
	int x = 0, f = 1; char s = getchar ();
	for ( ; s < '0' || '9' < s; s = getchar () ) f = s == '-' ? -f : f;
	for ( ; '0' <= s && s <= '9'; s = getchar () ) x = x * 10 + ( s ^ '0' );
	return x * f;
}

const int MAXN = 1e5, MOD = 1e9 + 7;
int n, ecnt, ans, clr[MAXN + 5], head[MAXN + 5], siz[MAXN + 5], sum[MAXN + 5];
struct Edge { int to, nxt; } graph[MAXN * 2 + 5];

inline int mul ( const long long a, const int b ) { return a * b % MOD; }
inline int sub ( int a, const int b ) { return ( a -= b ) < 0 ? a + MOD : a; }
inline int add ( int a, const int b ) { return ( a += b ) < MOD ? a : a - MOD; }

inline void link ( const int s, const int t ) {
	graph[++ecnt] = { t, head[s] };
	head[s] = ecnt;
}

inline int count ( const int s ) { return ( s * ( s - 1ll ) >> 1 ) % MOD; }

inline void dfs ( const int u, const int fa ) {
	siz[u] = 1, ++sum[clr[u]];
	for ( int i = head[u], v; i; i = graph[i].nxt ) {
		if ( ( v = graph[i].to ) ^ fa ) {
			int s = sum[clr[u]];
			dfs ( v, u ), siz[u] += siz[v];
			int t = sum[clr[u]], dlt = t - s;
			ans = add ( ans, count ( siz[v] - dlt ) );
			sum[clr[u]] += siz[v] - dlt;
		}
	}
}

int main () {
	n = rint ();
	rep ( i, 1, n ) clr[i] = rint ();
	for ( int i = 1, u, v; i < n; ++i ) {
		u = rint (), v = rint ();
		link ( u, v ), link ( v, u );
	}
	dfs ( 1, 0 );
	rep ( c, 1, n ) {
		if ( c ^ clr[1] ) {
			ans = add ( ans, count ( n - sum[c] ) );
		}
	}
	ans = sub ( mul ( n, count ( n ) ), ans );
	int fct = 1;
	rep ( i, 1, n - 2 ) fct = mul ( fct, i );
	printf ( "%d", mul ( ans, mul ( 2, mul ( n - 1, fct ) ) ) );
	return 0;
}

原文地址:https://www.cnblogs.com/rainybunny/p/14285093.html