Educational Codeforces Round 89 (Rated for Div. 2)

Description

给一个长度为(n)(a)数组与一个长度为(m)(b)数组,求把(a)数组划分为(m)段使得对每个(i)都有第(i)段最小值为(b_i)的方案数((mod) (998244353)

Solution

(f[i])表示(a)数组中划分到第(i)位(只考虑(a_i)(b)中某元素相等的(i)),(a_i)(b_k)相等,(a_i)为第(k)段最小值。
转移是(f[i]+=f[j]*calc(j,i)),(a_j=b_{k-1}),(calc(j,i))计算的是将([j+1,i-1])中的元素分为两部分,满足前一部分属于第(k-1)段,后一部分属于第(k)段的可行方案数(仍保证第(k-1)段最小值为(a_j),第(k)段最小值为(a_i)
发现方程中的(j)位置只需要取最靠后的满足(a_j=b_{k-1})的即可,因为在把([j+1,i-1])分为两半时,若中间有位置(m)满足(a_m=a_j),由于(a_m=a_j<a_i),位置(m)必被归类于第(k-1)段,那完全可以直接用位置(m)进行转移
(好吧我承认思路有那么一点点奇怪(讲得似乎也有那么一点点奇怪),对比正解存在一定冗余。正解好像是(O(n)),我做法中离散化、预处理(ST)表都为(nlogn),计算(calc)用的是倍增或二分,计算一次是(logn)的复杂度,总复杂度(O(nlogn))。)

Code

#include <bits/stdc++.h>
 
#define Mod 998244353
 
using namespace std;
 
typedef long long ll;
 
 
inline int read() {
	int out = 0;
	bool flag = false;
	register char cc = getchar();
	while (cc < '0' || cc > '9') {
		if (cc == '-') flag = true;
		cc = getchar();
	}
	while (cc >= '0' && cc <= '9') {
		out = (out << 3) + (out << 1) + (cc ^ 48);
		cc = getchar();
	}
	return flag ? -out : out;
} 
 
inline void write(int x) {
	if (x < 0) putchar('-'), x = -x;
	if (x == 0) putchar('0');
	else {
		int num = 0;
		char cc[20];
		while (x) cc[++num] = x % 10 + 48, x /= 10;
		while (num) putchar(cc[num--]);
	}
	putchar(' ');
}
 
 
int n, m, a[200010], b[200010], c[400010], pre[200010], lst[200010], tot, Log[200010], Min[20][200010], f[200010];
 
inline int MIN(const int &l, const int &r) {
	int t = Log[r - l + 1];
	if (Min[t][l] < Min[t][r - (1 << t) + 1]) return Min[t][l];
	else return Min[t][r - (1 << t) + 1];
}
 
 
 
inline int calc(const int &l, const int &r) {
	int x = l, y = r;
	for (int i = 18; i >= 0; i--)
		if (x + (1 << i) < r && MIN(l, x + (1 << i)) >= a[l]) x += 1 << i;
	for (int i = 18; i >= 0; i--)
		if (y - (1 << i) > l && MIN(y - (1 << i), r) >= a[r]) y -= 1 << i;
	if (x < y - 1) return 0;
	return x - y + 2;
}
 
 
int main() {
	n = read(), m = read();
	for (int i = 1; i <= n; i++) c[++tot] = a[i] = read();
	for (int i = 1; i <= m; i++) c[++tot] = b[i] = read();
	sort(c + 1, c + tot + 1);
	tot = unique(c + 1, c + tot + 1) - c - 1;
	for (int i = 1; i <= n; i++) a[i] = lower_bound(c + 1, c + tot + 1, a[i]) - c;
	for (int i = 1; i <= m; i++) b[i] = lower_bound(c + 1, c + tot + 1, b[i]) - c;
	for (int i = 2; i <= n; i++) Log[i] = Log[i >> 1] + 1;
	for (int i = 1; i <= n; i++) Min[0][i] = a[i];
	for (int k = 1; (1 << k) <= n; k++) {
		for (int i = 1; i + (1 << k) - 1 <= n; i++) {
			if (Min[k - 1][i] < Min[k - 1][i + (1 << (k - 1))]) Min[k][i] = Min[k - 1][i];
			else Min[k][i] = Min[k - 1][i + (1 << (k - 1))];
		}
	}
	for (int i = 1; i <= m; i++) lst[b[i]] = b[i - 1];
	int o = INT_MAX;
	for (int i = 1; i <= n; i++) {
		o = min(o, a[i]);
		if (a[i] == b[1] && o == a[i]) f[i] = 1;
	}
	for (int i = 1; i <= n; i++) {
		//cout << lst[a[i]] << ' ' << pre[lst[a[i]]] << endl;
		if (!f[i]) f[i] = 1ll * f[pre[lst[a[i]]]] * calc(pre[lst[a[i]]], i) % Mod;
		//cout << f[i] << endl;
		pre[a[i]] = i;
	}
	int ans = 0;
	for (int i = 1; i <= n; i++) if (a[i] == b[m] && MIN(i, n) >= a[i]) {
		ans = f[i];
		//if (ans >= Mod) ans -= Mod;
	}
	cout << ans << endl;
	return 0;
} 
原文地址:https://www.cnblogs.com/Urushibara-Ruka/p/13121615.html