[牛客] [#1108 J] [树形结构] 买一送一

2019牛客国庆集训派对day3

链接:https://ac.nowcoder.com/acm/contest/1108/J
来源:牛客网

题意

ICPCCamp 有 n 个商店,用 $1,2,...,n$ 编号。对于任意 i > 1,有从商店 $p_i$ 到 i 的单向道路。
同时,商店 i 出售类型为 $a_i$ 的商品。
Bobo 从商店 1 出发前往商店 i。他要在两个不同的商店购买商品(包括商店 1 和 i)。设他先购买的商品类型是 x,后购买的商品类型是 y,他用 $f_i$ 表示不同的有序对 $langle x, y angle$ 的数量。
求出 $f_2, f_3, dots, f_n$ 的值。

思路

由于每个点只有一个入度,所以可以转化为根为1的树形结构。

那么问题就转化为,从根$1$ 到 $i$ 的路径上不同有序对的个数。

// #pragma GCC optimize(2)
// #pragma GCC optimize(3)
// #pragma GCC optimize(4)
#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <cstdlib>
#include   <iomanip>
#include    <bitset>
#include    <cctype>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <stack>
#include     <cmath>
#include     <queue>
#include      <list>
#include       <map>
#include       <set>
#include   <cassert>
//#include <unordered_set>
//#include <unordered_map>
// #include<bits/extc++.h>
// using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<<endl;
#define FOR(a, b, c) for(int a = b; a <= c; ++ a)

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;

const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9+7;


template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

/**********showtime************/
            const int maxn = 1e5+9;
            int p[maxn], a[maxn];
            ll f[maxn];
            int sum,kind[maxn],cnt[maxn];
            vector<int>mp[maxn];
            
            ll add(int col) {
                ll res = 0;
                res = sum - kind[col];
                if(cnt[col] == 1) res ++;
                else if(cnt[col] == 0) sum ++;
                cnt[col] ++;
                kind[col] = sum;
                return res;
            }
            void del(int col) {
                cnt[col] --;
                if(cnt[col] == 0) sum--;
            }
            
            void dfs(int u, int fa, int dp) {
                f[u] = f[fa];
                int yk = kind[a[u]];
                f[u] += add(a[u]);

                for(int v : mp[u]) {
                    dfs(v, u, dp+1);
                }

                del(a[u]);

                kind[a[u]] = yk;
            }
int main(){
            int n;
            while(~scanf("%d", &n)) {

                for(int i=1; i<=n; i++) mp[i].clear(), f[i] = 0, kind[i] = 0;
                sum = 0;

                for(int i=2; i<=n; i++) {
                    scanf("%d", &p[i]);
                    mp[p[i]].pb(i);
                }
                for(int i=1; i<=n; i++) scanf("%d", &a[i]);

                dfs(1, 1, 1);

                for(int i=2; i<=n; i++) printf("%lld
", f[i]);
            }
            return 0;
}
View Code
原文地址:https://www.cnblogs.com/ckxkexing/p/11621858.html