[HNOI2015]实验比较

题目描述

D 被邀请到实验室,做一个跟图片质量评价相关的主观实验。实验用到的图片集一共有 N 张图片,编号为 1 到 N。实验分若干轮进行,在每轮实验中,小 D会被要求观看某两张随机选取的图片, 然后小D 需要根据他自己主观上的判断确定这两张图片谁好谁坏,或者这两张图片质量差不多。 用符号“<”、“>”和“=”表示图片 x和y(x、y为图片编号)之间的比较:如果上下文中x 和 y 是图片编号,则 x<y 表示图片 x“质量优于”y,x>y 表示图片 x“质量差于”y,x=y表示图片 x和 y“质量相同”;也就是说,这种上下文中,“<”、“>”、“=”分别是质量优于、质量差于、质量相同的意思;在其他上下文中,这三个符号分别是小于、大于、等于的含义。图片质量比较的推理规则(在 x和y是图片编号的上下文中):(1)x < y等价于 y > x。(2)若 x < y 且y = z,则x < z。(3)若x < y且 x = z,则 z < y。(4)x=y等价于 y=x。(5)若x=y且 y=z,则x=z。 实验中,小 D 需要对一些图片对(x, y),给出 x < y 或 x = y 或 x > y 的主观判断。小D 在做完实验后, 忽然对这个基于局部比较的实验的一些全局性质产生了兴趣。在主观实验数据给定的情形下,定义这 N 张图片的一个合法质量序列为形如“x1 R1 x2 R2 x3 R3 …xN-1 RN-1 xN”的串,也可看作是集合{ xi Ri xi+1|1<=i<=N-1},其中 xi为图片编号,x1,x2,…,xN两两互不相同(即不存在重复编号),Ri为<或=,“合法”是指这个图片质量序列与任何一对主观实验给出的判断不冲突。 例如: 质量序列3 < 1 = 2 与主观判断“3 > 1,3 = 2冲突(因为质量序列中 3<1 且1=2,从而3<2,这与主观判断中的 3=2 冲突;同时质量序列中的 3<1 与主观判断中的 3>1 冲突) ,但与主观判断“2 = 1,3 < 2  不冲突;因此给定主观判断“3>1,3=2时,1<3=2 和1<2=3 都是合法的质量序列,3<1=2 和1<2<3都是非法的质量序列。由于实验已经做完一段时间了,小D 已经忘了一部分主观实验的数据。对每张图片 i,小 D 都最多只记住了某一张质量不比 i 差的另一张图片 Ki。这些小 D 仍然记得的质量判断一共有 M 条(0 <= M <= N),其中第i 条涉及的图片对为(KXi, Xi),判断要么是KXi   < Xi  ,要么是KXi = Xi,而且所有的Xi互不相同。小D 打算就以这M 条自己还记得的质量判断作为他的所有主观数据。现在,基于这些主观数据,我们希望你帮小 D 求出这 N 张图片一共有多少个不同的合法质量序列。我们规定:如果质量序列中出现“x = y”,那么序列中交换 x和y的位置后仍是同一个序列。因此: 1<2=3=4<5 和1<4=2=3<5 是同一个序列, 1 < 2 = 3 和 1 < 3 = 2 是同一个序列,而1 < 2 < 3 与1 < 2 = 3是不同的序列,1<2<3和2<1<3 是不同的序列。由于合法的图片质量序列可能很多, 所以你需要输出答案对10^9 + 7 取模的结果

题解

有一个关键条件:每个点最多有一个不比它差的点,对应到一棵树上每个节点最多只有一个父亲节点,所以这个东西本质是一颗树,我们把相同的点缩到一起,把树建出来,问题就变成了一个树上dp问题。

解法1

我们可以设dp[u][i]表示以u为根节点的子树中有i种不同的值的方案数。

转移:因为涉及到有一些元素可能会存在价值相同的情况,所以我们假设当前dp到uv子树合并时,u内不同元素有j个,v内不同元素有k个,准备合并成i个。

首先这些元素中最小的肯定在j中,而且k中不可能有和它相同的元素,所以先把那一个安排上。

然后就钦定一下这剩余的j-1个元素的位置,选择方案为C(i-1,j-1)

然后考虑v中的k个元素往那i个元素里放,此时i中还剩着i-j个空位,这些空位肯定是给那k个的,那么最后还是剩下了k-(i-j)个需要合并。

所以我们再钦定一下j中有多少个元素是用来被合并的,答案为C(j-1,k-(i-j)),可以发现这些方案都是合法的。

复杂度O(n^3)

代码

