hdu5293 lca+dp+树状数组+时间戳

题意是给了 n 个点的树,会有m条链条 链接两个点,计算出他们没有公共点的最大价值,  公共点时这样计算的只要在他们 lca 这条链上有公共点的就说明他们相交

dp[i]为这个点包含的子树所能得到的最大价值 

sum[i]表示这个点没有选择经过i这个点链条的总价值

两种选择 

这个点没有被选择 

         dp[i]=sum[i]=sigma(dp[k])k为i的子树

选择了某个链 

        假设这条链 为(tyuijk)

       那么dp[i]=(sum[i]-dp[u]-dp[j])+(sum[j]-dp[k])+dp[k] +(sum[u]-dp[y])+(sum[y]-dp[t])+sum[t];

      整理后发现 dp[i]=sum[i] +(sum[j]-dp[j])+(sum[k]-dp[k])+(sum[u]-dp[u])+(sum[y]-dp[y])+(sum[t]-dp[t]);

使用lca计算出每条链的最近公共祖先,在这个最近公共祖先上判断是否使用这条链,还有我们可以使用时间戳加树状数组来求得sum和dp

#include <iostream>
#include <algorithm>
#include <string.h>
#include <cstdio>
#include <vector>
using namespace std;
const int maxn=100000+10;
int to[maxn*2],nx[maxn*2],H[maxn*2],numofedg,timoflook;
int fa[maxn][20],first[maxn],last[maxn],depth[maxn];
void addedg(int u, int v)
{
     numofedg++; to[numofedg]=v; nx[numofedg]=H[u]; H[u]=numofedg;
     numofedg++; to[numofedg]=u; nx[numofedg]=H[v]; H[v]=numofedg;
}
void dfs(int cur, int per, int dep)
{
    first[cur]=++timoflook;
    depth[cur]=dep;
    fa[cur][0]=per;
    for(int i=1; i<20; i++)
    {
        fa[cur][i]=fa[ fa[cur][i-1] ][ i-1 ];
    }
    for(int i=H[cur]; i; i=nx[i])
        {
            if(to[i]==per)continue;
            dfs(to[i],cur,dep+1);
        }
    last[cur]=++timoflook;
}
int getlca(int u,int v)
{
     if(depth[u]<depth[v])swap(u,v);
     for(int i=19; i>=0; i--)
        {
             if(depth[fa[u][i]]>=depth[v])
                u=fa[u][i];
             if(u==v)return u;
        }
     for(int i=19; i>=0; i--)
        {
             if(fa[u][i]!=fa[v][i])
             {
                 u=fa[u][i];
                 v=fa[v][i];
             }
        }
        return fa[u][0];
}
struct Edg
{
  int u,v,lca,val;
}P[maxn];
vector<int>E[maxn];
int dp[maxn],sum[maxn],CS[maxn*3],CD[maxn*3];
int lowbit(int x)
{
     return x&-x;
}
void add(int x, int d, int *C)
{
      while(x<=timoflook)
        {
             C[x]+=d;
             x+=lowbit(x);
        }
}
int getsum(int x, int *C)
{
     int ret=0;
      while(x>0)
        {
            ret+=C[x];
            x-=lowbit(x);
        }
        return ret;
}
void solve(int cur, int per)
{
     dp[cur]=sum[cur]=0;
     for(int i=H[cur]; i; i=nx[i])
        {
            if(to[i]==per)continue;
            solve(to[i],cur);
            sum[cur]+=dp[to[i]];
        }
     dp[cur]=sum[cur];
     for(int i=0; i<E[cur].size(); i++)
        {
              int id=E[cur][i];
              int u=P[id].u;
              int v=P[id].v;
              int t1=getsum(first[u],CS);
              int t2=getsum(first[v],CS);
              int t3=getsum(first[u],CD);
              int t4=getsum(first[v],CD);
              int tmp=t1+t2-t3-t4;
              dp[cur]=max(dp[cur],tmp+P[id].val+sum[cur]);
        }
     add(first[cur],sum[cur],CS);
     add(last[cur],-sum[cur],CS);
     add(first[cur],dp[cur],CD);
     add(last[cur],-dp[cur],CD);

}
int main()
{
    int cas;
    scanf("%d",&cas);
    for(int cc=1; cc<=cas; cc++)
        {
              int n,m;
              timoflook=numofedg=0;
              scanf("%d%d",&n,&m);
              for(int i=0; i<=n; i++)
                {
                    CS[i*2]=CS[i*2+1]=CD[i*2]=CD[i*2+1]=0;
                    H[i]=0;E[i].clear();

                }

              for(int i=1; i<n; i++)
                {
                    int u,v;
                    scanf("%d%d",&u,&v);
                    addedg(u,v);
                }
                fa[1][0]=1;
                dfs(1,1,0);
                for(int i=0; i<m; i++)
                    {
                           scanf("%d%d%d",&P[i].u,&P[i].v,&P[i].val);
                           P[i].lca=getlca(P[i].u,P[i].v);
                           E[P[i].lca].push_back(i);
                    }
                solve(1,-1);
                printf("%d
",dp[1]);
        }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/Opaser/p/4788669.html