BJOI2017 树的难题

落谷Loj

Description

给定 (n) 个点的无根树。(m) 种颜色,每种颜色权值为 (c_i)

定义树上路径权值为路径颜色序列,将其分为每一段极大的相同颜色序列,每一段颜色设为 (i),权值即 (sum c_i)

求边数在 ([l, r]) 范围的简单路径中路径权值最大值。

Solution

一般来说边数在 ([l, r]) 的一些树上信息很容易想到就是点分治了。

(c(x)) 从根到 (x) 的路径上第一条边的颜色,(w(x)) 为从根到 (x) 这条路径权值, (d(x)) 为根到 (x) 经过的边数。

考虑合并两个到根的链的点 (x, y)

  • (c(x) ot = c(y)),那么答案就是 (w(x) + w(y))
  • (c(x) = c(y)),那么要减掉重复的颜色 (w(x) + w(y) - c_{w(x)})

首先先来解决第一种情况 (c(x) ot= c(y))

显然两个点 (x, y),若 (d(x) = d(y))(w(x) > w(y)),那么 (y) 就没用了。因为把 (y) 换成 (x) 会更好。所以不妨存个桶,(b_i) 表示经过边数为 (i) 的点的最大 (w)

那么对于 (d(x) = i) 来说,他寻求拼合的另外一条链的边数在 ([l - i, r - i]),即我们要求这个区间的 (max(b_i))。发现当 (i) 从小到大循环的过程中,这个区间左右端点都是递减的,即一个滑动窗口问题。

然后考虑加入第二种情况:这个东西解决起来不难想,就是把根所连接的所有子树联通快按他们之间连的那条边排一下序,这样同一个颜色就肯定在一个区间了。所以不同颜色求一遍答案,相同颜色求一遍答案再整体减去重复的颜色即可。注意当遍历到新的颜色的时候,要把之前的相同颜色合并到不同颜色里。

滑动窗口这个东西可以用线段树 (/) 单调队列来做,但是你发现第一次查询是 (O(r - l)) 复杂度的,最坏情况下暴力扫可能会被卡成 (O(n^2)) 的,所以需要一种特殊技巧(传说叫单调队列按秩合并)。

由于这是一道码农题,所以我锻炼一下写一下两种做法。


算法 1 线段树

维护两个权值线段树,一个为异色联通块,一个为同色联通块。

对于每个联通块,先查询。查询完整体塞到同色联通块里。

碰到新的颜色,即把同色的合并到异色里,线段树合并即可。

时间复杂度

(O(n log ^ 2 n))

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;

const int N = 200005, INF = 2e9;

int n, m, L, R, c[N], ans = -INF;
int maxPart, rt, now, S, sz[N], d[N], maxDep;
int head[N], numE = 0, tot;
bool vis[N];

struct E{
	int next, v, w;
} e[N << 1];

struct Son{
	int x, c;
	bool operator < (const Son &b) const {
		return c < b.c;
	}
} sons[N];

struct T{
	int l, r, dat;
};

struct SegTree{
	int rt0, rt1, dat[N], idx;
	T t[N * 20];

	void change(int &p, int l, int r, int x, int c) {
		if (!p) t[p = ++idx] = (T) { 0, 0, -INF };
		t[p].dat = max(t[p].dat, c);
		if (l == r) return;
		int mid = (l + r) >> 1;
		if (x <= mid) change(t[p].l, l, mid, x, c);
		else change(t[p].r, mid + 1, r, x, c);
	}

	int query(int p, int l, int r, int x, int y) {
		if (x > y) return -INF;
		if (!p) return -INF;
		if (x <= l && r <= y) return t[p].dat;
		int mid = (l + r) >> 1, res = -INF;
		if (x <= mid) res = max(res, query(t[p].l, l, mid, x, y));
		if (mid < y) res = max(res, query(t[p].r, mid + 1, r, x, y));
		return res;
	}

	// 把 q 合并到 p 上
	void merge(int &p, int &q, int l, int r) {
		if (!p) { p = q; return; } 
		if (!q) return;
		t[p].dat = max(t[p].dat, t[q].dat);
		if (l == r) return;
		int mid = (l + r) >> 1;
		merge(t[p].l, t[q].l, l, mid);
		merge(t[p].r, t[q].r, mid + 1, r);
	}
} t;

