hdu4812 D Tree

题意:

给定一颗树,求一个字典序最小的路径,路径的权值是点权的累成mod(1e6+3),要求点权的乘积为k。

题解:

树上路径,用点分治可以做。按照点分治的套路来,x * y = k 那么枚举x = k * inv[y] % mod,为刚开始用来set来求最小,被t飞了, 看了题解发现可以用树组优化掉一个log。

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 #define N 100010
  4 #define M 1000010
  5 #define mod 1000003
  6 typedef long long ll;
  7 typedef pair<int, int> PII;
  8 inline int read()
  9 {
 10     char ch = getchar();int x = 0, f = 1;
 11     while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
 12     while(ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) - '0' + ch; ch = getchar();}
 13     return x * f;
 14 }
 15 
 16 int n, k, size[N], rt, tot, st[N], mx[N],a[N];
 17 
 18 int ans1, ans2;
 19 int inv[M];
 20 int h[N], e[N * 2], ne[N * 2], idx;
 21 ll tmp[N], top, id[N];
 22 ll mp[M];
 23 
 24 void add(int a, int b)
 25 {
 26     e[idx] = b;
 27     ne[idx] = h[a];
 28     h[a] = idx ++;
 29 }
 30 
 31 void getrt(int u, int fa)
 32 {
 33     mx[u] = 0; size[u] = 1;
 34     for (int i = h[u]; ~i; i = ne[i])
 35     {
 36         int j = e[i];
 37         if(j == fa || st[j]) continue;
 38         getrt(j, u);
 39         size[u] += size[j];
 40         mx[u] = max(mx[u], size[j]);
 41     }
 42     mx[u] = max(mx[u], tot - size[u]);
 43     if(mx[rt] > mx[u])
 44         rt = u;
 45 }
 46 
 47 void dfs(int u, int fa, int t)
 48 {
 49     t = 1ll * t * a[u] % mod;
 50     tmp[++ top] = t; id[top] = u;
 51    for (int i = h[u]; ~i; i = ne[i])
 52     {
 53         int j = e[i];
 54         if(j == fa || st[j]) continue;
 55         dfs(j, u, t);
 56     }
 57 }
 58 
 59 void up(int x, int y)
 60 {
 61     x = 1ll * inv[x] * k % mod;
 62     x = mp[x];
 63     if(x == 0) return ;
 64     if(x > y) swap(x, y);
 65     if(ans1 > x)
 66     {
 67         ans1 = x;
 68         ans2 = y;
 69     }
 70     else if(ans1 == x && y < ans2)
 71     {
 72         ans2 = y;
 73     }
 74 }
 75 
 76 void calc(int u, int fa)
 77 {
 78     mp[a[u]] = u;
 79     for (int i = h[u]; ~i; i = ne[i])
 80     {
 81         int j = e[i];
 82         if(j == fa || st[j]) continue;
 83         top = 0;
 84         dfs(j, u, 1);
 85         for (int z = 1; z <= top; z ++)
 86             up(tmp[z], id[z]);
 87         top = 0;
 88         dfs(j, u, a[u]);
 89         for (int z = 1; z <= top; z ++)
 90         {
 91             int now = mp[tmp[z]];
 92             if(!now || id[z] < now)
 93                 mp[tmp[z]] = id[z];
 94         }
 95     }
 96     mp[a[u]] = 0;
 97     for (int i = h[u]; ~i; i = ne[i])
 98     {
 99         int j = e[i];
100         if(j == fa || st[j]) continue;
101         top = 0;
102         dfs(j, u, a[u]);
103         for (int z = 1; z <= top; z ++)
104             mp[tmp[z]] = 0;
105     }
106 }
107 
108 void solve(int u)
109 {
110     st[u] = 1;
111     calc(u, 0);
112     for (int i = h[u]; ~i; i = ne[i])
113     {
114         int j = e[i];
115         if(st[j]) continue;
116         rt = 0;
117         tot = size[j];
118         getrt(j, 0);
119         solve(rt);
120     }
121 }
122 
123 int main()
124 {
125    inv[1]=1;
126    for(int i=2;i< M;i++)
127            inv[i]= 1ll * (mod-mod/i)*inv[mod%i]%mod;
128     while(scanf("%d%d", &n, &k) != EOF)
129     {
130         ans1 = 1e9; ans2 = 1e9;
131         idx = 0;
132         for (int i = 1; i <= n; i ++)  st[i] = 0, h[i] = -1;
133         for (int i = 1; i <= n; i ++)
134             a[i] = read();
135         for (int i = 1; i <= n - 1; i ++)
136         {
137             int a = read(), b = read();
138             add(a, b); add(b, a);
139         }
140         rt = 0; tot = mx[rt] = n;
141         getrt(1, 0);
142         solve(rt);
143         if(ans1 == 1e9)
144         {
145             puts("No solution");
146         }
147         else
148         printf("%d %d
", ans1, ans2);
149     }
150 }
View Code
原文地址:https://www.cnblogs.com/xwdzuishuai/p/14093711.html