bzoj 5006(洛谷 4547) [THUWC2017]Bipartite 随机二分图——期望DP

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=5006

   https://www.luogu.org/problemnew/show/P4547

算一种可行方案,只要确定出 n 条边即可;概率就是这 n 条边存在的概率,其他边视作无要求,概率贡献都是1;这样的话,一种方案对答案的贡献就是其概率。

考虑把第二组边和第三组边分成概率分别为 1/2 的两条独立的边。对于第二组边再加一条能把4个点都连起来的 1/4 的边,对于第三组边再加一条能把4个点都连起来的 -1/4 的边。

因为算一个方案的概率的时候只看选中的边的概率乘积,所以上述方案可以让概率计算正确。

可以设 dp[ s0 ][ s1 ] 表示左部的点集 s0 和右部的点集 s1 匹配的期望方案数。然后记忆化搜索。(也可以刷表,还会快,但不太会写)

为了避免方案数因为加边顺序而算重,可以规定一个顺序,比如转移到 (s0,s1) 的状态 (d0,d1) 的 d0 一定不含 s0 的 lowbit 之类的。

于是枚举 d0 , d1 ,但会TLE。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#define mkp make_pair
#define pii pair<int,int>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=20,M=345,mod=1e9+7;
int n,m,bin[N],dy[(1<<15)+5],iv2,iv4; bool s[N][N];
map<pii,int> s2,mp;
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
void upd(int &x){x>=mod?x-=mod:0;}
int dfs(int s0,int s1)
{
  pii S=mkp(s0,s1);if(mp.count(S))return mp[S];
  int lbt=(s0&-s0);
  int x=dy[lbt],y,d0=s0^(lbt),p0=s1,d1;
  while(p0)//every bit of s1
    {
      lbt=(p0&-p0);
      y=dy[lbt]; d1=s1^lbt; p0^=lbt;
      if(s[x][y])
    mp[S]=(mp[S]+(ll)iv2*dfs(d0,d1))%mod;
    }
  p0=d0; x=(s0&-s0);
  while(p0)//every bit of (s0-lowbit)
    {
      lbt=(p0&-p0); d0=(x|lbt); p0^=lbt;//d0:two bits of s0
      int p1=s1;//every bit of s1
      while(p1)
    {
      lbt=(p1&-p1); y=lbt; p1^=lbt;
      int p2=p1;//d1:every bit of p1
      while(p2)
        {
          lbt=(p2&-p2); d1=(y|lbt); p2^=lbt;
          pii d=mkp(d0,d1);
          if(s2.count(d))
        mp[S]=(mp[S]+(ll)iv4*dfs(s0^d0,s1^d1)*s2[d])%mod+mod,upd(mp[S]);
        }
    }
    }
  return mp[S];
}
int main()
{
  int m,t,x1,y1,x2,y2;
  n=rdn();m=rdn();
  bin[0]=1;dy[1]=0;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1,dy[bin[i]]=i;
  iv2=pw(2,mod-2); iv4=pw(4,mod-2); mp[mkp(0,0)]=bin[n];
  for(int i=1;i<=m;i++)
    {
      t=rdn();x1=rdn()-1;y1=rdn()-1;
      if(t)x2=rdn()-1,y2=rdn()-1;
      s[x1][y1]=1; if(!t)continue;
      s[x2][y2]=1;
      s2[mkp(bin[x1]|bin[x2],bin[y1]|bin[y2])]=(t==1?1:-1);
    }
  printf("%d
",dfs(bin[n]-1,bin[n]-1));
  return 0;
}
View Code

于是改成枚举每条边,并且把两个点集压进一个数而不是两个数里。但还是TLE。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#define mkp make_pair
#define pii pair<int,int>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=20,M=345,mod=1e9+7;
int n,m,bin[N],iv2,iv4;
struct Ed{
  int x,y,w;
  Ed(int x=0,int y=0,int w=0):x(x),y(y),w(w) {}
}ed[M];
map<pii,int> mp;
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
void upd(int &x){x>=mod?x-=mod:0;}
int dfs(int s0,int s1)
{
  pii S=mkp(s0,s1);if(mp.count(S))return mp[S];
  int lbt=(s0&-s0);
  for(int i=1;i<=m;i++)
    if((ed[i].x|s0)==s0&&(ed[i].y|s1)==s1&&(ed[i].x&lbt))
      mp[S]=(mp[S]+(ll)ed[i].w*dfs(s0^ed[i].x,s1^ed[i].y))%mod;
  return mp[S];
}
int main()
{
  int tp,t,x1,y1,x2,y2;
  n=rdn();tp=rdn();
  bin[0]=1;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
  iv2=pw(2,mod-2); iv4=pw(4,mod-2); mp[mkp(0,0)]=bin[n];
  for(int i=1;i<=tp;i++)
    {
      t=rdn();x1=bin[rdn()-1];y1=bin[rdn()-1];
      ed[++m]=Ed(x1,y1,iv2); if(!t)continue;
      x2=bin[rdn()-1];y2=bin[rdn()-1];
      ed[++m]=Ed(x2,y2,iv2);
      if((x1&x2)||(y1&y2))continue;///
      ed[++m]=Ed(x1|x2,y1|y2,t==1?iv4:mod-iv4);
    }
  printf("%d
",dfs(bin[n]-1,bin[n]-1));
  return 0;
}
View Code

主要是 map 的大小。在 mp[ s0 ] 里找有没有 s1 比在整个 (s0,s1) 里找有没有 (s0,s1) 快。

并且不要写很多 mp[ s0 ][ s1 ] ,可以用一个临时变量 ret 之类的代替。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=20,M=345,mod=1e9+7;
int n,m,bin[N<<1],iv2,iv4,base;
struct Ed{
  int s,w;
  Ed(int s=0,int w=0):s(s),w(w) {}
}ed[M];
map<int,int> mp[(1<<15)+5];
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
int dfs(int S)
{
  int s0=S>>n,s1=S&base;
  if(mp[s0].count(s1))return mp[s0][s1];
  int ret=0;
  for(int i=1;i<=m;i++)
    if((ed[i].s|S)==S&&(ed[i].s<<1)>S)
      ret=(ret+(ll)ed[i].w*dfs(S^ed[i].s))%mod;
  return mp[s0][s1]=ret;
}
int main()
{
  int tp,t,s1,s2;
  n=rdn();tp=rdn();
  bin[0]=1;for(int i=1,j=n<<1;i<=j;i++)bin[i]=bin[i-1]<<1;
  iv2=pw(2,mod-2); iv4=pw(4,mod-2); base=bin[n]-1; mp[0][0]=bin[n];
  for(int i=1;i<=tp;i++)
    {
      t=rdn();s1=bin[rdn()-1+n]|bin[rdn()-1];
      ed[++m]=Ed(s1,iv2); if(!t)continue;
      s2=bin[rdn()-1+n]|bin[rdn()-1];
      ed[++m]=Ed(s2,iv2);
      if(s1&s2)continue;///
      ed[++m]=Ed(s1|s2,t==1?iv4:mod-iv4);
    }
  printf("%d
",dfs(bin[n<<1]-1));
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10255966.html