题解-HEOI2013 SAO

容易发现这个图形成一个树形结构,虽然边是有向的,但我们不妨把他看成无向的然后跑树形dp。

(f[u,i]) 代表在 (u) 的子树内 (u)(i) 名。然后转移考虑合并子树。枚举 (i,j,k) 表示原来 (u)(i) 名,(v)(j) 名,然后 (u) 在加入了 (v) 后变成了 (k) 名。于是就是从 (k-1) 名中选 (i-1) 名,后面同理,有转移 (f'[u,k]=sum_isum_j {k-1choose i-1}{sz_u+sz_v-kchoose sz_u-i}f[u,i]f[v,j])。考虑 (k) 的范围,这里就会有两种情况;

  1. (v) 一定要在 (u) 后,这时显然只需 (kle i-1+j) 即可。

  2. (v) 一定要在 (u) 前,这时只需 (kge i+j) 即可。

然后直接硬转移是 (O(n^3)) 的,其实发现 (j) 这一部分是独立的,直接前缀和优化掉即可。

#include<bits/stdc++.h>
#define mp make_pair
#define pb push_back
typedef long long ll;
using namespace std;
const int maxn=1005;
const ll mod=1e9+7;
template<typename T>
void read(T &x){
	T flag=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar())if(ch=='-')flag=-1;
	for(x=0;isdigit(ch);ch=getchar())x=x*10+ch-'0';
	x*=flag;
}
int T,n;
ll f[maxn][maxn],sum[maxn][maxn],tmp[maxn];
string s;
int sz[maxn];
vector<pair<int,int>>vec[maxn];
ll fac[maxn],ifac[maxn];
ll ksm(ll a,ll b){
	ll ret=1;
	for(;b;b>>=1,a=a*a%mod)if(b&1)ret=ret*a%mod;
	return ret;
}
ll C(int n,int m){
	if(n<0||m<0||n<m)return 0;
	return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void dfs(int u,int fa){
	sz[u]=1;
	f[u][1]=1;
	for(auto x:vec[u]){
		int v=x.first;
		if(v==fa)continue;
		dfs(v,u);
		for(int i=1;i<=sz[u];i++)tmp[i]=f[u][i],f[u][i]=0;
		if(x.second==0){
			for(int i=1;i<=sz[u];i++){
				for(int k=i;k<sz[v]+i;k++){
					f[u][k]=(f[u][k]+tmp[i]*C(k-1,i-1)%mod*C(sz[u]+sz[v]-k,sz[u]-i)%mod*((sum[v][sz[v]]-sum[v][k-i]+mod)%mod)%mod)%mod;
				}
			}
		}else{
			for(int i=1;i<=sz[u];i++){
				for(int k=sz[u]+sz[v];k>i;k--){
					f[u][k]=(f[u][k]+tmp[i]*C(k-1,i-1)%mod*C(sz[u]+sz[v]-k,sz[u]-i)%mod*((sum[v][k-i]))%mod)%mod;
				}
			}
		}
		sz[u]+=sz[v];
	}
	for(int i=1;i<=sz[u];i++)sum[u][i]=(sum[u][i-1]+f[u][i])%mod;
}
void solve(){
	memset(sum,0,sizeof(sum));
	memset(f,0,sizeof(f));
	read(n);
	for(int i=1;i<=n;i++)vec[i].clear();
	for(int i=1,u,v;i<n;i++){
		read(u);cin>>s;read(v);
		u++;v++;
		if(s=="<")vec[u].pb(mp(v,0)),vec[v].pb(mp(u,1));
		else vec[u].pb(mp(v,1)),vec[v].pb(mp(u,0));
	}
	dfs(1,0);
	printf("%lld
",sum[1][n]);
}
int main(){
	fac[0]=ifac[0]=1;
	for(int i=1;i<=1000;i++)fac[i]=fac[i-1]*i%mod;
	ifac[1000]=ksm(fac[1000],mod-2);
	for(int i=999;i>=1;i--)ifac[i]=ifac[i+1]*(i+1)%mod;
	read(T);
	while(T--)solve();
	return 0;
}
原文地址:https://www.cnblogs.com/zcr-blog/p/15032279.html