题目大意
有一棵树,每条边都有一个边权,现在你要修改边权,使得修改后根到所有叶子的距离相等。
要求所有边权非负。
修改的代价为(lvert)每条边修改前的边权(-)修改后的边权( vert)之和。
(n+mleq 300000)
题解
容易发现,设 (f(x)) 为根到所有叶子的距离为 (x) 时的最小代价,那么 (f(x))是一个下凸函数,并且每一段都是线性的。
考虑一个点 (u) 从儿子 (v) 转移过来。这个过程分两步:
把 (v) 的凸包加上 (u o v) 这条边:
要从 (f(x)) 转移到 (f'(x))
假设原来 (f(x)) 的最小值是在 ([l,r]) 时取到的,那么:
(xleq l):(f'(x)=f(x)+w):最优方案是把这条边的长度减到 (0)(因为边权不能是负数)
(lleq xleq l+w):(f'(x)=f(l)+w-(x-l)):把这条边的代价减掉(w-(x-l))
(l+wleq xleq r+w):(f'(x)=f(l)):这条边的代价不需要变
(xgeq r+w):(f'(x)=f(l)+(x-R)-w):把这条边的代价减掉((x-r)-w)
那么就是把 ([l,r]) 这段往右平移,把 ([0,l]) 这段往上平移,加入一段斜率为 (1) 的直线和一段斜率为 (-1)的直线。
考虑怎么维护这个凸包。
可以发现相邻两段的斜率之差为 (1),所以只需要维护凸包上相邻两个线段交点的横坐标即可。
还可以发现凸包最右边那条直线的斜率就是这个点的儿子个数。
所以直接把最右边儿子个数 (-1) 条个交点弹掉就能找到 ([l,r]) 了。
把两个凸包合并:
直接把所有交点相加就好了。
那么要怎么计算答案呢?
先找到 ([l,r]),然后对于左边的每一个交点 (v),它的贡献就是 (-v)。
直接相加就好了。
可以用可合并堆实现,复杂度为 (O((n+m)log (n+m)))
但是我懒。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
#include<queue>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c,b=0;
while(((c=getchar())<'0'||c>'9')&&c!='-');
if(c=='-')
{
c=getchar();
b=1;
}
do
{
s=s*10+c-'0';
}
while((c=getchar())>='0'&&c<='9');
return b?-s:s;
}
void put(int x)
{
if(!x)
{
putchar('0');
return;
}
static int c[20];
int t=0;
while(x)
{
c[++t]=x%10;
x/=10;
}
while(t)
putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
if(b<a)
{
a=b;
return 1;
}
return 0;
}
int upmax(int &a,int b)
{
if(b>a)
{
a=b;
return 1;
}
return 0;
}
int n,m;
ll w[300010];
int f[300010];
int d[300010];
priority_queue<ll> q[300010];
int main()
{
open("loj2568");
scanf("%d%d",&n,&m);
ll ans=0;
for(int i=2;i<=n+m;i++)
{
scanf("%d%lld",&f[i],&w[i]);
ans+=w[i];
d[f[i]]++;
}
for(int i=n+m;i>=2;i--)
{
ll l=0,r=0;
if(i<=n)
{
while(--d[i])
q[i].pop();
l=q[i].top();
q[i].pop();
r=q[i].top();
q[i].pop();
}
q[i].push(l+w[i]);
q[i].push(r+w[i]);
if(q[i].size()>q[f[i]].size())
q[i].swap(q[f[i]]);
while(!q[i].empty())
q[f[i]].push(q[i].top()),q[i].pop();
}
while(d[1]--)
q[1].pop();
while(!q[1].empty())
ans-=q[1].top(),q[1].pop();
printf("%lld
",ans);
return 0;
}