hdu4670(树上点分治+状态压缩)

树上路径的f(u,v)=路径上所有点的乘积。

树上每个点的权值都是由给定的k个素数组合而成的,如果f(u,v)是立方数,那么就说明f(u,v)是可行的方案。

问有多少种可行的方案。

f(u,v)可是用状态压缩来表示,因为最多只有30个素数, 第i位表示第i个素数的幂,那么每一位的状态只有0,1,2因为3和0是等价的,所以用3进制状态来表示就行了。

其他代码就是裸的树分。

另外要注意的是,因为counts函数没有统计只有一个点的情况,所以需要另外统计。

  1 #pragma warning(disable:4996)
  2 #pragma comment(linker, "/STACK:1024000000,1024000000")
  3 #include <stdio.h>
  4 #include <string.h>
  5 #include <time.h>
  6 #include <math.h>
  7 #include <map>
  8 #include <set>
  9 #include <queue>
 10 #include <stack>
 11 #include <vector>
 12 #include <bitset>
 13 #include <algorithm>
 14 #include <iostream>
 15 #include <string>
 16 #include <functional>
 17 #include <unordered_map>
 18 const int INF = 1 << 30;
 19 typedef __int64 LL;
 20 /*
 21 用三进制的每一位表示第i个素数的幂
 22 如果幂都是0,那么说明是立方
 23 */
 24 const int N = 50000 + 10;
 25 std::vector<int> g[N];
 26 std::unordered_map<LL, int> mp;
 27 struct Node
 28 {
 29     int sta[33];
 30 }node[N];
 31 LL prime[33];
 32 std::vector<Node> dist;
 33 int n, k;
 34 int size[N], vis[N], total, root, mins;
 35 LL _3bit[33];
 36 void init()
 37 {
 38     _3bit[0] = 1;
 39     for (int i = 1;i <= 32;++i)
 40         _3bit[i] = _3bit[i - 1] * 3;
 41 }
 42 void getRoot(int u, int fa)
 43 {
 44     int maxs = 0;
 45     size[u] = 1;
 46     for (int i = 0;i < g[u].size();++i)
 47     {
 48         int v = g[u][i];
 49         if (v == fa || vis[v]) continue;
 50         getRoot(v, u);
 51         size[u] += size[v];
 52         maxs = std::max(maxs, size[v]);
 53     }
 54     maxs = std::max(maxs, total - size[u]);
 55     if (mins > maxs)
 56     {
 57         mins = maxs;
 58         root = u;
 59     }
 60 }
 61 void getDis(int u, int fa, Node d)
 62 {
 63     dist.push_back(d);
 64     for (int i = 0;i < g[u].size();++i)
 65     {
 66         int v = g[u][i];
 67         if (v == fa || vis[v]) continue;
 68         Node tmp;
 69         for (int j = 0;j < k;++j)
 70             tmp.sta[j] = (d.sta[j] + node[v].sta[j]) % 3;
 71         getDis(v, u, tmp);
 72     }
 73 }
 74 LL counts(int u)//计算经过u点的路径
 75 {
 76     mp.clear();
 77     mp[0] = 1;
 78     LL ret = 0;
 79     for (int i = 0;i < g[u].size();++i)
 80     {
 81         int v = g[u][i];
 82         if (vis[v]) continue;
 83         dist.clear();
 84         getDis(v, u, node[v]);
 85         for (int j = 0;j < dist.size();++j)
 86         {
 87             LL sta = 0;
 88             for (int z = 0;z < k;++z)
 89             {
 90                 sta += (3 - (node[u].sta[z] + dist[j].sta[z]) % 3) % 3 * _3bit[z];
 91             }
 92             ret += mp[sta];
 93         }
 94         for (int j = 0;j < dist.size();++j)
 95         {
 96             LL sta = 0;
 97             for (int z = 0;z < k;++z)
 98                 sta += dist[j].sta[z] * _3bit[z];
 99             mp[sta]++;
100         }
101     }
102     return ret;
103 }
104 LL ans;
105 void go(int u)
106 {
107     vis[u] = true;
108     ans += counts(u);
109     for (int i = 0;i < g[u].size(); ++i)
110     {
111         int v = g[u][i];
112         if (vis[v]) continue;
113         total = size[v];
114         mins = INF;
115         getRoot(v, u);
116         go(root);
117     }
118     
119 }
120 int main()
121 {
122     int u, v;
123     LL x;
124     init();
125     while (scanf("%d%d", &n, &k) != EOF)
126     {
127         for (int i = 0;i < k;++i)
128             scanf("%I64d", &prime[i]);
129         ans = 0;
130         for (int i = 1;i <= n;++i)
131         {
132             g[i].clear();
133             vis[i] = 0;
134             scanf("%I64d", &x);
135             memset(node[i].sta, 0, sizeof(node[i].sta));
136             int tmp = 0;
137             for (int j = 0;j <k;++j)
138             {
139                 
140                 while (x%prime[j] == 0 && x)
141                 {
142                     node[i].sta[j]++;
143                     x /= prime[j];
144                 }
145                 node[i].sta[j] %= 3;
146                 if (node[i].sta[j] != 0)tmp++;
147             }
148             if (tmp == 0)//统计只有一个点的
149                 ans++;
150         }
151         for (int i = 1;i < n;++i)
152         {
153             scanf("%d%d", &u, &v);
154             g[u].push_back(v);
155             g[v].push_back(u);
156         }
157         total = n;
158         mins = INF;
159         getRoot(1, -1);
160         go(root);
161         printf("%I64d
", ans);
162     }
163     return 0;
164 }
View Code
原文地址:https://www.cnblogs.com/justPassBy/p/4770566.html