luogu P4842 城市旅行

嘟嘟嘟


好题,好题


刚开始突发奇想写了一个(O(n ^ 2))暴力,结果竟然过了?!后来才知道是上传题的人把单个数据点开成了10s……
不过不得不说我这暴力写的挺好看的。删边模仿链表删边,加边的时候遍历其中一棵树,使两棵树染上相同的颜色,这样判联通就能达到(O(1))了。
所以我决定先放一个暴力代码

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e4 + 5;
const int maxe = 2e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, m, a[maxn];
struct Edge
{
  int nxt, to;
}e[maxe];
int head[maxn], ecnt = -1;
In void addEdge(int x, int y)
{
  e[++ecnt] = (Edge){head[x], y};
  head[x] = ecnt;
}

In ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}

int fa[maxn], dep[maxn], col[maxn], Col = 0;
In void dfs(int now, int _f, int Col)
{
  fa[now] = _f, dep[now] = dep[_f] + 1;
  col[now] = Col;
  for(int i = head[now], v; ~i; i = e[i].nxt)
    {
      if((v = e[i].to) == _f) continue;
      dfs(v, now, Col);
    }
}

In bool cut(int x, int y)
{
  for(int i = head[x], v, pre = 0; ~i; pre = i, i = e[i].nxt)
    {
      if((v = e[i].to) == y)
	{
	  if(!pre) head[x] = e[i].nxt;
	  else e[pre].nxt = e[i].nxt;
	  return 1;
	}
    }
  return 0;
}
In void Cut(int x, int y)
{
  if(!cut(x, y)) return;
  cut(y, x);
  dfs(y, 0, ++Col);
}
In void Link(int x, int y)
{
  if(col[x] == col[y]) return;
  addEdge(x, y), addEdge(y, x);
  dfs(y, x, col[x]);
}
In void add(int x, int y, int d)
{
  if(col[x] ^ col[y]) return;
  while(x ^ y)
    {
      if(dep[x] < dep[y]) swap(x, y);
      a[x] += d; x = fa[x];
    }
  a[x] += d;
}
int l[maxn], r[maxn], cnt1 = 0, cnt2 = 0;
In void query(int x, int y)
{
  if(col[x] ^ col[y]) {puts("-1"); return;}
  cnt1 = cnt2 = 0;
  while(x ^ y)
    {
      if(dep[x] >= dep[y]) l[++cnt1] = x, x = fa[x];
      else r[++cnt2] = y, y = fa[y];
    }
  l[++cnt1] = x;
  ll ret1 = 0, ret2 = (1LL * cnt1 + cnt2) * (cnt1 + cnt2 + 1) / 2;
  for(int i = 1; i <= cnt1; ++i) ret1 += 1LL * a[l[i]] * i * (cnt1 + cnt2 - i + 1);
  for(int i = 1; i <= cnt2; ++i) ret1 += 1LL * a[r[i]] * i * (cnt1 + cnt2 - i + 1);
  ll Gcd = gcd(ret1, ret2);
  write(ret1 / Gcd), putchar('/'), write(ret2 / Gcd), enter;
}

int main()
{
  Mem(head, -1);
  n = read(), m = read();
  for(int i = 1; i <= n; ++i) a[i] = read();
  for(int i = 1; i < n; ++i)
    {
      int x = read(), y = read();
      addEdge(x, y), addEdge(y, x);
    }
  dfs(1, 0, ++Col);
  for(int i = 1; i <= m; ++i)
    {
      int op = read(), x = read(), y = read();
      if(op == 1) Cut(x, y);
      else if(op == 2) Link(x, y);
      else if(op == 3)
	{
	  int d = read();
	  add(x, y, d);
	}
      else query(x, y);
    }
  return 0;
}

