Codeforces Round #474 E. Alternating Tree

E. Alternating Tree
time limit per test
2 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output

Given a tree with nn nodes numbered from 11 to nn. Each node ii has an associated value ViVi.

If the simple path from u1u1 to umum consists of mm nodes namely u1u2u3um1umu1→u2→u3→…um−1→um, then its alternating function A(u1,um)A(u1,um)is defined as A(u1,um)=mi=1(1)i+1VuiA(u1,um)=∑i=1m(−1)i+1⋅Vui. A path can also have 00 edges, i.e. u1=umu1=um.

Compute the sum of alternating functions of all unique simple paths. Note that the paths are directed: two paths are considered different if the starting vertices differ or the ending vertices differ. The answer may be large so compute it modulo 109+7109+7.

Input

The first line contains an integer n(2n2105)(2≤n≤2⋅105) — the number of vertices in the tree.

The second line contains nn space-separated integers V1,V2,,VnV1,V2,…,Vn (109Vi109−109≤Vi≤109) — values of the nodes.

The next n1n−1 lines each contain two space-separated integers uu and v(1u,vn,uv)(1≤u,v≤ n,u≠v) denoting an edge between vertices uuand vv. It is guaranteed that the given graph is a tree.

Output

Print the total sum of alternating functions of all unique simple paths modulo 109+7109+7.

Examples
input
4
-4 1 5 -2
1 2
1 3
1 4
output
40
input
8
-2 6 -4 -4 -9 -3 -7 23
8 2
2 3
1 4
6 5
7 6
4 7
5 8
output
4
思路:只需求奇数长度路径权值和即可,因为偶数长度来回跑会消掉。树上dp,每个点维护从以此点为起点,终点在子树内的奇长路径数,偶长路径数,奇长路径权值和,偶长路径权值和即可。

(这题写完调了近4个小时,一直以为是dp转移或者更新答案的时候错了,今早起来一看发现是设置访问标记的地方错了)

  1 #include <iostream>
  2 #include <fstream>
  3 #include <sstream>
  4 #include <cstdlib>
  5 #include <cstdio>
  6 #include <cmath>
  7 #include <string>
  8 #include <cstring>
  9 #include <algorithm>
 10 #include <queue>
 11 #include <stack>
 12 #include <vector>
 13 #include <set>
 14 #include <map>
 15 #include <list>
 16 #include <iomanip>
 17 #include <cctype>
 18 #include <cassert>
 19 #include <bitset>
 20 #include <ctime>
 21 
 22 using namespace std;
 23 
 24 #define pau system("pause")
 25 #define ll long long
 26 #define pii pair<int, int>
 27 #define pb push_back
 28 #define mp make_pair
 29 #define clr(a, x) memset(a, x, sizeof(a))
 30 
 31 const double pi = acos(-1.0);
 32 const int INF = 0x3f3f3f3f;
 33 const int MOD = 1e9 + 7;
 34 const double EPS = 1e-9;
 35 
 36 /*
 37 #include <ext/pb_ds/assoc_container.hpp>
 38 #include <ext/pb_ds/tree_policy.hpp>
 39 
 40 using namespace __gnu_pbds;
 41 tree<pli, null_type, greater<pli>, rb_tree_tag, tree_order_statistics_node_update> T;
 42 */
 43 
 44 int n;
 45 ll V[200015], sume[200015], sumo[200015], cnte[200015], cnto[200015], sum;
 46 void mod(ll &x) {
 47     x = x % MOD;
 48     if (x < 0) x += MOD;
 49 }
 50 vector<int> E[200015];
 51 void dfs(int x, int p) {
 52     ll cur_sume = 0, cur_sumo = 0, cur_cnte = 0, cur_cnto = 0;
 53     for (int i = 0; i < E[x].size(); ++i) {
 54         int y = E[x][i];
 55         if (y == p) continue;
 56         dfs(y, x);
 57         cur_sume += sume[y];
 58         cur_sumo += sumo[y];
 59         cur_cnte += cnte[y];
 60         cur_cnto += cnto[y];
 61     }
 62     mod(cur_cnte), mod(cur_cnto), mod(cur_sume), mod(cur_sumo);
 63     for (int i = 0; i < E[x].size(); ++i) {
 64         int y = E[x][i];
 65         if (y == p) continue;
 66         ll cnte_y = cnte[y], cnte_other = cur_cnte - cnte_y;
 67         ll cnto_y = cnto[y], cnto_other = cur_cnto - cnto_y;
 68         ll sume_y = sume[y], sume_other = cur_sume - sume_y;
 69         ll sumo_y = sumo[y], sumo_other = cur_sumo - sumo_y;
 70         sum += cnte_y * sume_other + cnte_other * sume_y + cnte_other * cnte_y % MOD * V[x];
 71         sum += cnto_y * sumo_other + cnto_other * sumo_y - cnto_other * cnto_y % MOD * V[x];
 72         mod(sum);
 73     }
 74     cnto[x] = 1 + cur_cnte;
 75     cnte[x] = cur_cnto;
 76     sumo[x] = cur_cnte * V[x] + cur_sume;
 77     sume[x] = -cur_cnto * V[x] + cur_sumo;
 78     sum += 2 * sumo[x];
 79     sumo[x] += V[x];
 80     mod(cnto[x]), mod(cnte[x]), mod(sumo[x]), mod(sume[x]), mod(sum);
 81 }
 82 int main() {
 83     scanf("%d", &n);
 84     for (int i = 1; i <= n; ++i) {
 85         scanf("%lld", &V[i]);
 86         sum += V[i];
 87     }
 88     for (int i = 1; i < n; ++i) {
 89         int u, v;
 90         scanf("%d%d", &u, &v);
 91         E[u].pb(v);
 92         E[v].pb(u);
 93     }
 94     dfs(1, 0);
 95     printf("%lld
", sum);
 96     return 0;
 97 }
 98 /*
 99 4
100 4 1 5 2
101 1 2
102 1 3
103 1 4
104 */
View Code
原文地址:https://www.cnblogs.com/BIGTOM/p/8776395.html