【HDOJ5956】The Elder(树形DP,斜率优化)

题意:有一棵n个点的有根树,每条边上有一个边权。给定P,从i跳到它的祖先j的费用是距离的平方+P,问所有点中到根节点1的总花费最大值

n<=1e5,p<=1e6,w<=1e2

思路:对于根节点到每个点i的路径上是一个下凸壳,是经典的斜率优化

考虑在dfs时维护这个下凸壳,在斜率优化加入与删除点时记录下时间戳和操作的类型,dfs结束时恢复即可

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 #include<algorithm>
  5 #include<cmath>
  6 typedef long long ll;
  7 using namespace std;
  8 #define N   210000
  9 #define oo  10000000
 10 #define MOD 1000000007
 11 
 12 struct node
 13 {
 14     int t,x,y;
 15 }stk[N];
 16 
 17 ll dp[N],s[N],P;
 18 int dep[N],head[N],vet[N],nxt[N],len[N],q[N],flag[N],n,top,tot,tim,t,w;
 19 
 20 int add(int a,int b,int c)
 21 {
 22     nxt[++tot]=head[a];
 23     vet[tot]=b;
 24     len[tot]=c;
 25     head[a]=tot;
 26 }
 27 
 28 ll sqr(ll x)
 29 {
 30     return x*x;
 31 }
 32 
 33 ll calc(int i,int j)
 34 {
 35     return dp[j]+sqr(s[i]-s[j])+P;
 36 }
 37 
 38 int cmp(int x,int y,int z)
 39 {
 40     ll x1=dp[x]-dp[y]+sqr(s[x])-sqr(s[y]);
 41     ll y1=s[x]-s[y];
 42     ll x2=dp[y]-dp[z]+sqr(s[y])-sqr(s[z]);
 43     ll y2=s[y]-s[z];
 44     return x1*y2>=x2*y1;
 45 }
 46 
 47 void dfs(int u)
 48 {
 49     tim++;
 50     flag[u]=1;
 51     if(u==1) 
 52     {
 53         t=1; w=1; dp[u]=-P; q[1]=1;
 54     }
 55      else
 56      {
 57          while(t<w&&calc(u,q[t])>=calc(u,q[t+1]))
 58         {
 59             stk[++top].t=tim; stk[top].x=1; stk[top].y=q[t];
 60             t++;
 61         }
 62         dp[u]=calc(u,q[t]);
 63         while(t<w&&cmp(q[w-1],q[w],u)) 
 64         {
 65             stk[++top].t=tim; stk[top].x=2; stk[top].y=q[w];
 66             w--;
 67         }
 68         q[++w]=u;
 69         stk[++top].t=tim; stk[top].x=3; 
 70     }
 71             
 72     int tmp=tim;
 73     int e=head[u];
 74     while(e)
 75     {
 76         int v=vet[e];
 77         if(!flag[v])
 78         {
 79             s[v]=s[u]+len[e];
 80             dfs(v);
 81         }
 82         e=nxt[e];
 83     }
 84     while(stk[top].t==tmp)
 85     {
 86         if(stk[top].x==1) q[--t]=stk[top].y;
 87         if(stk[top].x==2) q[++w]=stk[top].y;
 88         if(stk[top].x==3) w--;
 89         top--;
 90     }
 91 }
 92              
 93 int main()
 94 { 
 95     int cas;
 96     scanf("%d",&cas);
 97     while(cas--)
 98     {
 99         int n;
100         scanf("%d%d",&n,&P);
101         s[1]=0;
102         tot=0;
103         for(int i=1;i<=n;i++) head[i]=flag[i]=0;
104         for(int i=1;i<=n-1;i++) 
105         {
106             int x,y,z;
107             scanf("%d%d%d",&x,&y,&z);
108             add(x,y,z);
109             add(y,x,z);
110         }
111         tim=0;
112         t=1; w=0; top=0; 
113         dfs(1);
114         ll ans=0;
115         for(int i=2;i<=n;i++) ans=max(ans,dp[i]);
116         printf("%I64d
",ans); 
117     } 
118     return 0;
119 }
120     
原文地址:https://www.cnblogs.com/myx12345/p/9963703.html