#include<iostream>
#include<cstdio>
#include<set> 
#define N 5009
using namespace std;
typedef long long ll; 
set<int>s[N]; 
ll jie[N],ni[N],g[N],dp[N][N],ans;
bool vis[N],tag;
int deep[N],size[N],tot,head[N],du[N],n,m,tott,f[N];
char opt[5];
const int mod=1e9+7;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{int n,to;}e[N];
struct node{int x,y;}b[N];
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%mod;
        x=x*x%mod;y>>=1;
    }
    return ans;
}
int find(int x){return f[x]=f[x]==x?x:find(f[x]);}
inline void add(int u,int v){
    if(s[u].find(v)!=s[u].end())return;
    e[++tot].n=head[u];e[tot].to=v;head[u]=tot;du[v]++;
    s[u].insert(v);
}
inline ll C(int n,int m){return jie[n]*ni[n-m]%mod*ni[m]%mod;}
void dfs(int u){
    size[u]=deep[u]=1;vis[u]=1;
    dp[u][1]=1;
    for(int i=head[u];i;i=e[i].n){
        int v=e[i].to;
        if(vis[v]){tag=1;continue;}
        dfs(v);
        int x=deep[u];
        deep[u]=max(deep[u],deep[v]+1);
        for(int j=deep[u];j<=size[u]+size[v];++j)
          for(int k=x;k<=size[u];++k)
            for(int l=size[v];l>=deep[v]&&l+k>=j;--l){
               (g[j]+=dp[u][k]*dp[v][l]%mod*C(j-1,k-1)%mod*C(k-1,l-(j-k))%mod)%=mod;    
               
          }
        size[u]+=size[v];
        for(int j=1;j<=size[u];++j)dp[u][j]=g[j],g[j]=0;
    }
}
int main(){
    n=rd();m=rd();
    jie[0]=1;for(int i=1;i<=n;++i)jie[i]=jie[i-1]*i%mod;ni[n]=power(jie[n],mod-2);
    for(int i=n-1;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod;
    for(int i=1;i<=n;++i)f[i]=i;
    int x,y;
    for(int i=1;i<=m;++i){
        x=rd();scanf("%s",opt);y=rd();
        if(opt[0]=='<')b[++tott]=node{x,y};
        else{
            int xx=find(x),yy=find(y);
            if(xx!=yy)f[xx]=yy;
        }
    }
    for(int i=1;i<=tott;++i){
        int x=find(b[i].x),y=find(b[i].y);
        add(x,y);
    }
    for(int i=1;i<=n;++i)if(!du[find(i)])add(0,find(i));
    dfs(0);
    for(int i=1;i<=n;++i)if(!vis[find(i)])tag=1;
    if(tag){
        puts("0");return 0;
    }
    for(int i=deep[0];i<=size[0];++i)(ans+=dp[0][i])%=mod;
    cout<<ans;
    return 0;
} 

解法2

我们仍然设dp[u][i]表示以u为根节点的子树中有i种不同的值的方案数。

我们发现上面的dp方法遇到的一个非常棘手的问题就是处理某些元素相同的情况。

然后我们加一个f[u][i]表示有不超过i种不同的值的方案数。

这个比较好dp。f[u][i]=∏f[v][i-1]。最后再把dp数组前缀和一下。

然后考虑dp[i]=∑C(i,j)f[j]

二项式反演可得f[i]=∑dp[j]*C(i,j)*-1i-j

dp和容斥复杂度都是n^2,所以总复杂度为O(n^2)

代码

#include<iostream>
#include<cstdio>
#include<set> 
#include<cstring>
#define N 5009
using namespace std;
typedef long long ll; 
set<int>s[N]; 
ll jie[N],ni[N],g[N],dp[N][N],ans;
bool vis[N],tag;
int deep[N],size[N],tot,head[N],du[N],n,m,tott,f[N];
char opt[5];
const int mod=1e9+7;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{int n,to;}e[N];
struct node{int x,y;}b[N];
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%mod;
        x=x*x%mod;y>>=1;
    }
    return ans;
}
int find(int x){return f[x]=f[x]==x?x:find(f[x]);}
inline void add(int u,int v){
    if(s[u].find(v)!=s[u].end())return;
    e[++tot].n=head[u];e[tot].to=v;head[u]=tot;du[v]++;
    s[u].insert(v);
}
inline ll C(int n,int m){return jie[n]*ni[n-m]%mod*ni[m]%mod;}
void dfs(int u){
    size[u]=deep[u]=1;vis[u]=1;
    for(int i=1;i<=n+1;++i)dp[u][i]=1;
    for(int i=head[u];i;i=e[i].n){
        int v=e[i].to;
        if(!vis[v])dfs(v);
        int x=deep[u];
        deep[u]=max(deep[u],deep[v]+1);
        size[u]+=size[v]; 
        for(int j=1;j<=n+1;++j)
            (dp[u][j]*=dp[v][j-1])%=mod;       
    }
    for(int i=1;i<=n+1;++i)(dp[u][i]+=dp[u][i-1])%=mod;
}
int main(){
    n=rd();m=rd();
    jie[0]=1;for(int i=1;i<=n+1;++i)jie[i]=jie[i-1]*i%mod;ni[n+1]=power(jie[n+1],mod-2);
    for(int i=n;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod;
    for(int i=1;i<=n;++i)f[i]=i;
    int x,y;
    for(int i=1;i<=m;++i){
        x=rd();scanf("%s",opt);y=rd();
        if(opt[0]=='<')b[++tott]=node{x,y};
        else{
            int xx=find(x),yy=find(y);
            if(xx!=yy)f[xx]=yy;
        }
    }
    for(int i=1;i<=tott;++i){
        int x=find(b[i].x),y=find(b[i].y);
        add(x,y);
    }
    for(int i=1;i<=n;++i)if(!du[find(i)])add(0,find(i));
    dfs(0);
    for(int i=1;i<=n;++i)if(!vis[find(i)])tag=1;
    if(tag){puts("0");return 0;}
    for(int i=1;i<=n+1;++i){
      for(int j=0;j<=i;++j)
        if((i-j)&1)ans=(ans-C(i,j)*dp[0][j]%mod+mod)%mod;
        else (ans+=C(i,j)*dp[0][j]%mod)%=mod;
    }
    cout<<ans;
    return 0;
} 
原文地址:https://www.cnblogs.com/ZH-comld/p/10360077.html