//http://www.cnblogs.com/IMGavin/ //http://hihocoder.com/problemset/problem/1455 //https://media.hihocoder.com/contests/challenge25/solution.pdf //bitset 莫队 dfs序 /* in1[]表示当前处理的子树的bitset状态,in2[]为in1的翻转,out1[]表示当前处理的子树外部的bitset状态,out2[]为out1的翻转 */ #include <iostream> #include <stdio.h> #include <cstdlib> #include <cstring> #include <queue> #include <vector> #include <map> #include <stack> #include <set> #include <bitset> #include <cmath> #include <algorithm> using namespace std; typedef long long LL; const int INF = 0x3F3F3F3F, N = 51200, MOD = 1003; int n; int val[N], ic[N], oc[N]; bitset<N> ans, in1, in2, out1, out2; int head[N]; int lf[N], ri[N], seq2node[N]; int tot, lab; int fa[N]; struct Edge { int to; int next; }edge[N * 2]; void init(){ memset(head, -1, sizeof(head)); tot = 0; } void add(int st, int to){ edge[tot].to =to; edge[tot].next = head[st]; head[st]= tot++; } void dfs(int u, int f){ fa[u] = f; lf[u] = ++lab; seq2node[lab] = u; for(int i = head[u]; i != -1; i = edge[i].next){ int v = edge[i].to; if(v != f){ dfs(v, u); } } ri[u] = lab; } struct node{ int l, r, u; }q[N]; int blk; bool cmp(const node &a, const node &b){ if(a.l / blk != b.l / blk){ return a.l / blk < b.l / blk; }else{ return a.r < b.r; } } inline void add(int x){ x = val[ seq2node[x] ]; ic[x]++; oc[x]--; if(ic[x] == 1){ in1[x] = 1; in2[n + 1 - x] = 1; } if(oc[x] == 0){ out1[x] = 0; out2[n + 1 - x] = 0; } } inline void remove(int x){ x = val[ seq2node[x] ]; ic[x]--; oc[x]++; if(ic[x] == 0){ in1[x] = 0; in2[n + 1 - x] = 0; } if(oc[x] == 1){ out1[x] = 1; out2[n + 1 - x] = 1; } } void solve(){ blk = (int)sqrt(n + 0.5); sort(q + 1, q + 1 + n, cmp); for(int i = 1; i <= n; i++){ oc[val[i]]++; if(oc[val[i]] == 1){ out1[val[i]] = 1; out2[n + 1 - val[i]] = 1; } } int l = 1, r = 0; for(int i = 1; i <= n; i++){ while(l < q[i].l){ remove(l); l++; } while(l > q[i].l){ l--; add(l); } while(r < q[i].r){ r++; add(r); } while(r > q[i].r){ remove(r); r--; } if(fa[q[i].u] != -1){ int v = val[fa[q[i].u]]; ans |= ((in1 >> v) & (out2>>(n + 1 - v))) | ((out1>>v) & (in2>>(n + 1 - v))); } } } int main(){ cin >> n; for(int i = 1; i <= n; i++){ scanf("%d", &val[i]); } init(); for(int i = 2; i <= n; i++){ int u, v; scanf("%d %d", &u, &v); add(u, v); add(v, u); } lab = 0; dfs(1, -1); for(int i = 1; i <= n; i++){ q[i].u = i; q[i].l = lf[i]; q[i].r = ri[i]; } solve(); int cnt = 0; for(int i = 1; i <= n; i++){ cnt += ans[i]; } cout<<cnt<<endl; return 0; }