[LuoguP6144][USACO20FEB]Help Yourself P(DP+组合数学+线段树)

[LuoguP6144][USACO20FEB]Help Yourself P(DP+组合数学+线段树)

题面

Bessie 现在有 N条在一条数轴上的线段,第 i条线段覆盖了 ([l_i,r_i](1 leq l_i,r_i leq 2N))的所有实数。定义一个线段集合的并为所有至少被一条线段覆盖的实数。定义一个线段集合的复杂度为该集合并的联通块个数的 K 次方。
Bessie现在想计算这N条线段的(2^N)个子集的复杂度之和模 (10^9+7)

分析

先把线段排序。
(dp_{i,j,t})表示前(i)条线段,覆盖到最右边的点为(j)的所有子集,每个子集的连通块个数的(t)次方之和。(相当于把题面中的复杂度改成了t次方,原因是下面要用二项式定理转移).
显然有初始值(dp_{0,0,0}=1),答案为(sum_{j=1}^{2n}dp_{n,j,K})

考虑添加第(i)条线段([l_i,r_i])对答案的影响。

(1) 对于(j<l_i)的状态,加入([l_i,r_i])后连通块个数会+1,最右边的点变为(r_i). 设(dp_{i-1,j,t})对应的子集为(S),(cnt(S))为子集的连通块个数。(dp_{i,r_i,t})会增加的值为

[sum_{S} (cnt(S)+1)^t=sum_{S} sum_{i=1}^K C_{t}^i cnt(S)^t=sum_{S} sum_{i=1}^K C_{t}^i dp_{i-1,j,t} ]

(2)对于(l_i leq j leq r_i)的状态,加入([l_i,r_i])后连通块个数不变,最右边的点变为(r_i). (dp_{i,r_i,t})会增加的值为(dp_{i-1,j,t})

(3)对于(j>r_i)的状态,加入([l_i,r_i])后连通块个数不变,最右边的点也不变。但是子集的个数乘了2(每个子集都可以选或不选第(i)个区间)。因为每个子集的连通块个数不变,所以把(dp_{i,j,t})乘2即可。

那么我们就可以用线段树维护DP转移。显然(i)这一维可以去掉,线段树的叶子节点([j,j])维护一个(K)维向量代表(dp_{i,j}).
(1) 求([0,l-1])的区间和,再按照上面的组合式计算出增加量,然后对(r)单点增加.
(2) 求([l,r])的区间和,然后对(r)单点增加.
(3) 对区间([r+1,n])区间乘2
注意所有查询操作要在修改操作前。那么维护一个支持单点加向量,区间数乘向量,查询区间向量和的线段树即可。

复杂度(O(nKlog n))

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 200000
#define maxk 10
#define mod 1000000007
using namespace std;
typedef long long ll;
int n,K;
struct seg{
	int l;
	int r;
	friend bool operator < (seg p,seg q){
		if(p.l==q.l) return p.r<q.r;
		else return p.l<q.l;
	}
}a[maxn+5];

ll C[maxk+5][maxk+5];
void ini(int m){
	for(int i=0;i<=m;i++){
		C[i][0]=C[i][i]=1;
		for(int j=0;j<i;j++) C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;
	}
}
struct val_type{//维护一个K维向量 
	ll a[maxk+5];
	val_type(){
		memset(a,0,sizeof(a));
	}
	ll & operator [] (const int i){
		return a[i];
	}
	friend val_type operator + (val_type p,val_type q){
		val_type ans;
		for(int i=0;i<=K;i++) ans[i]=(p[i]+q[i])%mod;
		return ans;
	}
	friend val_type operator * (val_type p,ll x){
		val_type ans;
		for(int i=0;i<=K;i++) ans[i]=p[i]*x%mod;
		return ans;
	}
	void print(){
		printf("debug:");
		for(int i=0;i<=K;i++) printf("%lld ",a[i]);
		printf("
");
	}
};

struct segment_tree{
	struct node{
		int l;
		int r;
		ll mtag;
		val_type v;
	}tree[maxn*4+5];
	void push_up(int pos){
		tree[pos].v=tree[pos<<1].v+tree[pos<<1|1].v;
	}
	void mul_tag(int pos,int v){
		tree[pos].mtag=tree[pos].mtag*v%mod;
		tree[pos].v=tree[pos].v*v; 
	}
	void push_down(int pos){
		if(tree[pos].mtag!=1){
			mul_tag(pos<<1,tree[pos].mtag);
			mul_tag(pos<<1|1,tree[pos].mtag);
			tree[pos].mtag=1;
		}
	}
	void build(int l,int r,int pos){
		tree[pos].l=l;
		tree[pos].r=r;
		tree[pos].mtag=1;
		if(l==r) return;
		int mid=(l+r)>>1;
		build(l,mid,pos<<1);
		build(mid+1,r,pos<<1|1); 
	} 
	void add_point(int upos,val_type &uval,int pos){//单点加向量 
		if(tree[pos].l==tree[pos].r){
			tree[pos].v=tree[pos].v+uval;
			return;
		}
		push_down(pos);
		int mid=(tree[pos].l+tree[pos].r)>>1;
		if(upos<=mid) add_point(upos,uval,pos<<1);
		else add_point(upos,uval,pos<<1|1);
		push_up(pos); 
	}
	void mul_seg(int L,int R,ll uval,int pos){//区间数乘 
		if(L<=tree[pos].l&&R>=tree[pos].r){
			mul_tag(pos,uval);
			return;
		}
		push_down(pos);
		int mid=(tree[pos].l+tree[pos].r)>>1;
		if(L<=mid) mul_seg(L,R,uval,pos<<1);
		if(R>mid) mul_seg(L,R,uval,pos<<1|1);
		push_up(pos);
	} 
	val_type query(int L,int R,int pos){//查询区间向量和 
		if(L<=tree[pos].l&&R>=tree[pos].r){
			return tree[pos].v;
		}
		push_down(pos);
		int mid=(tree[pos].l+tree[pos].r)>>1;
		val_type ans;
		if(L<=mid) ans=ans+query(L,R,pos<<1);
		if(R>mid) ans=ans+query(L,R,pos<<1|1);
		return ans;
	}
}T; 
int main(){
	scanf("%d %d",&n,&K);
	ini(K);
	for(int i=1;i<=n;i++) scanf("%d %d",&a[i].l,&a[i].r);
	sort(a+1,a+1+n);
	T.build(0,n*2,1);
	val_type tmp;
	tmp[0]=1; 
	T.add_point(0,tmp,1);//dp[0]初始化为1
	for(int i=1;i<=n;i++){
		int l=a[i].l,r=a[i].r;
		val_type last=T.query(0,l-1,1);
//		last.print();
		val_type now;
		for(int i=0;i<=K;i++){
			for(int j=0;j<=i;j++){
				now[i]+=last[j]*C[i][j]%mod; 
				now[i]%=mod;
			}
		}
		now=now+T.query(l,r,1);
//		now.print();
		T.add_point(r,now,1);
		if(r!=n*2) T.mul_seg(r+1,n*2,2,1);
//		T.tree[1].v.print();
	} 
	printf("%lld
",T.tree[1].v[K]);
}

原文地址:https://www.cnblogs.com/birchtree/p/12555498.html