[loj#3124] [CTS2019] 氪金手游

题意简述

(n) 种卡,每种有权值 (W_i)(W_i)(p_{i,j}) 的概率取 (j) ( (j =1,2,3) )
不断抽卡,抽到卡 (i) 的概率是 (frac{W_i}{sumlimits_{j=1}^n W_j})
(T_i) 表示第一次抽到 (i) 的时间
给定 (n-1) 个限制 ((u_i,v_i)) ,要求 (T_{u_i}<T_{v_i}),以所有限制做边可形成一棵 (n) 个节点的树
求满足所有限制的概率

(n leq 10^6)


想法

乍一看不好下手,于是考虑子树 (dp)

假设是外向树,且已知 (W_i) ,则中奖概率为 (prodfrac{W_i}{size_i})

则总中奖概率为 (sum [(prodlimits_{i=1}^n P_{i,W_i}) imes prodlimits_{i=1}^n frac{W_i}{size_i}])

(dp) 计算,设以 (i) 为根的子树,(size_i)(s) 的情况下,整个子树的中奖概率为 (f[u][s])

(u) 仅有自己一个点时, (f[u][s]=P_{u,s}=frac{a_{u,s}}{a_{u,1}+a_{u,2}+a_{u,3}})

((u,v)) 将子节点 (v)(dp) 值更新 (u)(f'[u][s]) 表示更新前的),(f[u][s]+=sum f[v][i] imes f'[u][s-i] imes frac{s-i}{s})

进行树形背包 (dp) 就可以了,复杂度 (O(n^2))

然而上面做法只针对外向树,如果有从儿子指向父亲的反向边怎么办呢?

容斥!

(ans=sum (-1)^i imes 至少 i 条反向边不合法)

直接在整棵树中枚举哪些反向边不合法的复杂度过高,考虑仍在树形 (dp) 中容斥

考虑连接 (u) 与子节点 (v) 的边 ((u,v))

1.如果它原本是正向边,则转移方式不变:(f[u][s]+=sum f[v][i] imes f'[u][s-i] imes frac{s-i}{s})

2.如果它原本是反向边,则讨论它在不在“至少 (i) 条反向边不合法”中

  • 如果不在,说明这条边是“正向反向无所谓”的,直接不考虑这条边的限制,即:(f[u][s]+=sumlimits_i f'[u][s] imes f[v][i]=f'[u][s] imes (sum f[v][i]))

  • 如果在,说明这条边是正向边,需要乘上容斥系数 (-1) ,即:(f[u][s]-=sum f[v][i] imes f'[u][s-i] imes frac{s-i}{s})


总结

技巧

咋一看复杂的问题,可先从小规模问题入手分析(如树形结构的问题考虑子树)

代码

树上背包的写法,可以考虑是对每条边 ((u,v)) ,用 (v) 更新 (u)(dp) 值。
取模检查 (1ll*)(\%P)


代码

#include<cstdio>
#include<iostream>
#include<algorithm>

#define P 998244353

using namespace std;

int read(){
	int x=0;
	char ch=getchar();
	while(!isdigit(ch)) ch=getchar();
	while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
	return x;
}

const int N = 1005;

int n,a[N][4];

struct node{
	int v,dir;
	node *nxt;
}pool[N*2],*h[N];
int cnt;
void addedge(int u,int v){
	node *p=&pool[++cnt],*q=&pool[++cnt];
	p->v=v;p->nxt=h[u];h[u]=p; p->dir=v;
	q->v=u;q->nxt=h[v];h[v]=q; q->dir=v;
}

int Pow_mod(int x,int y){
	int ret=1;
	while(y){
		if(y&1) ret=1ll*ret*x%P;
		x=1ll*x*x%P;
		y>>=1;
	}
	return ret;
}
int Plus(int x,int y) { return x+y<P?x+y:x+y-P; }
int Minus(int x,int y) { return x>=y?x-y:x-y+P; }

int inv[N*3];
int f[N][N*3],sz[N],g[N*3];
void dfs(int u,int fa){
	int v;
	sz[u]=1;
	f[u][1]=a[u][1]; f[u][2]=a[u][2]; f[u][3]=a[u][3];
	for(node *p=h[u];p;p=p->nxt)
		if((v=p->v)!=fa) {
			dfs(v,u);
			if(p->dir==v){
				for(int i=1;i<=sz[u]*3;i++)
					for(int j=1;j<=sz[v]*3;j++)
						g[i+j]=Plus(g[i+j],1ll*f[u][i]*f[v][j]%P*i%P*inv[i+j]%P);
				sz[u]+=sz[v];
				for(int i=1;i<=sz[u]*3;i++)
					f[u][i]=g[i],g[i]=0;
			}
			else{
				for(int i=1;i<=sz[u]*3;i++)
					for(int j=1;j<=sz[v]*3;j++)
						g[i+j]=Minus(g[i+j],1ll*f[u][i]*f[v][j]%P*i%P*inv[i+j]%P);
				int sum=0;
				for(int i=1;i<=sz[v]*3;i++) sum=Plus(sum,f[v][i]);
				for(int i=1;i<=sz[u]*3;i++)
					g[i]=Plus(g[i],1ll*f[u][i]*sum%P);
				sz[u]+=sz[v];
				for(int i=1;i<=sz[u]*3;i++)
					f[u][i]=g[i],g[i]=0;
			}
		}
}

int main()
{
	n=read();
	for(int i=1;i<=n;i++){
		a[i][1]=read(); a[i][2]=read(); a[i][3]=read();
		int Inv=Pow_mod(Plus(a[i][1],Plus(a[i][2],a[i][3])),P-2);
		a[i][1]=1ll*a[i][1]*Inv%P;
		a[i][2]=1ll*a[i][2]*Inv%P;
		a[i][3]=1ll*a[i][3]*Inv%P;
	}
	int u,v;
	for(int i=1;i<n;i++){
		u=read(); v=read();
		addedge(u,v);
	}
	
	inv[1]=1;
	for(int i=2;i<=n*3;i++) inv[i]=P-1ll*(P/i)*inv[P%i]%P;
	dfs(1,0);
	
	int ans=0;
	for(int i=1;i<=n*3;i++) ans=Plus(ans,f[1][i]);
	printf("%d
",ans);
	
	return 0;
}
原文地址:https://www.cnblogs.com/lindalee/p/13222071.html