CF766E Mahmoud and a xor trip

首先按照套路,我们按位考虑,那么就变成统计 (0)(1) 两种情况的组合的数目,最后乘2的某次幂。

那么记 (dp_{i,j,0/1}) 表示经过以 (i) 为根的子树的每一个节点,在第 (j) 位上产生了多少个 (0)(1),然后 dfs 一遍,用儿子的答案更新父亲的答案,最后累加就好了。注意要考虑长度为 (0) 的路径。

#include <bits/stdc++.h>
#define reg register
#define ll long long
#define int long long
#define ull unsigned long long
#define db double
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vi vector<int>
#define vl vector<ll>
#define vpi vector<pi>
#define vpl vector<pl>
#define pb push_back
#define er erase
#define SZ(x) (int) x.size()
#define lb lower_bound
#define ub upper_bound
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define mkp make_pair
#define ms(data_name) memset(data_name, 0, sizeof(data_name))
#define msn(data_name, num) memset(data_name, num, sizeof(data_name))
#define For(i, j) for(reg int (i) = 1; (i) <= (j); ++(i))
#define For0(i, j) for(reg int (i) = 0; (i) < (j); ++(i))
#define Forx(i, j, k) for(reg int (i) = (j); (i) <= (k); ++(i))
#define Forstep(i , j, k, st) for(reg int (i) = (j); (i) <= (k); (i) += (st))
#define fOR(i, j) for(reg int (i) = (j); (i) >= 1; (i)--)
#define fOR0(i, j) for(reg int (i) = (j) - 1; (i) >= 0; (i)--)
#define fORx(i, j, k) for(reg int (i) = (k); (i) >= (j); (i)--)
#define tour(i, u) for(reg int (i) = head[(u)]; (i) != -1; (i) = nxt[(i)])
using namespace std;
char ch, B[1 << 20], *S = B, *T = B;
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 20, stdin), S == T) ? 0 : *S++)
#define isd(c) (c >= '0' && c <= '9')
int rdint() {
  int aa, bb;
  while(ch = getc(), !isd(ch) && ch != '-');
  ch == '-' ? aa = bb = 0 : (aa = ch - '0', bb = 1);
  while(ch = getc(), isd(ch))
    aa = aa * 10 + ch - '0';
  return bb ? aa : -aa;
}
ll rdll() {
  ll aa, bb;
  while(ch = getc(), !isd(ch) && ch != '-');
  ch == '-' ? aa = bb = 0 : (aa = ch - '0', bb = 1);
  while(ch = getc(), isd(ch))
    aa = aa * 10 + ch - '0';
  return bb ? aa : -aa;
}
const int mod = 998244353;
// const int mod = 1e9 + 7;
struct mod_t {
  static int norm(int x) {
    return x + (x >> 31 & mod);
  }
  int x;
  mod_t() {  }
  mod_t(int v) : x(v) {  }
  // mod_t(ll v) : x(v) {  }
  mod_t(char v) : x(v) {  }
  mod_t operator +(const mod_t &rhs) const {
    return norm(x + rhs.x - mod);
  }
  mod_t operator -(const mod_t &rhs) const {
    return norm(x - rhs.x);
  }
  mod_t operator *(const mod_t &rhs) const {
    return (ll) x * rhs.x % mod;
  }
};
const int MAXN = 2e5 + 10;
int n, a[MAXN], dp[MAXN][20][2];
ll ans = 0;
int E, head[MAXN], nxt[MAXN << 1], pnt[MAXN << 1];
inline void clear() {
  E = 0;
  msn(head, -1);
}
inline void addedge(int x, int y) {
  nxt[E] = head[x];
  pnt[E] = y;
  head[x] = E++;
}
inline void dfs(int u, int f) {
  For0(i, 20)
    if(a[u] & (1 << i))
      dp[u][i][1] = 1;
    else
      dp[u][i][0] = 1;
  tour(i, u) {
    int v = pnt[i];
    if(v != f) {
      dfs(v, u);
      For0(j, 20) {
        int tmp = (a[u] >> j) & 1;
        ans += (dp[u][j][1] * dp[v][j][0] + dp[u][j][0] * dp[v][j][1]) << j;
        dp[u][j][tmp ^ 0] += dp[v][j][0];
        dp[u][j][tmp ^ 1] += dp[v][j][1];
      }
    }
  }
}
inline void work() {
  n = rdint();
  For(i, n) {
    a[i] = rdint();
    ans += a[i];
  }
  clear();
  Forx(i, 2, n) {
    int u = rdint(), v = rdint();
    addedge(u, v);
    addedge(v, u);
  }
  dfs(1, -1);
  printf("%lld
", ans);
}
signed main() {
  // freopen("input.txt", "r", stdin);
  work();
  return 0;
}
原文地址:https://www.cnblogs.com/Lonely-233/p/13659199.html