loj2538 「PKUWC2018」Slay the Spire 【dp】

题目链接

loj2538

题解

比较明显的是,由于强化牌倍数大于(1),肯定是能用强化牌尽量用强化牌
如果强化牌大于等于(k),就留一个位给攻击牌

所以我们将两种牌分别排序,企图计算(F(i,j))表示(i)张强化牌选出最强的(j)张的所有方案的倍数和
(G(i,j))表示从(i)张攻击牌选出最强(j)张的所有方案的伤害和

那么

[ans = sumlimits_{i = 0}^{k - 1} F(i,i)G(m - i,k - i) + sumlimits_{i = k}^{m} F(i,k - 1)G(m - i,1) ]

所以我们只需计算出(F)(G)
(F)为例,我们枚举选出最后一张牌是什么
那么设(f[i][j])表示用了(i)张强化牌,最后一张是(j)的倍数和
同样设(g[i][j])表示用了(i)张攻击牌,最后一张是(j)的伤害和
那么有

[f[i][j] = w_jsumlimits_{x = 0}^{j - 1}f[i - 1][x] ]

[g[i][j] = w_j{j - 1 choose i - 1} + sumlimits_{x = 0}^{j - 1}g[i - 1][x] ]

可用前缀和优化为(O(n^2))

那么

[F(x,y) = sumlimits_{i = 0}^{n}f[y][i]{n - i choose x - y} ]

[G(x,y) = sumlimits_{i = 0}^{n}g[y][i]{n - i choose x - y} ]

总复杂度(O(Tn^2))

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
using namespace std;
const int maxn = 3005,maxm = 100005,INF = 1000000000,P = 998244353;
inline int read(){
	int out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
int n,m,K,fac[maxn],fv[maxn],inv[maxn];
int w1[maxn],w2[maxn];
int f[maxn][maxn],g[maxn][maxn];
int s[maxn];
void init(){
	fac[0] = fac[1] = inv[0] = inv[1] = fv[0] = fv[1] = 1;
	for (int i = 2; i <= 3000; i++){
		fac[i] = 1ll * fac[i - 1] * i % P;
		inv[i] = 1ll * (P - P / i) * inv[P % i] % P;
		fv[i] = 1ll * fv[i - 1] * inv[i] % P;
	}
}
inline int C(int n,int m){
	if (n < m) return 0;
	return 1ll * fac[n] * fv[m] % P * fv[n - m] % P;
}
inline int F(int x,int y){
	if (x > n || x < y) return 0;
	int re = 0;
	for (int i = 0; i <= n; i++)
		re = (re + 1ll * f[y][i] * C(n - i,x - y) % P) % P;
	return re;
}
inline int G(int x,int y){
	if (x > n || x < y) return 0;
	int re = 0;
	for (int i = 0; i <= n; i++)
		re = (re + 1ll * g[y][i] * C(n - i,x - y) % P) % P;
	return re;
}
inline bool cmp(const int& a,const int& b){
	return a > b;
}
void work(){
	sort(w1 + 1,w1 + 1 + n,cmp);
	sort(w2 + 1,w2 + 1 + n,cmp);
	for (int i = 0; i <= n; i++)
		for (int j = 0; j <= n; j++)
			f[i][j] = g[i][j] = 0;
	f[0][0] = 1;
	s[0] = 1;
	for (int i = 1; i <= n; i++) s[i] = s[i - 1];
	for (int i = 1; i <= n; i++){
		for (int j = i; j <= n; j++){
			f[i][j] = 1ll * w1[j] * s[j - 1] % P;
		}
		for (int j = 0; j < i; j++) s[j] = 0;
		for (int j = i; j <= n; j++) s[j] = (s[j - 1] + f[i][j]) % P;
	}
	s[0] = 0;
	for (int i = 1; i <= n; i++) s[i] = s[i - 1];
	for (int i = 1; i <= n; i++){
		for (int j = i; j <= n; j++){
			g[i][j] = (1ll * w2[j] * C(j - 1,i - 1) % P + s[j - 1]) % P;
		}
		for (int j = 0; j < i; j++) s[j] = 0;
		for (int j = i; j <= n; j++) s[j] = (s[j - 1] + g[i][j]) % P;
	}
	int ans = 0;
	for (int i = 0; i <= min(n,m); i++){
		int j = m - i; if (j < 0 || j > n) continue;
		if (i < K){
			ans = (ans + 1ll * F(i,i) * G(j,K - i) % P) % P;
		}
		else{
			ans = (ans + 1ll * F(i,K - 1) * G(j,1) % P) % P;
		}
	}
	printf("%d
",ans);
}
int main(){
	init();
	int T = read();
	while (T--){
		n = read(); m = read(); K = read();
		REP(i,n) w1[i] = read();
		REP(i,n) w2[i] = read();
		work();
	}
	return 0;
}

原文地址:https://www.cnblogs.com/Mychael/p/9216259.html