void inline add(int u, int v, int w) {
	e[++numE] = (E) { head[u], v, w };
	head[u] = numE;
}

void getRoot(int u, int last) {
	sz[u] = 1;
	int s = 0;
	for (int i = head[u]; i; i = e[i].next) {
		int v = e[i].v;
		if (vis[v] || v == last) continue;
		getRoot(v, u);
		sz[u] += sz[v];
		s = max(s, sz[v]);
	}
	s = max(s, S - sz[u]);
	if (s < maxPart) maxPart = s, rt = u;
}

void dfs(int u, int last, int col, int w, int dep) {
	maxDep = max(maxDep, dep);
	if (L <= dep && dep <= R) ans = max(ans, w);
	ans = max(ans, w + t.query(t.rt0, 1, n, max(L - dep, 1), min(R - dep, n)));
	ans = max(ans, w + t.query(t.rt1, 1, n, max(L - dep, 1), min(R - dep, n)) - now);
	d[dep] = max(d[dep], w);
	for (int i = head[u]; i; i = e[i].next) {
		int v = e[i].v;
		if (v == last || vis[v]) continue;
		dfs(v, u, e[i].w, w + (col == e[i].w ? 0 : c[e[i].w]), dep + 1);
	}
}

void solve(int x) {
	if (S == 1) return;
	maxPart = 2e9, getRoot(x, 0), vis[rt] = true;
	tot = 0; t.idx = t.rt0 = t.rt1 = 0;
	for (int i = head[rt]; i; i = e[i].next) {
		int v = e[i].v;
		if (vis[v]) continue;
		sons[++tot] = (Son) { v, e[i].w };
	}
	sort(sons + 1, sons + 1 + tot);
	for (int i = 1; i <= tot; i++) {
		maxDep = 0, now = c[sons[i].c];
		dfs(sons[i].x, rt, sons[i].c, c[sons[i].c], 1);
		for (int j = 1; j <= maxDep; j++) {
			t.change(t.rt1, 1, n, j, d[j]);
			d[j] = -INF;
		}
		if (i < tot && sons[i].c != sons[i + 1].c) {
			t.merge(t.rt0, t.rt1, 1, n);
			t.rt1 = 0;
		}
	}
	for (int i = head[rt]; i; i = e[i].next) {
		int v = e[i].v;
		if (vis[v]) continue;
		S = sz[v], solve(v);
	}
}

int main() {
	scanf("%d%d%d%d", &n, &m, &L, &R);
	for (int i = 1; i <= n; i++) d[i] = -2e9;
	for (int i = 1; i <= m; i++) scanf("%d", c + i);
	for (int i = 1, u, v, w; i < n; i++) {
		scanf("%d%d%d", &u, &v, &w);
		add(u, v, w); add(v, u, w);
	}
	S = n;
	solve(1);
	printf("%d
", ans);
	return 0;
}

算法 2 单调队列

这个按秩合并非常神奇。具体的排序顺序是这样的:

  • 把相同颜色的构成一个区间。
  • 不同颜色顺序按每个颜色块里的最大深度排序
  • 相同颜色之间按照最大深度排序

然后查询的时候,分类讨论:

  • 两个链属于不同颜色,整体拿出一个块来查询。由于颜色递增,所以第一次插入不会超过 (O(maxDep))
  • 属于相同颜色,查询。由于颜色递增,所以第一次插入不会超过 (O(maxDep))

时间复杂度

考虑这个 (sort) ,以每个点为根最多一次,所以排序总复杂度是 (O(n log n)) 的,所以总复杂度是 (O(n log n))

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;

const int N = 200005, INF = 2e9;

int n, m, L, R, c[N], ans = -INF;
int maxPart, rt, now, S, sz[N], d[N], maxDep, mxDep[N];
int head[N], numE = 0, tot, val[N], nowVal[N], zDep, q[N];
bool vis[N];

struct E {
    int next, v, w;
} e[N << 1];

struct Son {
    int x, c, d;
    bool operator<(const Son &b) const {
        if (c != b.c)
            return mxDep[c] < mxDep[b.c];
        else
            return d < b.d;
    }
} sons[N];

struct T {
    int l, r, dat;
};

void inline add(int u, int v, int w) {
    e[++numE] = (E){ head[u], v, w };
    head[u] = numE;
}

void getRoot(int u, int last) {
    sz[u] = 1;
    int s = 0;
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        if (vis[v] || v == last)
            continue;
        getRoot(v, u);
        sz[u] += sz[v];
        s = max(s, sz[v]);
    }
    s = max(s, S - sz[u]);
    if (s < maxPart)
        maxPart = s, rt = u;
}

