写在前面
考场上看到题目果断强推柿子,推到一半发现推的关系式回溯不了,于是果断写链。
结果感觉自己关系式没推错,然而没过考场上发的恶臭不堪的大样例。
最终35pts走人。
回来gyh说是树形DP。说了半天啥也没懂。
今天考场外边再来做此题,终归还是瞎搞,然而搞过了。
Solution
瞎搞。
统计出每个右括号的贡献值,即以该右括号为结束位置的合法括号串总数。
最终每个节点的答案即为根节点到其简单路径上所有节点的贡献值之和。
这样就不用回溯了。
考虑怎么统计贡献值。
每个右括号的贡献值等于在之前的合法区间内与该右括号所在的一对括号同等级的括号总数 \(+1\)(\(+1\) 是因为包括它本身)。
举个例子,像这样:
()()(())
那么这四个右括号的贡献就分别是 \(1,2,1,3\)。
-
如果一个左括号紧跟着一个右括号,那么这个左括号就继承右括号的贡献值。
-
如果一个右括号紧跟着一个左括号,那么这个右括号就等于左括号的贡献值 \(+1\)。
但是要是这一对左右括号之间还有别的合法括号串(如举的例子中的第4个右括号及其所对的左括号)呢?
考虑对于每个位置,开一个变量 last_pos
存储“上一等级的最近的左括号所在位置”。
-
如果存在两个相邻的左括号,那么后面的左括号的
last_pos
即为前一个左括号的位置。 -
同等级括号处理之后,继承上一个括号的
last_pos
。 -
前一个位置是 空 或 无法和左括号匹配上的右括号 的左括号的
last_pos
是前一个位置,不管是不是左括号。 -
内部还有合法括号串的一对括号,那么这对括号中的右括号的
last_pos
是前一个位置的last_pos
(即上一个左括号位置)的last_pos
。
那什么是不合法的区间呢?
如果存在一个没有匹配的右括号,那么这个右括号就将整个统计贡献的累积值断掉了,将贡献清零,并将其 last_pos
设为它本身即可。
至于如何判断嘛,开栈判断!这里其实所谓的栈只是一个统计层数的变量。变量减到 \(-1\) 这个括号就是不合法的了。
记得判断完不合法之后把当前等级清零。
Code:
愉快的代码时间
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int Maxn = 5e5 + 5;
struct e{
int to, next;
};
e b[Maxn];
int head[Maxn], ecnt;
void add(int u, int v)
{
ecnt++;
b[ecnt].to = v;
b[ecnt].next = head[u];
head[u] = ecnt;
}
struct node{
char data;
int father;
int last_pos;
LL counter;
LL ans;
int stk;
};
node a[Maxn];
int n;
LL ans;
void work(int t)
{
int fa = a[t].father;
if(a[t].data == '(')
{
a[t].stk = a[fa].stk + 1;
a[t].ans = a[fa].ans;
if(a[fa].data == ')')
{
a[t].counter = a[fa].counter;
a[t].last_pos = a[fa].last_pos;
}
else
{
a[t].last_pos = fa;
a[t].counter = 0;
}
}
else if(a[t].data == ')')
{
a[t].stk = a[fa].stk - 1;
if(a[t].stk == -1)
{
a[t].stk = 0;
a[t].counter = 0;
a[t].ans = a[fa].ans;
a[t].last_pos = t;
}
else
{
if(a[fa].data == '(')
{
a[t].last_pos = a[fa].last_pos;
a[t].counter = a[fa].counter + 1;
a[t].ans = a[fa].ans + a[t].counter;
}
else if(a[fa].data == ')')
{
a[t].last_pos = a[a[fa].last_pos].last_pos;
a[t].counter = a[a[fa].last_pos].counter + 1;
a[t].ans = a[fa].ans + a[t].counter;
}
}
}
ans ^= (a[t].ans * (LL)t);
for(int i = head[t]; i; i = b[i].next)
{
work(b[i].to);
}
}
int main()
{
scanf("%d", &n);
string str;
cin >> str;
for(int i = 1; i <= n; ++i)
{
a[i].data = str[i - 1];
}
for(int i = 2; i <= n; ++i)
{
scanf("%d", &a[i].father);
add(a[i].father, i);
}
work(1);
printf("%lld", ans);
return 0;
}