暴力写挂

题目描述

题解

考虑把式子化一下,因为只有一个式子跟第二棵树有关,所以我们可以考虑把前面的式子化成跟 $ ext{lca}$ 没有关系,即 $frac{1}{2}(dp_u+dp_v+dis(u,v))$ 。因此我们可以利用边分治,每次把两边的点黑白染色,构成虚树,然后做 $ ext{dp}$ 即可。这里要注意 $ ext{lca}$ 要 $O(1)$ 求,虚树构成过程中不能排序,故我们可以一开始就按照第二棵树的dfs排序好,之后分治下去即可。效率: $O(nlogn)$ 。

代码

#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int N=4e5+5,N2=N<<1,N4=N<<2;
int n,m,t=1,fa[22][N2],Lg[N2],d[N],e[N],o,rt,su,hd[N2],sz[N2];
int V[N4],W[N4],nx[N4],b[N],id[N],col[N],tp,S[N],h[2][N],c;
LL dp[N],Dp[N],sm[N],f[2][N],ans=-2e18;
bool vis[N2];vector<int>X[N],Y[N];
void Add(int u,int v,int w){
    X[u].push_back(v);Y[u].push_back(w);
}
void add(int u,int v,int w){
    nx[++t]=hd[u];V[hd[u]=t]=v;W[t]=w;
}
void add(int u,int v){X[u].push_back(v);}
void rebuild(int u,int fr){
    int x=0,z=X[u].size();
    for (int v,w,i=0;i<z;i++){
        v=X[u][i];w=Y[u][i];
        if (v==fr) continue;dp[v]=dp[u]+w;
        if (!x) add(u,v,w),add(v,u,w),x=u;
        else m++,add(x,m,0),add(m,x,0),
            add(m,v,w),add(v,m,w),x=m;
        rebuild(v,u);
    }
}
void dfs(int u,int fr){
    int z=X[u].size();
    fa[0][e[u]=++c]=u;b[id[u]=++t]=u;
    for (int v,w,i=0;i<z;i++){
        v=X[u][i];w=Y[u][i];
        if (v==fr) continue;
        Dp[v]=Dp[u]+w;d[v]=d[u]+1;
        dfs(v,u);fa[0][++c]=u;
    }
}
int Min(int u,int v){return d[u]<d[v]?u:v;}
int qry(int l,int r){
    l=e[l];r=e[r];
    if (l>r) swap(l,r);int i=Lg[r-l+1];
    return Min(fa[i][l],fa[i][r-(1<<i)+1]);
}
void Sz(int u,int fr){
    sz[u]=1;
    for (int v,i=hd[u];i;i=nx[i])
        if (!vis[i>>1] && (v=V[i])!=fr)
            Sz(v,u),sz[u]+=sz[v];
}
void Rt(int u,int fr){
    for (int v,w,i=hd[u];i;i=nx[i])
        if (!vis[i>>1] && (v=V[i])!=fr){
            w=max(sz[v],o-sz[v]);
            if (w<su) rt=i,su=w;Rt(v,u);
        }
}
void find(int u,int fr,LL w,int cl){
    if (u<=n) col[u]=cl,sm[u]=w;
    for (int v,i=hd[u];i;i=nx[i])
        if (!vis[i>>1] && (v=V[i])!=fr)
            find(V[i],u,w+W[i],cl);
}
void ins(int u){
    if (tp<1){S[++tp]=u;return;}
    int x=qry(S[tp],u);
    if (x==S[tp]){S[++tp]=u;return;}
    while(tp>1 && id[S[tp-1]]>=id[x])
        add(S[tp-1],S[tp]),tp--;
    if (x!=S[tp]) add(x,S[tp]),S[tp]=x;
    S[++tp]=u;
}
void get(int u){
    int z=X[u].size();
    f[0][u]=f[1][u]=-2e18;
    if (~col[u]) f[col[u]][u]=dp[u]+sm[u];
    for (int v,i=0;i<z;i++){
        v=X[u][i];get(v);
        for (int j=0;j<2;j++)
            ans=max(ans,f[j][u]+f[!j][v]-2ll*Dp[u]);
        for (int j=0;j<2;j++)
            f[j][u]=max(f[j][u],f[j][v]);
    }
    X[u].clear();
}
void solve(int u,int l,int r){
    Sz(u,0);o=sz[u];rt=0;su=1e9;
    Rt(u,0);if (!rt) return;
    int x=V[rt],y=V[rt^1];vis[rt>>1]=1;
    find(x,y,0,0);find(y,x,W[rt],1);
    if (b[l]!=1) ins(1);
    for (int i=l;i<=r;i++) ins(b[i]);
    while(tp>1) add(S[tp-1],S[tp]),tp--;
    tp=0;get(1);int v[2]={0};
    for (int w,i=l;i<=r;i++)
        w=col[b[i]],h[w][++v[w]]=b[i],col[b[i]]=-1;
    for (int i=0;i<v[0];i++) b[i+l]=h[0][i+1];
    for (int i=0;i<v[1];i++) b[r-i]=h[1][v[1]-i];
    solve(x,l,l+v[0]-1);solve(y,r-v[1]+1,r);
}
int main(){
    cin>>n;m=n;
    for (int i=1,x,y,z;i<n;i++)
        scanf("%d%d%d",&x,&y,&z),
        Add(x,y,z),Add(y,x,z);rebuild(1,0);
    for (int i=1;i<=n;i++)
        X[i].clear(),Y[i].clear(),col[i]=-1;t=0;
    for (int i=1,x,y,z;i<n;i++)
        scanf("%d%d%d",&x,&y,&z),
        Add(x,y,z),Add(y,x,z);dfs(1,0);
    for (int i=2;i<=c;i++) Lg[i]=Lg[i>>1]+1;
    for (int i=c;i;i--)
        for (int j=1;i+(1<<j)<=c+1;j++)
            fa[j][i]=Min(fa[j-1][i],fa[j-1][i+(1<<(j-1))]);
    for (int i=1;i<=n;i++) X[i].clear();
    solve(1,1,n);ans>>=1;
    for (int i=1;i<=n;i++)
        ans=max(ans,dp[i]-Dp[i]);
    cout<<ans<<endl;return 0;
}
原文地址:https://www.cnblogs.com/xjqxjq/p/12368634.html