D Tree HDU

https://vjudge.net/problem/HDU-4812

点分就没一道不卡常的?

卡常记录:

1.求逆元忘开longlong

2.把solve中分离各个子树的方法,由“一开始全部加入,处理某个子树前先删除该子树”,变为“逐渐加入,每一次加入某个子树之前处理该子树,不用删除“(由于点对是无序的,只要一个方向处理过就行了,另一个方向不需要处理)

  1 //%:pragma GCC optimize(3)
  2 #include<cstdio>
  3 #include<algorithm>
  4 #include<cstring>
  5 #include<vector>
  6 #include<set>
  7 using namespace std;
  8 #define fi first
  9 #define se second
 10 #define mp make_pair
 11 #define pb push_back
 12 typedef long long ll;
 13 typedef unsigned long long ull;
 14 typedef pair<int,int> pi;
 15 #define md 1000003
 16 struct E
 17 {
 18     int to,nxt;
 19 }e[200100];
 20 int f1[100100],ne;
 21 int sz[100100],fx[100100],a[100100],d[100100];
 22 bool vis[100100];
 23 int root,sum;pi ans;
 24 set<pi> s;
 25 int inv[1001000];
 26 int n,K;
 27 void getroot(int x,int fa)
 28 {
 29     sz[x]=1;fx[x]=0;
 30     for(int k=f1[x];k;k=e[k].nxt)
 31         if(!vis[e[k].to]&&e[k].to!=fa)
 32         {
 33             getroot(e[k].to,x);
 34             sz[x]+=sz[e[k].to];
 35             fx[x]=max(fx[x],sz[e[k].to]);
 36         }
 37     fx[x]=max(fx[x],sum-sz[x]);
 38     if(fx[x]<fx[root])    root=x;
 39 }
 40 void getsz(int x,int fa)
 41 {
 42     sz[x]=1;
 43     for(int k=f1[x];k;k=e[k].nxt)
 44         if(!vis[e[k].to]&&e[k].to!=fa)
 45         {
 46             getsz(e[k].to,x);
 47             sz[x]+=sz[e[k].to];
 48         }
 49 }
 50 void getd(int u,int fa)
 51 {
 52     for(int k=f1[u];k;k=e[k].nxt)
 53         if(!vis[e[k].to]&&e[k].to!=fa)
 54         {
 55             d[e[k].to]=ll(d[u])*a[e[k].to]%md;
 56             getd(e[k].to,u);
 57         }
 58 }
 59 void addd(int u,int fa)
 60 {
 61     s.insert(mp(d[u],u));
 62     for(int k=f1[u];k;k=e[k].nxt)
 63         if(!vis[e[k].to]&&e[k].to!=fa)
 64             addd(e[k].to,u);
 65 }
 66 void deld(int u,int fa)
 67 {
 68     s.erase(mp(d[u],u));
 69     for(int k=f1[u];k;k=e[k].nxt)
 70         if(!vis[e[k].to]&&e[k].to!=fa)
 71             deld(e[k].to,u);
 72 }
 73 void calc(int u,int fa)
 74 {
 75     int t=ll(K)*inv[ll(d[u])*a[root]%md]%md;
 76     auto it=s.lower_bound(mp(t,0));
 77     if(it!=s.end()&&it->fi==t)
 78     {
 79         int a=u,b=it->se;
 80         if(a>b)    swap(a,b);
 81         ans=min(ans,mp(a,b));
 82     }
 83     for(int k=f1[u];k;k=e[k].nxt)
 84         if(!vis[e[k].to]&&e[k].to!=fa)
 85             calc(e[k].to,u);
 86 }
 87 void solve(int u)
 88 {
 89     s.clear();d[u]=1;s.insert(mp(1,u));vis[u]=1;
 90     for(int k=f1[u];k;k=e[k].nxt)
 91         if(!vis[e[k].to])
 92         {
 93             d[e[k].to]=a[e[k].to];getd(e[k].to,u);
 94             calc(e[k].to,u);
 95             addd(e[k].to,u);
 96         }
 97     for(int k=f1[u];k;k=e[k].nxt)
 98         if(!vis[e[k].to])
 99         {
100             root=0;
101             getsz(e[k].to,0);sum=sz[e[k].to];
102             getroot(e[k].to,0);
103             solve(root);
104         }
105 }
106 int main()
107 {
108     int i,x,y;
109     inv[1]=1;
110     for(i=2;i<=1000100;i++)    inv[i]=ll(md-md/i)*inv[md%i]%md;
111     fx[0]=0x3f3f3f3f;
112     while(scanf("%d%d",&n,&K)==2)
113     {
114         for(i=1;i<=n;i++)    f1[i]=vis[i]=0;
115         ne=0;ans=mp(n+1,n+1);
116         for(i=1;i<=n;i++)    scanf("%d",&a[i]);
117         for(i=1;i<n;i++)
118         {
119             scanf("%d%d",&x,&y);
120             e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
121             e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
122         }
123         root=0;sum=n;getroot(1,0);
124         solve(root);
125         if(ans==mp(n+1,n+1))    puts("No solution");
126         else    printf("%d %d
",ans.fi,ans.se);
127     }
128     return 0;
129 }
原文地址:https://www.cnblogs.com/hehe54321/p/9285605.html