好,那么我们切入正题。 不对,我先吐槽一下:第2和第4两个点我的LCT跑的比暴力还慢,然后别的点和暴力差不多,结果总时间竟然比暴力还慢……调了半天也不知道为啥,我这可活什么劲。
好,那现在真的切入正题了。 有时候维护LCT跟线段树差不多,比如这道题,核心就是pushdown和pushup怎么写。 线段树是连个子区间合并,那么这个LCT就是两条链首尾相连合并成一条链。
关于pushup,我实在写不动了,就扔出一篇博客:[城市旅行题解](https://www.luogu.org/blog/user25308/solution-p4842) 思路就是算出左子树的答案$ans_l$,左子树在整棵树中的贡献$w_l$,右子树同理,那么整棵树的答案就是$w_l + w_r = ans_l + Delta x_l + ans_r + Delta x_r$,其中两个$Delta$是能手算出来的。
关于pushdown,除了期望,我和那篇题解都一样。因为我数学没那位老哥那么巨,小学也没学过奥数,推了一阵子搞出个这么个东西:$d * (frac{n ^ 3 + 2n ^ 2 + n}{2} - sum i ^ 2)$。 然后发现没办法$O(1)$求…… 你以为我就去抄题解了吗?那不可能,别忘了,咱这是信竞,不是数竞,后面那个$sum$直接预处理出来不就完了嘛。
对了,子树大小可能在运算的时候会爆int,别忘强制转换成long long。 ```c++ #include #include #include #include #include #include #include #include #include #include using namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define In inline typedef long long ll; typedef double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 5e4 + 5; inline ll read() { ll ans = 0; char ch = getchar(), last = ' '; while(!isdigit(ch)) last = ch, ch = getchar(); while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar(); if(last == '-') ans = -ans; return ans; } inline void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); }

int n, m;
ll SUM[maxn];
struct Tree
{
int ch[2], fa, siz, rev;
ll val, lzy, sum, lsum, rsum, ans;
}t[maxn];

define ls t[now].ch[0]

define rs t[now].ch[1]

define S t[now].siz

In ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}

In void c_rev(int now)
{
swap(ls, rs); swap(t[now].lsum, t[now].rsum);
t[now].rev ^= 1;
}
In void c_add(int now, ll d)
{
t[now].lzy += d; t[now].val += d;
t[now].sum += d * S;
t[now].lsum += ((d * S * (S + 1)) >> 1);
t[now].rsum += ((d * S * (S + 1)) >> 1);
t[now].ans += d * (((1LL * S * S * S + 1LL * S * (S << 1) + S) >> 1) - SUM[S]);
}
In void pushdown(int now)
{
if(t[now].rev)
{
if(ls) c_rev(ls);
if(rs) c_rev(rs);
t[now].rev = 0;
}
if(t[now].lzy)
{
if(ls) c_add(ls, t[now].lzy);
if(rs) c_add(rs, t[now].lzy);
t[now].lzy = 0;
}
}
In void pushup(int now)
{
t[now].siz = t[ls].siz + t[rs].siz + 1;
t[now].sum = t[ls].sum + t[rs].sum + t[now].val;
t[now].lsum = t[ls].lsum + t[rs].lsum + (t[rs].sum + t[now].val) * (t[ls].siz + 1);
t[now].rsum = t[rs].rsum + t[ls].rsum + (t[ls].sum + t[now].val) * (t[rs].siz + 1);
t[now].ans = t[ls].ans + t[rs].ans + t[ls].lsum * (t[rs].siz + 1) + t[rs].rsum * (t[ls].siz + 1) + t[now].val * (t[ls].siz + 1) * (t[rs].siz + 1);
}
In bool n_root(int x)
{
return t[t[x].fa].ch[0] == x || t[t[x].fa].ch[1] == x;
}
In void rotate(int x)
{
int y = t[x].fa, z = t[y].fa, k = (t[y].ch[1] == x);
if(n_root(y)) t[z].ch[t[z].ch[1] == y] = x; t[x].fa = z;
t[y].ch[k] = t[x].ch[k ^ 1], t[t[y].ch[k]].fa = y;
t[x].ch[k ^ 1] = y, t[y].fa = x;
pushup(y), pushup(x);
}
int st[maxn], top = 0;
In void splay(int x)
{
int y = x; st[top = 1] = y;
while(n_root(y)) st[++top] = y = t[y].fa;
while(top) pushdown(st[top--]);
while(n_root(x))
{
int y = t[x].fa, z = t[y].fa;
if(n_root(y)) rotate(((t[y].ch[0] == x) ^ (t[z].ch[0] == y)) ? x : y);
rotate(x);
}
}

In void access(int x)
{
int y = 0;
while(x)
{
splay(x); t[x].ch[1] = y;
pushup(x);
y = x; x = t[x].fa;
}
}
In void make_root(int x)
{
access(x), splay(x);
c_rev(x);
}
In int find_root(int x)
{
access(x), splay(x);
while(t[x].ch[0]) pushdown(x), x = t[x].ch[0];
return x;
}
In void split(int x, int y)
{
make_root(x);
access(y), splay(y);
}
In void Link(int x, int y)
{
make_root(x);
if(find_root(y) ^ x) t[x].fa = y;
}
In void Cut(int x, int y)
{
make_root(x);
if(find_root(y) == x && t[x].fa == y && !t[x].ch[1])
t[y].ch[0] = t[x].fa = 0, pushup(y);
}
In void update(int x, int y, int d)
{
make_root(x);
if(find_root(y) ^ x) return;
split(x, y); c_add(y, d);
pushup(y);
}
In void query(int x, int y)
{
make_root(x);
if(find_root(y) ^ x) {puts("-1"); return;}
split(x, y);
ll Siz = t[y].siz, tp = (Siz * (Siz + 1)) >> 1, Gcd = gcd(t[y].ans, tp);
write(t[y].ans / Gcd), putchar('/'), write(tp / Gcd), enter;
}

int main()
{
n = read(), m = read();
for(int i = 1; i <= n; ++i)
{
t[i].val = read(), t[i].siz = 1;
t[i].sum = t[i].lsum = t[i].rsum = t[i].ans = t[i].val;
SUM[i] = SUM[i - 1] + i * i;
}
for(int i = 1; i < n; ++i)
{
int x = read(), y = read();
Link(x, y);
}
for(int i = 1; i <= m; ++i)
{
int op = read(), x = read(), y = read();
if(op == 1) Cut(x, y);
else if(op == 2) Link(x, y);
else if(op == 4) query(x, y);
else
{
int d = read();
update(x, y, d);
}
}
return 0;
}

原文地址:https://www.cnblogs.com/mrclr/p/10769932.html