[wc2013]糖果公园(70分)

#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <ctime>
#include <algorithm>

#define REP(i, n) for (i = 0; i < (n); ++i)
#define FER(i, j) for (i = lst[j]; i; i = i->n)
#define int64 long long
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define maxn 90002
#define BLOCK 300

using namespace std; double now;

template<class T> void read(T &x)
{
	char c = getchar();
	for (; '0' > c || c > '9'; c = getchar());
	x = c - '0', c = getchar();
	for (; '0' <= c && c <= '9'; c = getchar())
		x = x * 10 + c - '0';
}

int n, m, Q, w[maxn], c[maxn]; short Mod[maxn]; int64 v[maxn];
int pa[maxn], size[maxn], ufs[maxn], dep[maxn], dfn[maxn], Dfn, ca[maxn];
int F, f, mark[maxn], tot, a[maxn], Mark;
int64 buf[maxn * BLOCK], *buft = buf, *ans[maxn], *Ans;

struct edge { int t; edge *n; } edges[maxn * 2], *adj = edges, *lst[maxn], *fr[maxn];
struct block { int a[BLOCK]; } blocks[maxn + BLOCK], *btot = blocks;
struct array { block *a[BLOCK]; int operator[](int); } sum[maxn];

int array::operator[](int b) { return --b, a[b / BLOCK]->a[Mod[b]]; }

void inherit(array &a, array &b, int pos)
{
	block *&p = a.a[--pos / BLOCK];
	memcpy(&a, &b, sizeof(array)), memcpy(btot, p, sizeof(block));
	p = btot++, ++p->a[Mod[pos]];
}

int find(int x)
{
	int f, g;
	for (f = x; ufs[f] != f; f = ufs[f]);
	for (; ufs[x] != x; x = g) g = ufs[x], ufs[x] = f;
	return f;
}

void dfs(int u, int fa)
{
	edge *e; int f = -1;
	inherit(sum[u], sum[fa], c[u]), dep[u] = dep[fa] + 1, dfn[u] = ++Dfn;
	FER(e, u) if (e->t != fa)
	{
		dfs(e->t, u), fr[e->t] = e, pa[e->t] = u;
		if (!~f || size[f] + size[e->t] > BLOCK << 1)
			f = e->t;
		else
			size[f] += size[e->t], ca[f] = u, ufs[e->t] = f;
	}
	ca[u] = ufs[u] = u, size[u] = 1;
	if (~f && size[f] < BLOCK)
		size[u] += size[f], ufs[f] = u;
}

void bfs(int S)
{
	static int q[maxn]; int h, t; edge *e;
	for (q[h = t = S] = 0; h; h = q[h])
	{
		Ans[h] = Ans[pa[h]] + v[c[h]] * w[sum[F][c[h]] - (sum[f][c[h]] << 1) + sum[h][c[h]] + (c[f] == c[h])];
		FER(e, h) if (e->t != pa[h]) q[t = q[t] = e->t] = 0;
	}
}

void init()/*pretreat for the answers between blocks*/
{
	int i, j; edge *e;
	REP(i, BLOCK) sum->a[i] = btot++;
	dfs(1, 0);
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
	REP(i, n) if (find(i + 1) == i + 1)
	{
		F = f = ca[i + 1];
		if (ans[f]) continue;
		buft += n, Ans = ans[f] = buft - n - 1;
		Ans[f] = v[c[f]] * w[1];
		FER(e, f) if (e->t != pa[f]) bfs(e->t);
		for (j = f; fr[j]; j = f)
		{
			f = pa[j];
			Ans[f] = Ans[j] + v[c[f]] * w[sum[F][c[f]] - sum[f][c[f]] + 1];
			for (e = fr[j]->n; e; e = e->n)
				if (e->t != pa[f]) bfs(e->t);
		}
	}
}

int64 ask(int x, int y)
{
	int i, j, f; int64 res = 0; ++Mark, tot = 0;
	for (i = x; i; i = pa[i]) mark[i] = Mark;
	for (f = y; mark[f] < Mark; f = pa[f]) a[++tot] = c[f];
	a[++tot] = c[f];
	for (i = x; i != f; i = pa[i]) a[++tot] = c[i];
	sort(a + 1, a + tot + 1);
	for (i = 1; i <= tot; )
		for (j = 0, f = a[i]; a[i] == f && i <= tot; ++i)
			res += v[f] * w[++j];
	return res;
}

void Dfs(int u, int fa)
{
	edge *e; pa[u] = fa;
	for (e = lst[u]; e; e = e->n)
		if (e->t != fa) Dfs(e->t, u);
}

void input()
{
	int i; scanf("%d%d%d", &n, &m, &Q);
	REP(i, m) read(v[i + 1]);
	REP(i, n) read(w[i + 1]), Mod[i] = i % BLOCK;
	REP(i, n - 1)
	{
		int a, b; read(a), read(b);
		*adj = (edge){b, lst[a]}, lst[a] = adj++;
		*adj = (edge){a, lst[b]}, lst[b] = adj++;
	}
	REP(i, n) read(c[i + 1]);
	if (n <= 20000 && m <= 20000)
	{
		for (Dfs(1, 0); Q--; )
		{
			int t, x, y; read(t), read(x), read(y);
			t ? printf(fmt64"\n", ask(x, y)) : c[x] = y;
		}
		exit(0);
	}
}

int LCA(int x, int y)
{
	for (; x != y; )
		ufs[x] == ufs[y] ?
			dep[x] > dep[y] ? x = pa[x] : y = pa[y] :
			dep[ufs[x]] > dep[ufs[y]] ? x = pa[ufs[x]] : y = pa[ufs[y]];
	return x;
}

int main()
{
	freopen("park.in", "r", stdin);
	freopen("park.out", "w", stdout);

	now = clock();
	input();
	init();
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
	for (; Q--; )
	{
		int t, x, y; read(t), read(x), read(y);
		if (!t) exit(0); if (dfn[y] < dfn[x]) swap(x, y);
		int f = ca[ufs[x]], g = LCA(x, y);
		int64 Ans = ans[f][y];
		if (dep[f] < dep[g])
		{
			for (t = pa[g]; t != pa[f]; t = pa[t])
				Ans -= v[c[t]] * w[sum[y][c[t]] - sum[t][c[t]] + 1];
			f = g;
		}
		for (t = x; t != f; t = pa[t])
			Ans += v[c[t]] * w[sum[y][c[t]] - (sum[g][c[t]] << 1) + sum[t][c[t]] + (c[g] == c[t])];
		printf(fmt64"\n", Ans);
	}
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
}



原文地址:https://www.cnblogs.com/xinyuyuanm/p/3019694.html