[LuoguP4426][AHOI2018]毒瘤(动态DP)

[LuoguP4426][AHOI2018]毒瘤(动态DP)

题面

给出一个(n)个点(m)条边的无向图,求独立集个数。
(n leq 10^5,n-1 leq m leq n+10)

分析

注意到(|m-n|)很小,我们可以暴力枚举这些非树边((u,v))的状态,按两边选和不选有(0,0)(0,1)(1,0)三种。其实可以合并为2种:

  1. (u)强制不选,(v)可任意选
  2. (u)强制选,(v)强制不选

那么直接暴力枚举每条边的状态,然后在树上修改,做动态DP即可。

(f_{x,0},f_{x,1})分别表示(x)不选/选,(x)子树中的独立集个数,那么:
(f_{x,0}=1+prod_{y in child(x)} (f_{y,0}+f_{y,1}))
(f_{x,1}=1+prod_{y in child(x)} f_{y,0})

最终答案为(f_{x,0}+f_{x,1})

(g_{x,0}=1+prod_{y in child(x)-{son(x)}} (f_{y,0}+f_{y,1}))

(g_{x,1}=1+prod_{y in child(x)-{son(x)}} f_{y,0})

g维护了所有轻儿子的DP贡献,那么有:

(f_{x,0}=(f_{son(x),0}+f_{son(x),1})cdot g_{x,0})
(f_{x,1}=f_{son(x),0} cdot g_{x,1})

写成矩阵的形式(注意这里是+,(cdot)矩阵乘法)

[egin{bmatrix}f_{x,0} \ f_{x,1} end{bmatrix}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} 0 end{bmatrix} egin{bmatrix}f_{son(x),0} \ f_{son(x),1} end{bmatrix} ]

(m{M_x}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} 0 end{bmatrix})。为了处理强制选和不选的情况,我们还需要对每个节点定义一个矩阵(m{C_x}),求区间矩阵积的时候把乘(m{M_x})变成乘(m{C_xM_x})

注意到(egin{bmatrix} 0 0 \ 0 1end{bmatrix}egin{bmatrix}f_{x,0} \ f_{x,1} end{bmatrix}=egin{bmatrix}0 \ f_{x,1} end{bmatrix}),于是使得(f_{x,0}=0),那么(m{C_x}=egin{bmatrix} 0 0 \ 0 1end{bmatrix})就表示强制选(x).同理(m{C_x}=egin{bmatrix} 1 0 \ 0 0end{bmatrix})就表示强制不选(x),(m{C_x}=egin{bmatrix} 1 0 \ 0 1end{bmatrix})就表示选和不选(x)均可。于是枚举的时候单点修改即可。

但是还有一个问题,在动态DP的过程中,我们需要把儿子的影响从父亲中消除,也就是说要做除法。但是万一(f_y=0),就会出现除0的问题。于是我们可以对于每个(f)(g)值,记录它们被乘进去了几个0,做除法的时候0的个数会减少。如果减到了0,就变成了它们的真实值。具体实现可以定义一个新的类,重载它的*,/运算符

struct mynum { //为了消除下方g对上方g的影响,要支持撤回乘0操作
	ll val;
	int cnt;//记录被乘上去的0个数
	mynum() {
		val=cnt=0;
	}
	mynum(ll _val) {
		if(_val==0) val=cnt=1;
		else val=_val,cnt=0;
	}
	friend mynum operator * (mynum p,mynum q) {
		mynum ans;
		ans.val=p.val*q.val%mod;//把0的val设为1,这样乘的时候val就不变
		ans.cnt=p.cnt+q.cnt;
		return ans;
	}
	friend mynum operator / (mynum p,mynum q) {
		mynum ans;
		ans.val=p.val*inv(q.val)%mod;
		ans.cnt=p.cnt-q.cnt;
		return ans;
	}
	ll value() {
		if(cnt==0) return val;
		else return 0;
	}
};

用LCT实现,复杂度(O(n+m+2^{m-n}log n)),常数还可以。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 200000
#define mod 998244353
using namespace std;
typedef long long ll;
template<typename T> void qread(T &x) {
	x=0;
	T sign=1;
	char c=getchar();
	while(c<'0'||c>'9') {
		if(c=='-') sign=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9') {
		x=x*10+c-'0';
		c=getchar();
	}
	x=x*sign;
}
template<typename T> void qprint(T x) {
	if(x<0) {
		putchar('-');
		qprint(-x);
	} else if(x==0) {
		putchar('0');
		return;
	} else {
		if(x>=10) qprint(x/10);
		putchar('0'+x%10);
	}
}

