description
给定两棵树,节点数分别为(n)、(m),请你在两棵树间连接一条路径使得形成的新大树中由(N=n+m)个点两两组成的(frac{N(N-1)}{2})个点对距离之和最小.
solution
通过大眼观察法,我们可以得出连接处两端点分别是两棵树的重心,所以可以进行树形dp.然后(Omicron(n+m))进行扫描求解答案即可.
至于求解答案的方法,我们可以注意到,对于每一个点(x)到其父节点的边,会有(size[x])个点在其下方,(n+m-size[x])个点在其上方,这些点两两构成点对会经过这条边,于是这条边的贡献即为(size[x]*(n+m-size[x]))
另外值得注意的是,本题点数较多,极端情况下递归dfs会爆栈,所以要手写栈实现非递归式dfs,从来没写过的我瞎yy出来一种,凑合看吧.
code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#define R register
#define next MabLcdG
#define mod 1
#define debug puts("mlg")
#define Mod(x) ((x%mod+mod)%mod)
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
inline ll read();
inline void write(ll x);
inline void writesp(ll x);
inline void writeln(ll x);
const ll maxn=360000;
ll to[maxn<<2],head[maxn<<1],next[maxn<<2],tot,w[maxn<<2];
inline void add(ll x,ll y){to[++tot]=y;next[tot]=head[x];head[x]=tot;}
ll n,m;
ll siz[maxn<<1],maxs[maxn<<1];
ll ans;
ll root1,root2;
ll _stack[2000000],h,fa[2000000];
bool book[2000000];
inline void dfs1(ll begin,bool type){
_stack[++h]=begin;
while(h){
ll x=_stack[h];
if(!book[x]){
siz[x]=1;
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver==fa[x]) continue;
fa[ver]=x;
_stack[++h]=ver;
}
book[x]=true;
continue;
}
else{
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver==fa[x]) continue;
siz[x]+=siz[ver];
maxs[x]=max(maxs[x],siz[ver]);
}
maxs[x]=max(maxs[x],(type?m:n)-siz[x]);
if(maxs[x]<maxs[type?root2:root1]) type?(root2=x):(root1=x);
book[x]=false;
--h;
}
}
}
inline void dfs2(ll begin){
_stack[++h]=begin;
while(h){
ll x=_stack[h];
if(!book[x]){
siz[x]=1;
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver==fa[x]) continue;
_stack[++h]=ver;fa[ver]=x;
}
book[x]=true;
continue;
}
else{
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver==fa[x]) continue;
siz[x]+=siz[ver];
}
ans+=siz[x]*(n+m-siz[x]);
--h;
book[x]=false;
}
}
}
int main(){
freopen("unite.in","r",stdin);
freopen("unite.out","w",stdout);
maxs[0]=((ull)1<<63)-1;
n=read();m=read();
for(R ll i=1,x,y;i<n;i++){
x=read();y=read();
add(x,y);add(y,x);
}
for(R ll i=1,x,y;i<m;i++){
x=read()+n,y=read()+n;
add(x,y);add(y,x);
}
dfs1(1,0);dfs1(n+1,1);
add(root1,root2);add(root2,root1);
dfs2(1);
writeln(ans);
}
inline ll read(){ll x=0,t=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') t=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*t;}
inline void write(ll x){if(x<0){putchar('-');x=-x;}if(x<=9){putchar(x+'0');return;}write(x/10);putchar(x%10+'0');}
inline void writesp(ll x){write(x);putchar(' ');}
inline void writeln(ll x){write(x);putchar('
');}
//inline void dfs1(ll x,ll fa,bool type){
// siz[x]=1;
// for(R ll i=head[x],ver;i;i=next[i]){
// ver=to[i];
// if(ver==fa) continue;
// dfs1(ver,x,type);
// siz[x]+=siz[ver];
// maxs[x]=max(maxs[x],siz[ver]);
// }
// maxs[x]=max(maxs[x],(type?m:n)-siz[x]);
// if(maxs[x]<maxs[type?root2:root1]){
// if(type) root2=x;
// else root1=x;
// }
//}
//inline void dfs2(ll x,ll fa){
// siz[x]=1;
// for(R ll i=head[x],ver;i;i=next[i]){
// ver=to[i];
// if(ver==fa) continue;
// dfs2(ver,x);
// siz[x]+=siz[ver];
// }
// ans+=siz[x]*(n+m-siz[x]);
//}