void dfs(int u, int last, int col, int w, int dep) {
    maxDep = max(maxDep, dep);
    if (L <= dep && dep <= R)
        ans = max(ans, w);
    d[dep] = max(d[dep], w);
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        if (v == last || vis[v])
            continue;
        dfs(v, u, e[i].w, w + (col == e[i].w ? 0 : c[e[i].w]), dep + 1);
    }
}

void dfs0(int u, int last, int dep) {
    maxDep = max(maxDep, dep);
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        if (v == last || vis[v])
            continue;
        dfs0(v, u, dep + 1);
    }
}

int inline work(int a[], int len1, int b[], int len2) {
    int res = -INF;

    int hh = 0, tt = -1;
    int l = max(1, L - 1), r = R - 1;
    if (l > r)
        return res;
    len1 = min(len1, R - 1);
    for (int i = min(r, len2); i >= l; i--) {
        while (hh <= tt && b[q[tt]] < b[i]) tt--;
        q[++tt] = i;
    }
    if (hh <= tt)
        res = max(res, a[1] + b[q[hh]]);
    for (int i = 2; i <= len1; i++) {
        if (q[hh] == r)
            hh++;
        r--;
        if (l > 1) {
            --l;
            while (hh <= tt && b[q[tt]] < b[l]) tt--;
            q[++tt] = l;
        }
        if (hh <= tt)
            res = max(res, a[i] + b[q[hh]]);
    }

    return res;
}

void solve(int x) {
    if (S == 1)
        return;
    maxPart = 2e9, getRoot(x, 0), vis[rt] = true;
    tot = 0;
    for (int i = head[rt]; i; i = e[i].next) {
        int v = e[i].v;
        if (vis[v])
            continue;
        maxDep = 0;
        dfs0(v, rt, 1);
        mxDep[e[i].w] = max(mxDep[e[i].w], maxDep);
        sons[++tot] = (Son){ v, e[i].w, maxDep };
    }
    sort(sons + 1, sons + 1 + tot);
    zDep = 0;
    int nowDep = 0;
    for (int i = 1; i <= tot; i++) {
        maxDep = 0, now = c[sons[i].c];
        mxDep[e[i].w] = 0;
        dfs(sons[i].x, rt, sons[i].c, c[sons[i].c], 1);
        zDep = max(zDep, maxDep);
        nowDep = max(nowDep, maxDep);
        ans = max(ans, work(d, maxDep, nowVal, nowDep) - c[sons[i].c]);
        for (int j = 1; j <= maxDep; j++) {
            nowVal[j] = max(nowVal[j], d[j]);
            d[j] = -INF;
        }
        if (i == tot || sons[i].c != sons[i + 1].c) {
            ans = max(ans, work(nowVal, nowDep, val, zDep));
            for (int j = 1; j <= nowDep; j++) {
                val[j] = max(val[j], nowVal[j]);
                nowVal[j] = -INF;
            }
            nowDep = 0;
        }
    }
    for (int i = 1; i <= zDep; i++) val[i] = -INF;
    for (int i = head[rt]; i; i = e[i].next) {
        int v = e[i].v;
        if (vis[v])
            continue;
        S = sz[v], solve(v);
    }
}

int main() {
    scanf("%d%d%d%d", &n, &m, &L, &R);
    for (int i = 1; i <= n; i++) d[i] = nowVal[i] = val[i] = -2e9;
    for (int i = 1; i <= m; i++) scanf("%d", c + i);
    for (int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    S = n;
    solve(1);
    printf("%d
", ans);
    return 0;
}
原文地址:https://www.cnblogs.com/dmoransky/p/12670423.html