题解 hdu4624 Endless Spin

题解 hdu4624 Endless Spin

题目大意

题目链接

有长度为\(n\)的区间,每次随机选择一段(左右端点都是整数)染黑,问期望多少次全部染黑。

数据范围:\(n\leq 50\)

本题题解

\(n\) 个随机变量 \(a_1,\dots,a_n\)\(a_i\) 表示第一次覆盖到 \(i\) 的操作次数的期望。则我们要求的是 \(E(\max_{i=1}^{n}\{a_i\})\)

考虑 minmax 容斥:

\[\max_{x\in s}\{x\}=\sum_{t\subseteq s}(-1)^{|t|+1}\min_{x\in t}\{x\} \]

这样我们就转化为对每个点集 \(s\),求 \(E(\min_{i\in s}\{a_i\})\),也就是染黑其中至少一个点所需的操作次数的期望。

如果我们知道了只操作一次的情况下,染黑点集中至少一个点的概率 \(p\),则所需操作次数的期望就是 \(\frac{1}{p}\)。(例如掷一次骰子掷到某个数的概率是 \(\frac{1}{6}\),则期望掷 \(6\) 次可以第一次得到该数)。

这个还是不好求,我们转而求:操作一次,\(s\) 中的点一个都没染黑的概率 \(p'\),则 \(p=1-p'\), \(E(\min_{i\in s}\{a_i\})=\frac{1}{1-p'}\)

考虑如果暴力枚举一个子集 \(s\)。则整个数列被 \(s\) 内的点划分成若干个区间,设长度分别为:\(l_1,l_2,...,l_k\)。则\(p'=\frac{\sum_{i=1}^{k}\frac{1}{2}l_i(l_i+1)}{\frac{1}{2}n(n+1)}\)。复杂度 \(\mathcal{O}(2^nn)\),无法承受。

考虑 DP。设 \(\mathrm{dp}(i,j,k,0/1)\) 表示考虑了前 \(i\) 个位置;有 \(j\) 个不包含【点集的里点】的区间 (\(j\leq \frac{1}{2}n(n+1)\));上一个【点集里的点】距离 \(i\)\(k\);点集大小的奇偶性为 \(0/1\)。这样选出区间的方案数。

转移时考虑第 \(i+1\) 个位置是否加入点集:

  • 如果加入点集:\(\mathrm{dp}(i,j,k,0/1)\to \mathrm{dp}(i+1,j,0,1/0)\)

  • 如果不加入点集:\(\mathrm{dp}(i,j,k,0/1)\to \mathrm{dp}(i+1,j+k+1,k+1,0/1)\)

其中 \(\mathrm{dp}(a)\to \mathrm{dp}(b)\) 表示从状态 \(a\) 转移到状态 \(b\),也就是 \(\mathrm{dp}(b)\texttt{+=}\mathrm{dp}(a)\)。转移是 \(\mathcal{O}(1)\) 的,所以 DP 的复杂度就是状态数,也就是 \(\mathcal{O}(n^4)\)

统计答案时把所有 \(j\) 的情况加起来即可。即:\(\mathrm{ans}(n)=\displaystyle\sum_{j=0}^{\frac{1}{2}n(n+1)-1}\frac{\mathrm{dp}(n,j,k,0/1)\times(-1)^{1/0}}{1-\frac{j}{\frac{1}{2}n(n+1)}}\)

备注:具体实现的时候把分数上下同时乘以 \(\frac{1}{2}n(n+1)\) 会更好写。式子上面的 \(k\) 表示所有 \(k\) 的情况的和。\((-1)\) 的指数上的 \(1/0\) 和状态里的 \(0/1\) 是相反的,因为 minmax 容斥的公式本来就是 \((-1)^{|t|+1}\)

注意统计答案时要使用高精度。

参考代码

// problem: HDU4624
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // dysyn1314

namespace Bigdouble{
	const int K=50;
	typedef long long ll;
	struct db{ll zs,xs[K+5];db(){zs=0;memset(xs,0,sizeof(xs));}};
	db makedb(ll fz,ll fm){
		db res;
		res.zs=fz/fm,fz%=fm,fz*=10;
		for(int i=1;i<=K;++i)res.xs[i]=fz/fm,fz%=fm,fz*=10;
		return res;
	}
	db operator + (db a,db b){
		db res;ll jw=0;
		for(int i=K;i>=1;--i)res.xs[i]=a.xs[i]+b.xs[i]+jw,jw=res.xs[i]/10,res.xs[i]%=10;
		res.zs=a.zs+b.zs+jw;
		return res;
	}
	db operator - (db a,db b){
		db res;
		for(int i=K;i>=2;--i){
			if(a.xs[i]<b.xs[i])a.xs[i-1]--,a.xs[i]+=10;
			res.xs[i]=a.xs[i]-b.xs[i];
		}
		if(a.xs[1]<b.xs[1])a.zs--,a.xs[1]+=10;
		res.xs[1]=a.xs[1]-b.xs[1];
		res.zs=a.zs-b.zs;
		return res;
	}
	db operator * (db a,ll b){
		db res;
		ll jw=0;
		for(int i=K;i>=1;--i)res.xs[i]=a.xs[i]*b+jw,jw=res.xs[i]/10,res.xs[i]%=10;
		res.zs=a.zs*b+jw;
		return res;
	}
	void printdb(db a,int k=15){
		if(a.xs[k+1]>=5)a.xs[k]++;
		int t=k;
		while(a.xs[t]>=10){
			a.xs[t]-=10;
			if(t!=1)a.xs[--t]++;
			else{a.zs++;break;}
		}
		cout<<a.zs<<".";
		for(int i=1;i<=k;++i)cout<<a.xs[i];
	}
}
using namespace Bigdouble;
const int MAXN=55;
ll dp[MAXN][MAXN*MAXN][MAXN][2];
db ans[MAXN];
int main() {
	dp[0][0][0][0]=1;
	for(int i=0;i<50;++i){
		for(int j=0;j<=i*(i+1)/2;++j){
			for(int k=0;k<=i;++k){
				for(int t=0;t<=1;++t){
					dp[i+1][j][0][t^1]+=dp[i][j][k][t];
					dp[i+1][j+k+1][k+1][t]+=dp[i][j][k][t];
				}
			}
		}
	}
	for(int n=1;n<=50;++n){
		for(int j=0;j<n*(n+1)/2;++j){
			db tmp=makedb(n*(n+1)/2,n*(n+1)/2-j);
			ll sum=0;
			for(int k=0;k<=n;++k)sum+=dp[n][j][k][0];
			ans[n]=ans[n]-(tmp*sum);
			sum=0;
			for(int k=0;k<=n;++k)sum+=dp[n][j][k][1];
			ans[n]=ans[n]+(tmp*sum);
		}
	}
	//for(int n=1;n<=50;++n)printf("%d\n",n),printdb(ans[n]),puts("");return 0;
	int t=read();while(t--){
		int n=read();
		printdb(ans[n]);puts("");
	}
	return 0;
}
原文地址:https://www.cnblogs.com/dysyn1314/p/12371810.html