hihocoder1455 Rikka with Tree III(bitset 莫队 dfs序)

//http://www.cnblogs.com/IMGavin/
//http://hihocoder.com/problemset/problem/1455
//https://media.hihocoder.com/contests/challenge25/solution.pdf
//bitset 莫队 dfs序

/*
in1[]表示当前处理的子树的bitset状态,in2[]为in1的翻转,out1[]表示当前处理的子树外部的bitset状态,out2[]为out1的翻转
*/
#include <iostream>
#include <stdio.h>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <vector>
#include <map>
#include <stack>
#include <set>
#include <bitset>
#include <cmath>
#include <algorithm>

using namespace std;

typedef long long LL;
const int INF = 0x3F3F3F3F, N = 51200, MOD = 1003;
int n;
int val[N], ic[N], oc[N];
bitset<N> ans, in1, in2, out1, out2;
int head[N];
int lf[N], ri[N], seq2node[N];
int tot, lab;

int fa[N];

struct Edge {
    int to;
    int next;
}edge[N * 2];

void init(){
    memset(head, -1, sizeof(head));
    tot = 0;
}

void add(int st, int to){
    edge[tot].to =to;
    edge[tot].next = head[st];
    head[st]= tot++;
}

void dfs(int u, int f){
    fa[u] = f;
    lf[u] = ++lab;
    seq2node[lab] = u;
    for(int i = head[u]; i != -1; i = edge[i].next){
        int  v = edge[i].to;
        if(v != f){
            dfs(v, u);
        }
    }
    ri[u] = lab;
}

struct node{
    int l, r, u;
}q[N];
int blk;
bool cmp(const node &a, const node &b){
    if(a.l / blk != b.l / blk){
        return a.l / blk < b.l / blk;
    }else{
        return a.r < b.r;
    }
}

inline void add(int x){
    x = val[ seq2node[x] ];
    ic[x]++;
    oc[x]--;
    if(ic[x] == 1){
        in1[x] = 1;
        in2[n + 1 - x] = 1;
    }
    if(oc[x] == 0){
        out1[x] = 0;
        out2[n + 1 - x] = 0;
    }
}

inline void remove(int x){
    x = val[ seq2node[x] ];
    ic[x]--;
    oc[x]++;
    if(ic[x] == 0){
        in1[x] = 0;
        in2[n + 1 - x] = 0;
    }
    if(oc[x] == 1){
        out1[x] = 1;
        out2[n + 1 - x] = 1;
    }
}

void solve(){
    blk = (int)sqrt(n + 0.5);
    sort(q + 1, q + 1 + n, cmp);
    for(int i = 1; i <= n; i++){
        oc[val[i]]++;
        if(oc[val[i]] == 1){
            out1[val[i]] = 1;
            out2[n + 1 - val[i]] = 1;
        }
    }

    int l = 1, r = 0;
    for(int i = 1; i <= n; i++){

        while(l < q[i].l){
            remove(l);
            l++;
        }
        while(l > q[i].l){
            l--;
            add(l);
        }

        while(r < q[i].r){
            r++;
            add(r);
        }
        while(r > q[i].r){
            remove(r);
            r--;
        }

        if(fa[q[i].u] != -1){
            int v = val[fa[q[i].u]];
            ans |= ((in1 >> v) & (out2>>(n + 1 - v))) | ((out1>>v) & (in2>>(n + 1 - v)));
        }
    }
}

int main(){
    cin >> n;
    for(int i = 1; i <= n; i++){
        scanf("%d", &val[i]);
    }
    init();
    for(int i = 2; i <= n; i++){
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v);
        add(v, u);
    }
    lab = 0;
    dfs(1, -1);
    for(int i = 1; i <= n; i++){
        q[i].u = i;
        q[i].l = lf[i];
        q[i].r = ri[i];
    }
    solve();
    int cnt = 0;
    for(int i = 1; i <= n; i++){
        cnt += ans[i];
    }
    cout<<cnt<<endl;

    return 0;
}

  

原文地址:https://www.cnblogs.com/IMGavin/p/6279580.html