题目链接:https://codeforces.ml/gym/102832/problem/B
官方题解:https://zhuanlan.zhihu.com/p/279287505
参考了用户 smallling 的提交。
题目大意
给定 (n) 个数,数的范围为 ([1,m)),共 (d) 次操作。
-
每秒在区间 ([L,R]) 中选定 (k) 个数进行诅咒,被诅咒的数不变,未被诅咒的数会 (+1),问最多多少秒所有数都 (< m)。
-
将第 (u) 个数变为 (v)。
思路1:树套树
对于第一种,显然每秒都要诅咒最大的数。假设最多能维持 (t) 秒,那么将式子转化为 (sum_{i=L}^{R} max(t+a_{i}-(m-1), 0) leq k * t),其中 (t+a_{i})是维持 (t) 秒后的大小,减去 (m-1) 代表需要对这个数进行诅咒的次数,要保证总诅咒次数 (leq k*t)。
记 ((m-1)-a_{i}) 为 (x_{i}),则将原先的式子转化为 (sum_{i=L}^{R} (t-x_{i})*[t geq x_{i}])。
假设在 ([L, R]) 中有 (y) 个数满足 (leq t),那么式子可以转化为 (y*t - sum_{x_{i} leq t} x_{i}leq k*t)。
移动式子,可得 ((y-k)*t leq sum_{x_{i} leq t} x_{i})。
若将 (x_{i}) 的值排序后建权值线段树,可以同时求出 (sum_{x_{i} leq t} x_{i}) 和 (y),在权值线段树上的每个node建区间平衡树,能够维护位置满足 ([L,R]) 的条件。
现在在权值线段树上二分,判断是否要往左儿子走。
如果要取整个左儿子,那么此时 (t geq x_{mid}),判断 ((y-k)*x_{mid} leq sum_{x_{i} leq t} x_{i}) 是否成立,若成立则往右边走,否则往左边走。
再对式子 ((y-k)*t leq sum_{x_{i} leq t} x_{i}) 进行移动操作,可得 (t leq frac{sum_{x_{i} leq t} x_{i}}{y-k}),向下取整可得答案。
note:在建立权值线段树时可以在右端新建一个权值为无穷大的节点保证之前的节点都能被考虑到,体现在代码中为 tree.build(1, 1, Discrete::blen + 1);
AC代码1:树套树
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define llinf 0x3f3f3f3f3f3f3f3f
#define mp make_pair
#define pii pair<int, int>
#define vi vector<int>
#define fi first
#define se second
#define pb push_back
#define SZ(x) (int)x.size()
#define ull unsigned long long
#define pll pair<ll, ll>
using namespace std;
const int MAXN = 2e5 + 5;
namespace Discrete {
int b[MAXN], btol, blen;
void insert(int x) {
b[btol++] = x;
}
void bi() {
sort(b, b + btol);
blen = unique(b, b + btol) - b;
}
int val2id(int x) {
return lower_bound(b, b + blen, x) - b + 1;
}
int id2val(int x) {
return b[x - 1];
}
}
using Discrete::val2id;
using Discrete::id2val;
class DS {
public:
// Treap begin
int ch[MAXN * 20][2], dat[MAXN* 20], siz[MAXN* 20], pos[MAXN* 20];
ll val[MAXN* 20], sum[MAXN* 20];
int tot;
int pool[MAXN* 20], pool_cnt;
void init() {
tot = 0, pool_cnt = 0;
}
inline int Newid() {
return pool_cnt ? pool[pool_cnt--] : ++tot;
}
inline void Delid(int &rt) {
if (!rt) return;
pool[++pool_cnt] = rt;
dat[rt] = siz[rt] = pos[rt] = val[rt] = sum[rt] = 0;
ch[rt][0] = ch[rt][1] = 0;
rt = 0;
}
inline int New_treapnode(int v1, ll v2) {
int nrt = Newid();
ch[nrt][0] = ch[nrt][1] = 0, dat[nrt] = rand(), siz[nrt] = 1, pos[nrt] = v1, val[nrt] = sum[nrt] = v2;
return nrt;
}
inline void push_up(int rt) {
siz[rt] = siz[ch[rt][0]] + siz[ch[rt][1]] + 1;
sum[rt] = sum[ch[rt][0]] + sum[ch[rt][1]] + val[rt];
}
void split(int rt, int vp, int &x, int &y) {
if (!rt) x = y = 0;
else {
if (pos[rt] <= vp) {
x = rt;
split(ch[rt][1], vp, ch[rt][1], y);
} else {
y = rt;
split(ch[rt][0], vp, x, ch[rt][0]);
}
push_up(rt);
}
}
int merge(int x, int y) {
if (!x || !y) return x + y;
if (dat[x] < dat[y]) {
ch[x][1] = merge(ch[x][1], y);
push_up(x);
return x;
} else {
ch[y][0] = merge(x, ch[y][0]);
push_up(y);
return y;
}
}
// Treap end
struct node {
int l, r;
int root;
} T[MAXN << 2];
void build(int rt, int l, int r) {
T[rt].l = l, T[rt].r = r;
if (l == r) {
return;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid), build(rt << 1 | 1, mid + 1, r);
}
void Seg_insert(int rt, int seg_pos, int vpos, ll v) {
if (T[rt].root == 0) {
T[rt].root = New_treapnode(vpos, v);
} else {
int x, y;
split(T[rt].root, vpos, x, y);
int nz = New_treapnode(vpos, v);
T[rt].root = merge(x, merge(nz, y));
}
if (T[rt].l == T[rt].r) return;
int mid = (T[rt].l + T[rt].r) >> 1;
if (seg_pos <= mid) Seg_insert(rt << 1, seg_pos, vpos, v);
else Seg_insert(rt << 1 | 1, seg_pos, vpos, v);
}
void Seg_erase(int rt, int seg_pos, int vpos) {
int x, y, z, trt;
split(T[rt].root, vpos - 1, x, trt);
split(trt, vpos, y, z);
Delid(y);
T[rt].root = merge(x, z);
if (T[rt].l == T[rt].r) return;
int mid = (T[rt].l + T[rt].r) >> 1;
if (seg_pos <= mid) Seg_erase(rt << 1, seg_pos, vpos);
else Seg_erase(rt << 1 | 1, seg_pos, vpos);
}
ll query(int l, int r, int k) {
auto query_node = [&](int rt) {
int x, y, z, trt;
split(T[rt].root, l - 1, x, trt);
split(trt, r, y, z);
pair<ll, ll> ans = mp(siz[y], sum[y]);
T[rt].root = merge(x, merge(y, z));
return ans;
};
ll nk = 0, nsum = 0;
function<void(int)> qs = [&](int rt) {
if (T[rt].l == T[rt].r) {
return;
}
int mid = (T[rt].l + T[rt].r) >> 1;
pair<ll, ll> inforL = query_node(rt << 1);
if ((nk + inforL.first) < k ||
(nk + inforL.first - k) * id2val(mid) <= nsum + inforL.second) {
nk += inforL.first, nsum += inforL.second;
qs(rt << 1 | 1);
} else {
qs(rt << 1);
}
};
qs(1);
return nsum / (nk - k);
}
} tree;
int qs[MAXN][4], a[MAXN];
int main() {
tree.init();
int n, m, d;
scanf("%d%d%d", &n, &m, &d);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), Discrete::insert((m - 1) - a[i]);
for (int i = 1; i <= d; i++) {
int opt;
scanf("%d", &opt);
if (opt == 1) {
qs[i][0] = 1;
scanf("%d%d%d", &qs[i][1], &qs[i][2], &qs[i][3]);
} else {
qs[i][0] = 2;
scanf("%d%d", &qs[i][1], &qs[i][2]);
Discrete::insert((m - 1) - qs[i][2]);
}
}
Discrete::bi();
tree.build(1, 1, Discrete::blen + 1);
for (int i = 1; i <= n; i++) {
int po = val2id((m - 1) - a[i]);
tree.Seg_insert(1, po, i, (m - 1) - a[i]);
}
for (int i = 1; i <= d; i++) {
if (qs[i][0] == 1) {
printf("%lld
", tree.query(qs[i][1], qs[i][2], qs[i][3]));
} else {
tree.Seg_erase(1, val2id((m - 1) - a[qs[i][1]]), qs[i][1]);
a[qs[i][1]] = qs[i][2];
tree.Seg_insert(1, val2id((m - 1) - a[qs[i][1]]), qs[i][1], (m - 1) - a[qs[i][1]]);
}
}
}
思路2:整体二分
同上面的式子,在转换右边时记录已经确定一定会小于 (t) 的个数和 (x) 之和。
注意在判断时会炸long long,需要使用__int128。
树套树跑了4695ms,整体二分1029ms。
AC代码2:整体二分
#include <bits/stdc++.h>
#define ll long long
#define vi vector<int>
#define mp make_pair
#define inf 0x3f3f3f3f
#define llinf 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define SZ(x) (int)x.size()
#define pb push_back
#define fi first
#define se second
using namespace std;
const int MAXN = 1e5 + 5;
class BIT {
public:
ll val[MAXN], n;
void init(int _n) {
n = _n;
for (int i = 1; i <= n; i++) val[i] = 0;
}
inline int lowbit(int x) {
return x & (-x);
}
void add(int pos, ll x) {
for (int i = pos; i <= n; i += lowbit(i)) val[i] += x;
}
ll query(int pos) {
ll ans = 0;
for (int i = pos; i >= 1; i -= lowbit(i)) ans += val[i];
return ans;
}
ll query(int l, int r) {
return query(r) - query(l - 1);
}
} bit_num, bit_sum;
struct Query {
int l, r, k, id, type;
ll pre_num, pre_sum;
} q[MAXN * 3], q1[MAXN * 3], q2[MAXN * 3];
ll a[MAXN], res[MAXN];
int ty[MAXN];
int main() {
int n, d;
ll m;
scanf("%d%lld%d", &n, &m, &d);
int qcnt = 0;
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
for (int i = 1; i <= n; i++) {
q[++qcnt] = {i, i, (int) ((m - 1) - a[i]), 0, 1, 0, 0};
}
for (int i = 1; i <= d; i++) {
int opt;
scanf("%d", &opt);
if (opt == 1) {
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
q[++qcnt] = {l, r, k, i, 2, 0, 0};
ty[i] = 1;
} else {
int x;
ll y;
scanf("%d%lld", &x, &y);
q[++qcnt] = {x, x, (int) ((m - 1) - a[x]), 0, 3, 0, 0};
a[x] = y;
q[++qcnt] = {x, x, (int) ((m - 1) - a[x]), 0, 1, 0, 0};
}
}
bit_num.init(n), bit_sum.init(n);
function<void(ll, ll, int, int)> solve = [&](ll l, ll r, int ql, int qr) {
if (ql > qr) return;
if (l == r) {
for (int i = ql; i <= qr; i++) {
if (q[i].type == 2) res[q[i].id] = l;
}
return;
}
ll mid = (l + r) >> 1;
int cnt1 = 0, cnt2 = 0;
for (int i = ql; i <= qr; i++) {
if (q[i].type == 1) {
if (q[i].k <= mid) {
bit_num.add(q[i].l, 1), bit_sum.add(q[i].l, q[i].k);// bit.add(q[i].l, (mid - q[i].k));
q1[++cnt1] = q[i];
} else {
q2[++cnt2] = q[i];
}
} else if (q[i].type == 3) {
if (q[i].k <= mid) {
bit_num.add(q[i].l, -1), bit_sum.add(q[i].l, -q[i].k);//bit.add(q[i].l, (ll) q[i].k - mid);
q1[++cnt1] = q[i];
} else {
q2[++cnt2] = q[i];
}
} else {
__int128 sum = (__int128)(bit_num.query(q[i].l, q[i].r) + q[i].pre_num) * mid -
(__int128)(bit_sum.query(q[i].l, q[i].r) + q[i].pre_sum);
if (sum <= (__int128)mid * q[i].k) {
q[i].pre_num += bit_num.query(q[i].l, q[i].r);
q[i].pre_sum += bit_sum.query(q[i].l, q[i].r);
q2[++cnt2] = q[i];
} else q1[++cnt1] = q[i];
}
}
for (int i = 1; i <= cnt1; i++) {
if (q1[i].type == 1) bit_num.add(q1[i].l, -1), bit_sum.add(q1[i].l, -q1[i].k);
else if (q1[i].type == 3) bit_num.add(q1[i].l, 1), bit_sum.add(q1[i].l, q1[i].k);
}
for (int i = 1; i <= cnt1; i++) q[ql + i - 1] = q1[i];
for (int i = 1; i <= cnt2; i++) q[ql + cnt1 + i - 1] = q2[i];
solve(l, mid, ql, ql + cnt1 - 1);
solve(mid + 1, r, ql + cnt1, qr);
};
solve(0, 1e15, 1, qcnt);
for (int i = 1; i <= d; i++) {
if (ty[i]) printf("%lld
", res[i]-1);
}
}
/*
5 10 2
6 1 4 2 3
1 1 5 3
1 3 5 1
*/