CF671D(set 维护整体 dp)

翻别人博客的时候翻到的一道题

  • 给定一棵 n 个点的以 1 为根的树。
  • 有 m 条路径 (x,y),保证 y 是 x 或 x 的祖先,每条路径有一个权值。
  • 你要在这些路径中选择若干条路径,使它们能覆盖每条边,同时权值和最小。
  • (n,m le 3 imes 10^5)

首先可以想到一个显然的 dp。

(dp[i][j]) 表示以 i 为根的子树,向上延申了 j 个点。然后转移就是了复杂度 (O(n^2))

但这个复杂度不太能过得去,我们换一种形式。

(dp[i][j]) 表示在 i 为根的子树,向上支配到 j 深度。并且设 (f(i)=minlimits_{i=1}^{dep[i]-1}dp[i][j])

然后有

[dp[i][j] = sum_{vin son(i)}f(v)-min_{vin son(i)}(dp[v][j]-f(v)) ]

[dp[i][anc] = sum_{vin son(i)}f(v)+c ]

仔细分析一波,第一个式子相当于一个线段树合并维护 dp,而第二个式子则是在做一个全局加法。

可以想到使用线段树合并来维护 dp。

不过 (operatorname{256MB}) 空间复杂度好像不太能过得去。

有一种奇妙的做法是用 set 来维护这个 dp。

具体来说,每个节点维护一个 set,里面存放形如 ((j,dp[i][j])) 的二元组,转移的时候相当于做区间加法、合并两个 set,并且需要支持随时取出 (f(i))

注意到我们可以随时维护二元组的第一位 j,不超过 (dep[i]) 并且对每一个 (j) 只有一个二元组,这样前两个操作非常容易处理。

而对于第二个操作,数据结构的优势已经基本用尽,要还想随时维护的话,就得继续嵌套其他数据结构。这时候注意到一个贪心性质,set 中的二元组是单调递减的,所以只需要修改的时候顺便维护 set 的单调性就行。

代码细节一般,注意合并 set 的时候启发式合并。

// 代码中没有启发式合并就草过去了/xk
#include<bits/stdc++.h>

using namespace std;

#define int long long
#define pb emplace_back
#define pii pair<int,int>

template<typename _T>
inline void read(_T &x)
{
    x= 0 ;int f =1;char s=getchar();
    while(s<'0' ||s>'9') {f =1;if(s == '-')f =- 1;s=getchar();}
    while('0'<=s&&s<='9'){x = (x<<3) + (x<<1) + s - '0';s= getchar();}
    x*=f;
}

const int np = 3e5 + 5;
const int INF = 1 << 30;
int head[np],ver[np * 2],nxt[np * 2],tit;
int f[np],id[np];
vector<pii> vec[np];
int tmp[np],dep[np],son[np];
int n,m,pre[np],siz[np];
set<pii> s[np];
int tag[np];

inline void add(int x,int y)
{
    ver[++tit] = y;
    nxt[tit] = head[x];
    head[x] = tit;
}

inline void solve(int x)
{
	int a1(-1),a2(-1);
	for(set<pii>::iterator it = s[id[x]].begin(),iter;it != s[id[x]].end();it ++)
	{
		iter = it;
		int a1_ = (*it).first;
		int a2_ = (*it).second;
		if(a1==-1)
		{
			a1 = a1_;
			a2 = a2_;
			continue;
		}
		if(a2 <= a2_){
			iter++;
			s[id[x]].erase(it);
			it = iter;
			it--;
		}else{
			a1 = a1_;
			a2 = a2_;
		}
	}
}

inline void ins(int u,int j,int val)
{
    val += tag[id[u]];
    if(j > dep[u]) return ;
    set<pii>::iterator it = s[id[u]].lower_bound((pii){j,-INF});
    set<pii>::iterator iter,aux;
    if(it == s[id[u]].end() || (*it).first != j)
    {
        s[id[u]].insert((pii){j,val});
        iter = s[id[u]].lower_bound((pii){j,val});
//        it = iter;
//        if(iter == s[id[u]].begin()) return ;
//		iter--;
//        if((*iter).second <= val){
//            s[id[u]].erase(it);
//        } 
        
    }
    else{
        if((*it).second < val) return ;
        s[id[u]].erase(it);
        s[id[u]].insert((pii){j,val});
    }
}

inline void dfs(int x,int ff)
{
    int F(0);
    dep[x] = dep[ff] + 1;
    for(int i=head[x],v;i;i=nxt[i])
    {
        v = ver[i];
        if(v == ff) continue;
        dfs(v,x);
        F += f[v];
		if(f[v] == -1){
			f[x]=-1;
			return ;
		}
    }
    tag[id[x]] += F;
    for(auto pi:vec[x]){
        int anc = pi.first;
        int sd = pi.second;
        ins(x,dep[anc],sd);
    }
    ins(x,dep[x],0);
    for(int i=head[x],v;i;i=nxt[i])
    {
        v = ver[i];
        if(v == ff) continue;
//        if(s[x].size() > s[v].size()) swap(id[x],id[v]);
        for(set<pii>::iterator it = s[id[v]].begin();it!=s[id[v]].end();it ++)
        {
            int j = (*it).first;
            int val = (*it).second;
			ins(x,j,val-f[v]);
        }
    }
    solve(x);
//    cerr<< x <<" : ";
//    printf("%d : ",x);
//    for(auto i:s[id[x]])
//    {
//    	cerr<<"("<<i.first<<","<<i.second<<")"; 
//	}
//	cerr<<'
';
//	puts("");
	if(x == 1) return;
//	cout<<((*(s[x].rbegin())).second)<<'
';
    if((*(s[id[x]].rbegin())).first != dep[x])f[x] = (*(s[id[x]].rbegin())).second;
	else {
		set<pii >::iterator it = s[id[x]].end();
		--it;
		if(it == s[id[x]].begin()) {
			f[x]=-1;
			return ;	
		}//f[x] =-1;
		it--;
		f[x] = (*it).second;
	}//f[x] = (*(--s[x].rbegin())).second;
//	printf("%d
",f[x]);
}

signed main()
{
    read(n),read(m);
    for(int i=1,x,y;i <= n- 1;i ++)
    {
        read(x),read(y);
        add(x,y),add(y,x);
    }
    for(int i=1,x,y,val;i <= m;i ++){
        read(x),read(y),read(val);
        vec[x].pb((pii){y,val});
    }
    for(int i=1;i <= n;i ++) id[i] = i;
    dfs(1,0);
    if(f[1] == -1)
    {
    	puts("-1");
    	return 0;
	}
    printf("%lld
",(*(s[1].lower_bound((pii){1,-INF}))).second);
}

原文地址:https://www.cnblogs.com/-Iris-/p/15342836.html