[cf 1140] G. Double Tree

题意

给两棵同构的树,将同构节点之间连边,构成一张图。所有边有边权。给出一些询问,求某两点间的最短路。

题解

设两棵树分别为(T, T'),同构节点们为(x, x')

注意到每个询问的答案一定是从(u)在某一棵树上走,然后走到另一棵树的同构节点,再在另外一棵树上走,这样的过程重复个若干次。

在一棵树上走一定是走简单路径;走到同构节点并非就是走那条直接相连的边,而是最短路。

先考虑最短路这么求?我们要求(n)对同构节点的最短路。

可以等价转化:

1.(forall_{e(x, x', w_x)}adde(0, x, w_x))

2.(forall_{e(x, y, w_1, w_2)} adde(x, y, w_1 + w_2))

然后会发现这是对的……很神奇,就可以直接一遍sssp就好啦。

然后可以直接把最短路当做边权了。

那如何处理一整个问题?

记录(dp_{x, y, u, v})代表从树(u)的节点(x)向上跳(2 ^ y)步且最终到达第v棵树上的最短路。其中(u, v)取值都是({0, 1})。发现可以把(dp_{x, y})看成一个(2 * 2)的矩阵。

为了方便,还要记录(pd)数组代表的是从上向下的最短路矩阵。

最后询问的时候倍增跳一跳,矩阵重定义运算一下,然后按顺序合并即可,注意合并的顺序。

复杂度(O((n + Q) log n))

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;

inline int read () {
	static int x;
	scanf("%lld", &x);
	return x;
}
inline int readl () {
	static ll x;
	scanf("%lld", &x);
	return x;
}

const int N = 3e5 + 10, M = 2e6 + 10, H = 19;
int n; ll D[N];
struct Graph {
	int n, tot;
	int lnk[N], nxt[M], son[M]; ll w[M];
	void init (int _n) {
		n = _n, tot = 1;
		memset(lnk, 0, sizeof lnk);
	}
	void add (int x, int y, ll z) {
		nxt[++tot] = lnk[x], lnk[x] = tot, son[tot] = y, w[tot] = z;
	}
	void adde (int x, int y, ll z) {
		add(x, y, z), add(y, x, z); 
	}
	void sssp () {
		static ll dis[N];
		static bool vis[N];
		static queue <int> q;
		memset(dis, 60, sizeof dis), dis[0] = 0;
		memset(vis, 0, sizeof vis), vis[0] = 1;
		for ( ; !q.empty(); q.pop()); q.push(0);
		for ( ; !q.empty(); q.pop()) {
			int x = q.front();
			vis[x] = 0;
			for (int j = lnk[x]; j; j = nxt[j])
				if (dis[son[j]] > dis[x] + w[j]) {
					dis[son[j]] = dis[x] + w[j];
					if (!vis[son[j]]) vis[son[j]] = 1, q.push(son[j]);
				}
		}
		for (int i = 1; i <= n; ++i) D[i] = dis[i];
	}
} G;
struct Matrix {
	ll a[2][2];
	Matrix operator * (const Matrix &o) {
		return {min(a[0][0] + o.a[0][0], a[0][1] + o.a[1][0]),
				min(a[0][0] + o.a[0][1], a[0][1] + o.a[1][1]),
				min(a[1][0] + o.a[0][0], a[1][1] + o.a[1][0]),
				min(a[1][0] + o.a[0][1], a[1][1] + o.a[1][1])};
	}
};
struct Tree {
	int n, tot;
	int lnk[N], nxt[N << 1], son[N << 1]; ll w1[N << 1], w2[N << 1];
	int dep[N], fa[N][H + 1]; Matrix dp[N][H + 1], pd[N][H + 1];
	void init (int _n) {
		n = _n, tot = 1, dep[0] = 0;
		memset(lnk, 0, sizeof lnk);
	}
	void add (int x, int y, ll z1, ll z2) {
		nxt[++tot] = lnk[x], lnk[x] = tot, son[tot] = y, w1[tot] = z1, w2[tot] = z2;
	}
	void adde (int x, int y, ll z1, ll z2) {
		add(x, y, z1, z2), add(y, x, z1, z2);
	}
	void dfs (int x, int p) {
		fa[x][0] = p, dep[x] = dep[p] + 1;
		for (int j = lnk[x]; j; j = nxt[j]) if (son[j] != p) {
			dfs(son[j], x);
			dp[son[j]][0] = {w1[j], min(w1[j] + D[x], D[son[j]] + w2[j]),
							min(w1[j] + D[son[j]], D[x] + w2[j]), w2[j]};
			pd[son[j]][0] = {w1[j], min(w1[j] + D[son[j]], D[x] + w2[j]),
							min(w1[j] + D[x], D[son[j]] + w2[j]), w2[j]};
		}
	}
	void build () {
		for (int j = 1; j <= H; ++j)
			for (int i = 1; i <= n; ++i) {
				fa[i][j] = fa[fa[i][j - 1]][j - 1];
				dp[i][j] = dp[i][j - 1] * dp[fa[i][j - 1]][j - 1];
				pd[i][j] = pd[fa[i][j - 1]][j - 1] * pd[i][j - 1];
			}
	}
	int lca (int x, int y) {
		if (dep[x] < dep[y]) swap(x, y);
		int dif = dep[x] - dep[y];
		for (int j = H; ~j; --j)
			if (dif >> j & 1) x = fa[x][j];
		if (x == y) return x;
		for (int j = H; ~j; --j)
			if (fa[x][j] != fa[y][j]) x = fa[x][j], y = fa[y][j];
		return fa[x][0];
	}
	ll query (int x, int y, int u, int v) {
		static Matrix ans1, ans2, ans;
		ans1 = {0, D[x], D[x], 0}, ans2 = {0, D[y], D[y], 0};
		int z = lca(x, y), dif;
		dif = dep[x] - dep[z];
		for (int j = H; ~j; --j) if (dif >> j & 1)
			ans1 = ans1 * dp[x][j], x = fa[x][j];
		dif = dep[y] - dep[z];
		for (int j = H; ~j; --j) if (dif >> j & 1)
			ans2 = pd[y][j] * ans2, y = fa[y][j];
		ans = ans1 * ans2;
		return ans.a[u][v];
	}
} T;

signed main () {
	n = read(), G.init(n), T.init(n);
	for (int i = 1; i <= n; ++i) G.adde(0, i, readl());
	for (int i = 1; i < n; ++i) {
		int x = read(), y = read();
		ll z1 = readl(), z2 = readl();
		G.adde(x, y, z1 + z2), T.adde(x, y, z1, z2);
	}
	G.sssp();
	T.dfs(1, 0);
	T.build();
	for (int _ = read(), x, y; _; --_) {
		x = read() + 1, y = read() + 1;
		printf("%lld
", T.query(x >> 1, y >> 1, x & 1, y & 1));
	}
	return 0;
}

不知为何全搞成long long才过。

原文地址:https://www.cnblogs.com/psimonw/p/10727401.html