HDU-5977

HDU - 5977

题意:

  给定一颗树,问树上有多少节点对,节点对间包括了所有K种苹果。

思路:

  点分治,对于每个节点记录从根节点到这个节点包含的所有情况,类似状压,因为K《=10。然后处理每个重根连着的点的值:直接枚举每个点,然后找出这个点对应的每个子集,累计和子集互补的个数。

  枚举一个数的子集,例如1010,它的子集包括1010,1000,0010,0000.这里有个技巧:

    for(int s = x; s; s = (s - 1) & x){
              res += 1ll*cnt[((1<<k)-1) ^ s];
       }
//#pragma GCC optimize(3)
//#pragma comment(linker, "/STACK:102400000,102400000")  //c++
// #pragma GCC diagnostic error "-std=c++11"
// #pragma comment(linker, "/stack:200000000")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")

#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <cstdlib>
#include   <iomanip>
#include    <bitset>
#include    <cctype>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <stack>
#include     <cmath>
#include     <queue>
#include      <list>
#include       <map>
#include       <set>
#include   <cassert>

using namespace std;
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "
";
#define pb push_back
#define pq priority_queue



typedef long long ll;
typedef unsigned long long ull;
//typedef __int128 bll;
typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int,pii> p3;

//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl '
'

#define OKC ios::sync_with_stdio(false);cin.tie(0)
#define FT(A,B,C) for(int A=B;A <= C;++A)  //用来压行
#define REP(i , j , k)  for(int i = j ; i <  k ; ++i)
#define max3(a,b,c) max(max(a,b), c);
#define min3(a,b,c) min(min(a,b), c);
//priority_queue<int ,vector<int>, greater<int> >que;

const ll mos = 0x7FFFFFFF;  //2147483647
const ll nmos = 0x80000000;  //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f; //18
const int mod = 1e9+7;
const double esp = 1e-8;
const double PI=acos(-1.0);
const double PHI=0.61803399;    //黄金分割点
const double tPHI=0.38196601;


template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}


/*-----------------------showtime----------------------*/


            const int maxn = 50009;
            int a[maxn],g[maxn],dp[maxn],cnt[maxn];
            vector<int>mp[maxn];
            int n,k;
            ll ans = 0;

            void dfs(int u,int fa){
                dp[u] = 1;
                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(g[v] || fa == v)continue;
                    dfs(v, u);
                    dp[u] += dp[v];
                }
            }
            pii findg(int u,int fa, int sz){
                int mx = 0;
                pii tmp = pii(inf, u);

                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(g[v] || fa == v)continue;
                    tmp = min(tmp, findg(v,u,sz));
                    mx = max(mx, dp[v]);
                }
                mx = max(mx, sz - dp[u]);
                return min(tmp, pii(mx, u));
            }

            void route(int u, int fa, vector<int>& ve, int sta){
                    sta = ((1<<a[u]) | sta);
                    ve.pb(sta);
                    for(int i=0; i<mp[u].size(); i++){
                        int v = mp[u][i];
                        if(v == fa || g[v])continue;
                        route(v, u, ve, sta);
                    } 
            }

            ll cal(vector<int> &ve){
                // memset(cnt, 0, sizeof(cnt));
                for(int i=0; i<2000; i++) cnt[i] = 0;

                for(int i=0; i<ve.size(); i++){
                    cnt[ve[i]] ++;
                }

                /*
                        Hash[it]-=1;
                        ans+=Hash[(1<<m)-1];
                        for(int j=it;j;j=(j-1)&it){
                            ans+=Hash[((1<<m)-1)^j];
                        }
                        Hash[it]+=1;
                */

                ll res = 0;
                for(int i=0; i<ve.size(); i++){
                    int x = ve[i];
                    cnt[ve[i]]--;
                    res += 1ll*cnt[(1<<k)-1];
                    for(int s = x; s; s = (s - 1) & x){
                        res += 1ll*cnt[((1<<k)-1) ^ s];
                    }
                    cnt[ve[i]]++;
                }
                return res;
            }
            void divide(int u){
                dfs(u,-1);
                int rt = findg(u, -1, dp[u]).se;
                g[rt] = 1;

                for(int i=0; i<mp[rt].size(); i++){
                    int v = mp[rt][i];
                    if(g[v])continue;
                    divide(v);
                }

                vector<int>all;
                all.pb((1<<a[rt]));
                for(int i=0; i<mp[rt].size(); i++){
                    vector<int>ve;
                    int v = mp[rt][i];
                    if(g[v])continue;
                    route(v, -1, ve, (1<<a[rt]));
                    ans -= 1ll*cal(ve);
                    all.insert(all.end(),ve.begin(),ve.end());
                }
                ans += 1ll*cal(all);
                g[rt] = 0;
            }
int main(){

            while(~scanf("%d%d", &n, &k)){
                for(int i=1; i<=n; i++) scanf("%d", &a[i]), a[i]--;
                for(int i=1; i<=n; i++) mp[i].clear();
                for(int i=1; i< n; i++) {
                    int u,v;    scanf("%d%d", &u, &v);
                    mp[u].pb(v);    mp[v].pb(u);
                }
                if(k == 1) {
                    ans = 1ll*n*n;
                    printf("%lld
", ans);
                    continue;
                }
                // memset(g,0,sizeof(g));
                
                ans = 0;
                divide(1);
                printf("%lld
", ans);
            }
            return 0 ;
}
HDU-5977
原文地址:https://www.cnblogs.com/ckxkexing/p/10099650.html