inline ll fast_pow(ll x,ll k) {
	ll ans=1;
	while(k) {
		if(k&1) ans=ans*x%mod;
		x=x*x%mod;
		k>>=1;
	}
	return ans;
}
inline ll inv(ll x) {
	return fast_pow(x,mod-2);
}

int n,m;
struct edge {
	int from;
	int to;
	int next;
} E[maxn*2+5];
int head[maxn+5];
int esz=1;
void add_edge(int u,int v) {
	esz++;
	E[esz].from=u;
	E[esz].to=v;
	E[esz].next=head[u];
	head[u]=esz;
}

struct mynum { //为了消除下方g对上方g的影响,要支持撤回乘0操作
	ll val;
	int cnt;//记录被乘上去的0个数
	mynum() {
		val=cnt=0;
	}
	mynum(ll _val) {
		if(_val==0) val=cnt=1;
		else val=_val,cnt=0;
	}
	friend mynum operator * (mynum p,mynum q) {
		mynum ans;
		ans.val=p.val*q.val%mod;//把0的val设为1,这样乘的时候val就不变
		ans.cnt=p.cnt+q.cnt;
		return ans;
	}
	friend mynum operator / (mynum p,mynum q) {
		mynum ans;
		ans.val=p.val*inv(q.val)%mod;//把0的val设为1,这样乘的时候val就不变
		ans.cnt=p.cnt-q.cnt;
		return ans;
	}
	ll value() {
		if(cnt==0) return val;
		else return 0;
	}
};
mynum mat[maxn+5][2][2];//存储每个点的初始矩阵

struct matrix {
	ll a[2][2];//因为矩阵乘法需要加减,这里不能用mynum
	matrix() {
		a[0][0]=a[0][1]=a[1][0]=a[1][1]=0;
	}
	ll* operator [](int i) {
		return a[i];
	}

