51nod1812树的双直径(换根树DP)

传送门:http://www.51nod.com/Challenge/Problem.html#!#problemId=1812

题解:头一次写换根树DP。

求两条不相交的直径乘积最大,所以可以这样考虑:把一条边割掉,然后分别求两棵子树内的最长链乘起来就行了。由于负负得正,所以要再求一次最短链,就是把边权全部取负求一下就行了。然后就能通过dfs维护子树i内的答案dn[i]和不含以i为根的子树的答案up[i],dn[i]很好维护,重点是维护up[i],共5种可能:(1)从父亲的up继承过来(2)前后缀中的最大值f+出边+入边(3)父亲的g+兄弟节点中最大的f+出边(4)前驱/后继中的最大和次大(5)前驱/后继中的子树中的直径。然后转移状态就行了。

细节太多……还要__int128。为了方便,计算时答案用long long维护,乘起来再转long long……

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e5+7;
int n,tot,hd[N],v[N<<1],w[N<<1],nxt[N<<1],p[N<<1],len[N<<1];
ll f[N],g[N],pre[N],suf[N],dn[N],up[N];
__int128 ans;
void print(__int128 x){if(x>9)print(x/10);putchar('0'+x%10);}
void add(int x,int y,int z){v[++tot]=y,nxt[tot]=hd[x],hd[x]=tot,w[tot]=z;}
void dfs1(int u, int fa)
{
    f[u]=dn[u]=0;
    for(int i=hd[u];i;i=nxt[i])
    if(v[i]!=fa)
    {
        dfs1(v[i],u);
        dn[u]=max(dn[u],f[u]+f[v[i]]+w[i]);
        f[u]=max(f[u],f[v[i]]+w[i]);
        dn[u]=max(dn[u],dn[v[i]]);
    }
}
void dfs2(int u,int fa)
{
    int cnt=0;
    for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)p[++cnt]=v[i],len[cnt]=w[i];
    pre[0]=suf[cnt+1]=0;
    for(int i=1;i<=cnt;i++)pre[i]=max(pre[i-1],f[p[i]]+len[i]);
    for(int i=cnt;i;i--)suf[i]=max(suf[i+1],f[p[i]]+len[i]);
/*一个点向上的直径:
(1)从父亲的up继承过来
(2)前后缀中的最大值f+出边+入边
(3)父亲的g+兄弟节点中最大的f+出边
(4)前驱/后继中的最大和次大
(5)前驱/后继中的子树中的直径*/
    for(int i=1;i<=cnt;i++)
    {
        g[p[i]]=max(g[p[i]],g[u]+len[i]);
        g[p[i]]=max(g[p[i]],max(pre[i-1],suf[i+1])+len[i]);
        up[p[i]]=max(up[p[i]],up[u]);
        up[p[i]]=max(up[p[i]],pre[i-1]+suf[i+1]);
        up[p[i]]=max(up[p[i]],g[u]+max(pre[i-1],suf[i+1]));
    }
    ll mx1=-1e18,mx2=-1e18,mx=-1e18,tmp;
    for(int i=1;i<=cnt;i++)
    {
        up[p[i]]=max(up[p[i]],max(mx1+mx2,mx));
        tmp=f[p[i]]+len[i];
        if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp;
        mx=max(mx,dn[p[i]]);
    }
    mx1=mx2=mx=-1e18;
    for(int i=cnt;i;i--)
    {
        up[p[i]]=max(up[p[i]],max(mx1+mx2,mx));
        tmp=f[p[i]]+len[i];
        if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp;
        mx=max(mx,dn[p[i]]);
    }
    for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)dfs2(v[i],u);
}
int main()
{
    scanf("%d",&n);
    for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),add(x,y,z),add(y,x,z);
    dfs1(1,0),dfs2(1,0);
    for(int i=2;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]);
    for(int i=1;i<=tot;i++)w[i]=-w[i];
    memset(up,0,sizeof up); 
    memset(g,0,sizeof g);
    dfs1(1,0),dfs2(1,0);
    for(int i=2;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]);
    print(ans);
}
原文地址:https://www.cnblogs.com/hfctf0210/p/10600715.html