洛谷P4719 【模板】动态dp

https://www.luogu.org/problemnew/show/P4719

大概就是一条链一条链的处理(“链”在这里指重链),对于每一条链,对于其上每一个点,先算出它自身和所有轻儿子的贡献,当做这一步中这个点的“权值”,然后就变成序列上dp,直接用线段树维护

线段树版本O(n*log^2)

  1 #include<cstdio>
  2 #include<algorithm>
  3 #include<cstring>
  4 using namespace std;
  5 typedef long long ll;
  6 struct E
  7 {
  8     int to,nxt;
  9 }e[200011];
 10 int f1[100011],ne;
 11 struct P1
 12 {
 13     ll d[2][2];//左侧不选/选,右侧不选/选
 14 };
 15 struct P2
 16 {
 17     ll d[2];//自身不选/选
 18 };
 19 ll a[100101];
 20 int sz[100101],hson[100101],ff[100101];
 21 int b[100101],pl[100101];
 22 int n,m;
 23 inline ll max1(ll a,ll b)
 24 {
 25     return a>b?a:b;
 26 }
 27 const ll inf1=-0x3f3f3f3f3f3f3f3f;
 28 #define max max1
 29 #define G(x) max1((x),inf1)
 30 inline void merge(P1 &c,const P1 &a,const P1 &b)
 31 {
 32     c.d[0][0]=G(max(a.d[0][0]+max(b.d[1][0],b.d[0][0]),
 33         a.d[0][1]+b.d[0][0]));
 34     c.d[0][1]=G(max(a.d[0][0]+max(b.d[1][1],b.d[0][1]),
 35         a.d[0][1]+b.d[0][1]));
 36     c.d[1][0]=G(max(a.d[1][0]+max(b.d[1][0],b.d[0][0]),
 37         a.d[1][1]+b.d[0][0]));
 38     c.d[1][1]=G(max(a.d[1][0]+max(b.d[1][1],b.d[0][1]),
 39         a.d[1][1]+b.d[0][1]));
 40 }
 41 inline void initnode(P1 &c,const P2 &a)
 42 {
 43     c.d[0][0]=a.d[0];c.d[1][1]=a.d[1];
 44     c.d[0][1]=c.d[1][0]=inf1;
 45 }
 46 namespace S
 47 {
 48 #define lc (num<<1)
 49 #define rc (num<<1|1)
 50     P1 d[400101];
 51     inline void upd(int num){merge(d[num],d[lc],d[rc]);}
 52     P1 x;int L;
 53     void _setx(int l,int r,int num)
 54     {
 55         if(l==r)
 56         {
 57             d[num]=x;
 58             return;
 59         }
 60         int mid=(l+r)>>1;
 61         if(L<=mid)    _setx(l,mid,lc);
 62         else    _setx(mid+1,r,rc);
 63         upd(num);
 64     }
 65     P1 getx(int L,int R,int l,int r,int num)
 66     {
 67         if(L<=l&&r<=R)    return d[num];
 68         int mid=(l+r)>>1;
 69         if(L<=mid&&mid<R)
 70         {
 71             P1 x;
 72             merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+1,r,rc));
 73             return x;
 74         }
 75         else if(L<=mid)
 76             return getx(L,R,l,mid,lc);
 77         else if(mid<R)
 78             return getx(L,R,mid+1,r,rc);
 79         else
 80             exit(-1);
 81     }
 82 }
 83 void dfs1(int u,int fa)
 84 {
 85     sz[u]=1;
 86     for(int v,k=f1[u];k;k=e[k].nxt)
 87         if(e[k].to!=fa)
 88         {
 89             v=e[k].to;
 90             ff[v]=u;
 91             dfs1(v,u);
 92             sz[u]+=sz[v];
 93             if(sz[v]>sz[hson[u]])    hson[u]=v;
 94         }
 95 }
 96 P2 d1[100101];//d1[i]维护i节点及其轻儿子的贡献
 97 P2 d2[100101];//d2[i]维护i节点(是重链顶)所在重链的dp值
 98 int tp[100101],dwn[100101];//链顶,链底
 99 void dfs2(int u,int fa)
100 {
101     d1[u].d[0]=0;d1[u].d[1]=a[u];
102     b[++b[0]]=u;pl[u]=b[0];
103     tp[u]=(u==hson[fa])?tp[fa]:u;
104     if(hson[u])    dfs2(hson[u],u);
105     dwn[u]=hson[u]?dwn[hson[u]]:u;
106     int v,k;
107     for(k=f1[u];k;k=e[k].nxt)
108         if(e[k].to!=fa&&e[k].to!=hson[u])
109         {
110             v=e[k].to;
111             dfs2(v,u);
112             d1[u].d[0]+=max(d2[v].d[0],d2[v].d[1]);
113             d1[u].d[1]+=d2[v].d[0];
114         }
115     initnode(S::x,d1[u]);S::L=pl[u];S::_setx(1,n,1);
116     if(u==tp[u])
117     {
118         P1 t=S::getx(pl[u],pl[dwn[u]],1,n,1);
119         d2[u].d[0]=max(t.d[0][0],t.d[0][1]);
120         d2[u].d[1]=max(t.d[1][0],t.d[1][1]);
121     }
122 }
123 int main()
124 {
125     int i,x,y;ll z;P1 t;
126     scanf("%d%d",&n,&m);
127     for(i=1;i<=n;++i)    scanf("%lld",a+i);
128     for(i=1;i<n;++i)
129     {
130         scanf("%d%d",&x,&y);
131         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
132         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
133     }
134     dfs1(1,0);
135     dfs2(1,0);
136     while(m--)
137     {
138         scanf("%d%lld",&x,&z);
139         d1[x].d[1]-=a[x];a[x]=z;d1[x].d[1]+=z;
140         while(x)
141         {
142             initnode(S::x,d1[x]);S::L=pl[x];S::_setx(1,n,1);
143             x=tp[x];y=ff[x];
144             t=S::getx(pl[x],pl[dwn[x]],1,n,1);
145             d1[y].d[0]-=max(d2[x].d[0],d2[x].d[1]);
146             d1[y].d[1]-=d2[x].d[0];
147             d2[x].d[0]=max(t.d[0][0],t.d[0][1]);
148             d2[x].d[1]=max(t.d[1][0],t.d[1][1]);
149             d1[y].d[0]+=max(d2[x].d[0],d2[x].d[1]);
150             d1[y].d[1]+=d2[x].d[0];
151             x=y;
152         }
153         //printf("%lld %lld
",d2[1].d[0],d2[1].d[1]);
154         printf("%lld
",max(d2[1].d[0],d2[1].d[1]));
155     }
156     return 0;
157 }
View Code

bst版本O(n*log)

待写

原文地址:https://www.cnblogs.com/hehe54321/p/10087583.html