[hdu5293]Tree chain problem

记一条链两端的lca为该链的lca,f[i]表示所有lca在i子树内的链的最大价值和(为方便递推,假设存在i-i的链价值为0),有递推式$f[i]=max(sum_{son}f[son]+val)$(其中son是链上所有点的儿子且不再链上的节点,val表示该链的价值),时间复杂度为$o(n^{2})$
考虑快速计算f,记$sum[i]=sum_{son}f[son]$(这里的son是i的儿子),那么答案变为$sum_{k}sum[k]-f[k]$(其中k是链上的点),如果把sum[i]-f[i]看成一个整体,那么相当于支持单点修改,链求和,用树剖维护,时间复杂度两个log
实际上还可以优化,先对链差分,变为询问某个点当根(1)的和,然后变为子树修改单点查询,可以用线段树+dfs序来维护,时间复杂度降为一个log

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 100005
  4 #define ll long long
  5 #define L (k<<1)
  6 #define R (L+1)
  7 #define mid (l+r>>1)
  8 struct ji{
  9     int nex,to;
 10 }edge[N<<1];
 11 vector<int>v[N];
 12 int E,n,m,x,y,z,head[N],s[N],sz[N],id[N],xx[N],yy[N],zz[N],f[N][21];
 13 ll d[N],dp[N],sum[N],tr[N<<2],laz[N<<2];
 14 void add(int x,int y){
 15     edge[E].nex=head[x];
 16     edge[E].to=y;
 17     head[x]=E++;
 18 }
 19 int lca(int x,int y){
 20     if (s[x]<s[y])swap(x,y);
 21     for(int i=20;i>=0;i--)
 22         if (s[f[x][i]]>=s[y])x=f[x][i];
 23     if (x==y)return x;
 24     for(int i=20;i>=0;i--)
 25         if (f[x][i]!=f[y][i]){
 26             x=f[x][i];
 27             y=f[y][i];
 28         }
 29     return f[x][0];
 30 }
 31 void dfs(int k,int fa){
 32     sz[k]=1;
 33     id[k]=++x;
 34     s[k]=s[fa]+1;
 35     f[k][0]=fa;
 36     for(int i=1;i<=20;i++)f[k][i]=f[f[k][i-1]][i-1];
 37     for(int i=head[k];i!=-1;i=edge[i].nex)
 38         if (edge[i].to!=fa){
 39             dfs(edge[i].to,k);
 40             sz[k]+=sz[edge[i].to];
 41         }
 42 }
 43 void upd(int k,int l,int r,ll x){
 44     laz[k]+=x;
 45     tr[k]+=x*(r-l+1);
 46 }
 47 void down(int k,int l,int r){
 48     upd(L,l,mid,laz[k]);
 49     upd(R,mid+1,r,laz[k]);
 50     laz[k]=0;
 51 }
 52 void update(int k,int l,int r,int x,int y,ll z){
 53     if ((l>y)||(x>r))return;
 54     if ((x<=l)&&(r<=y)){
 55         upd(k,l,r,z);
 56         return;
 57     }
 58     down(k,l,r);
 59     update(L,l,mid,x,y,z);
 60     update(R,mid+1,r,x,y,z);
 61     tr[k]=tr[L]+tr[R];
 62 }
 63 ll query(int k,int l,int r,int x){
 64     if (l==r)return tr[k];
 65     down(k,l,r);
 66     if (x<=mid)return query(L,l,mid,x);
 67     return query(R,mid+1,r,x);
 68 }
 69 void dfs2(int k,int fa){
 70     for(int i=head[k];i!=-1;i=edge[i].nex)
 71         if (edge[i].to!=fa){
 72             dfs2(edge[i].to,k);
 73             sum[k]+=dp[edge[i].to];
 74         }
 75     for(int i=0;i<v[k].size();i++)
 76         dp[k]=max(dp[k],d[v[k][i]]+query(1,1,n,id[xx[v[k][i]]])+query(1,1,n,id[yy[v[k][i]]]));
 77     update(1,1,n,id[k],id[k]+sz[k]-1,-dp[k]);
 78     dp[k]+=sum[k];
 79 }
 80 int main(){
 81     scanf("%*d");
 82     while (scanf("%d%d",&n,&m)!=EOF){
 83         E=0;
 84         memset(tr,0,sizeof(tr));
 85         memset(laz,0,sizeof(laz));
 86         memset(dp,0,sizeof(dp));
 87         memset(sum,0,sizeof(sum));
 88         memset(head,-1,sizeof(head));
 89         for(int i=1;i<=n;i++)v[i].clear();
 90         for(int i=1;i<n;i++){
 91             scanf("%d%d",&x,&y);
 92             add(x,y);
 93             add(y,x);
 94         }
 95         x=0;
 96         dfs(1,1);
 97         for(int i=1;i<=m;i++){
 98             scanf("%d%d%lld",&xx[i],&yy[i],&d[i]);
 99             zz[i]=lca(xx[i],yy[i]);
100             v[zz[i]].push_back(i);
101         }
102         dfs2(1,0);
103         printf("%lld
",dp[1]);
104     }
105 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/11803303.html