	friend matrix operator * (matrix p,matrix q) {
		matrix ans;
		for(int i=0; i<=1; i++) {
			for(int j=0; j<=1; j++) {
				for(int k=0; k<=1; k++) {
					ans[i][j]=(ans[i][j]+p[i][k]*q[k][j]%mod)%mod;
				}
			}
		}
		return ans;
	}
	void print(){
		for(int i=0;i<=1;i++){
			for(int j=0;j<=1;j++) printf("%lld ",a[i][j]);
			printf("
");
		}
	}
};

matrix cho[maxn+5];//修改用,如果a[x][0][0]=0,那么乘完之后f[x][0]=0,表示强制不选。如果a[x][1][1]=0,表示强制选 
struct LCT {
#define lson(x) (tree[x].ch[0])
#define rson(x) (tree[x].ch[1])
#define fa(x) (tree[x].fa)
	struct node {
		int fa;
		int ch[2];
		matrix v;
	} tree[maxn+5];
	inline bool is_root(int x){
		return !(lson(fa(x))==x||rson(fa(x))==x);
	}
	inline int check(int x){
		return rson(fa(x))==x;
	}
	void push_up(int x){
		tree[x].v[0][0]=mat[x][0][0].value();
		tree[x].v[0][1]=mat[x][0][1].value();
		tree[x].v[1][0]=mat[x][1][0].value();
		tree[x].v[1][1]=mat[x][1][1].value();
		tree[x].v=cho[x]*tree[x].v;
		if(lson(x)) tree[x].v=tree[lson(x)].v*tree[x].v;
		if(rson(x)) tree[x].v=tree[x].v*tree[rson(x)].v;
	}
	void rotate(int x){
		int y=fa(x),z=fa(y),k=check(x),w=tree[x].ch[k^1];
		tree[y].ch[k]=w;
		tree[w].fa=y;
		if(!is_root(y)) tree[z].ch[check(y)]=x;
		tree[x].fa=z;
		tree[x].ch[k^1]=y;
		tree[y].fa=x;
		push_up(y);
		push_up(x); 
	} 
	void splay(int x){
		while(!is_root(x)){
			int y=fa(x);
			if(!is_root(y)){
				if(check(x)==check(y)) rotate(y);
				else rotate(x); 
			}
			rotate(x);
		}
	}
	void access(int x){
		for(int y=0;x;y=x,x=fa(x)){
			splay(x);
			if(rson(x)){
				mat[x][0][0]=mat[x][0][0]*mynum(tree[rson(x)].v[0][0]+tree[rson(x)].v[1][0]);
				mat[x][1][0]=mat[x][1][0]*mynum(tree[rson(x)].v[0][0]);
			}
			rson(x)=y;
			if(rson(x)){
				mat[x][0][0]=mat[x][0][0]/mynum(tree[rson(x)].v[0][0]+tree[rson(x)].v[1][0]);
				mat[x][1][0]=mat[x][1][0]/mynum(tree[rson(x)].v[0][0]);
			}
			mat[x][0][1]=mat[x][0][0];
			push_up(x);
		}
	}
	ll query(){
		splay(1);
		return (tree[1].v[0][0]+tree[1].v[1][0])%mod;
	}
	void set(int x,bool is_chosen){
		access(x);
		splay(x);
		if(is_chosen) cho[x][0][0]=0;//a[0][0]=0,乘上[f[x][0],f[x][1]]后只剩下f[x][1],表示强制选 
		else cho[x][1][1]=0;
		push_up(x);
	} 
	void revert(int x,bool is_chosen){
		access(x);
		splay(x);
		if(is_chosen) cho[x][0][0]=1;
		else cho[x][1][1]=1;
		push_up(x);
	}
}T;

ll f[maxn+5][2]; 
void ini_dfs(int x,int fa){
	cho[x][0][0]=cho[x][1][1]=1;//一开始表示选和不选均可
	f[x][0]=f[x][1]=1;
	for(int i=head[x];i;i=E[i].next){
		int y=E[i].to;
		if(y!=fa){
			ini_dfs(y,x);
			f[x][0]=f[x][0]*(f[y][0]+f[y][1])%mod;
			f[x][1]=f[x][1]*f[y][0]%mod;
		}
	}
	mat[x][0][0]=mat[x][0][1]=mynum(f[x][0]);
	mat[x][1][0]=mynum(f[x][1]);
	mat[x][1][1]=mynum(0);
	T.tree[x].fa=fa;
	T.push_up(x);
}
struct disjoint_set{
	int fa[maxn+5];
	void ini(int n){
		for(int i=1;i<=n;i++) fa[i]=i;
	}
	int find(int x){
		if(fa[x]==x) return x;
		else return fa[x]=find(fa[x]);
	}
	void merge(int x,int y){
		fa[find(x)]=find(y);
	}
}S;

vector< pair<int,int> >ed; //存储非树边
//非树边(u,v)有三种情况(0,0) (0,1) (1,0),前两种可以合在一起,表示u强制不选。后一种是u强制选 
//那么我们可以状压枚举情况1
//若第j位为1,则u[j]强制选。若第j位为0,则u[j]强制不选 
ll ans=0;

int main() {
	int u,v;
	qread(n);
	qread(m);
	S.ini(n);
	for(int i=1;i<=m;i++){
		qread(u);
		qread(v);
		if(S.find(u)!=S.find(v)){
			 S.merge(u,v);
			 add_edge(u,v);
			 add_edge(v,u);
		}
		else ed.push_back(make_pair(u,v));
	}
	ini_dfs(1,0);
//	solve(0,0);
	int sz=ed.size();
	for(int i=0;i<(1<<sz);i++){
		for(int j=0;j<sz;j++){
			int u=ed[j].first,v=ed[j].second;
			if(i&(1<<j)){
				T.set(u,1);//u强制选 
				T.set(v,0);//v强制不选 (1,0)
			}else{
				T.set(u,0);//u强制选,v选不选均可 (0,0) (0,1) 
			}
		}
		ans=(ans+T.query())%mod;
		for(int j=0;j<sz;j++){
			int u=ed[j].first,v=ed[j].second;
			if(i&(1<<j)){
				T.revert(u,1);
				T.revert(v,0);
			}else{
				T.revert(u,0);
			}
		}
	}
	printf("%lld
",ans);
}
/*
3 3
1 2
1 3
2 3
*/

原文地址:https://www.cnblogs.com/birchtree/p/12681864.html