[CSP-S模拟测试]:射手座之日(dsu on tree)

题目传送门(内部题103)


输入格式

  第一行一个数$n$,表示结点的个数。
  第二行$n–1$个数,第$i$个数是$p[i+1]$。$p[i]$表示结点$i$的父亲是$p[i]$。数据保证$p[i]<i$。
  第三行$n$个数,$a[1],a[2],...,a[n]$,表示关卡表。数据保证这是一个排列。
  第四行$n$个数,$x[1],x[2],...,x[n]$,表示结点的权值。


输出格式

  输出一个数表示答案。即对于所有可能的回合,你们能获得的总收益是多少。


数据范围与提示

  对于$20\%$的数据,满足$nleqslant 100$。
  对于$40\%$的数据,满足$nleqslant 2,000$。
  对于$60\%$的数据,满足$nleqslant 50,000$。
  对于另外$20\%$的数据,排列$a[i]$是用如下的算法生成的:从一号点开始对树做$dfs$,到达一个节点的时候输出这个结点。
  对于$100\%$的数据,满足$1leqslant nleqslant 200,000,0leqslant x[i]leqslant 100,000,p[i]<i$,$a[i]$是一个排列。


题解

其实另外$20\%$的数据就是正解的一个引导。

显然对于这个部分分无非就是输出每一个点的子节点的$size$相互之间的乘积和乘上这个点的$x$值即可。

那么考虑正解。

用$dsu on tree$思想,每个节点继承儿子中最重的节点的信息,其他儿子暴力合并;用两个数组记录当前位置是否是一个极长区间的左(右)端点并记录另一个端点的位置,加入一个节点时尝试合并左右信息并统计合并后的方案数即可。

时间复杂度:$Theta(nlog n)$。

期望得分:$100$分。

实际得分:$100$分。


代码时刻

#include<bits/stdc++.h>
using namespace std;
struct rec{int nxt,to;}e[200000];
int head[200001],cnt;
int n;
int a[200001],val[200001],size[200001],son[200001],l[200001],r[200001],vis[200001],dfn[200001],now;
long long ans;
void add(int x,int y)
{
	e[++cnt].nxt=head[x];
	e[cnt].to=y;
	head[x]=cnt;
}
void dfs(int x)
{
	size[x]=1;
	for(int i=head[x];i;i=e[i].nxt)
	{
		dfs(e[i].to);
		size[x]+=size[e[i].to];
		if(size[son[x]]<size[e[i].to])son[x]=e[i].to;
	}
}
long long insert(int x)
{
	vis[a[x]]=now;
	if(vis[a[x]+1]!=now)l[a[x]+1]=r[a[x]+1]=0;
	if(vis[a[x]-1]!=now)l[a[x]-1]=r[a[x]-1]=0;
	long long L=l[a[x]-1],R=r[a[x]+1],len=l[a[x]-1]+r[a[x]+1]+1;
	l[a[x]]=L+1;r[a[x]]=R+1;l[a[x]+R]=r[a[x]-L]=len;
	return len*(len+1)/2-L*(L+1)/2-R*(R+1)/2;
}
long long ask(int x)
{
	long long res=insert(x);
	for(int i=head[x];i;i=e[i].nxt)
		res+=ask(e[i].to);
	return res;
}
int dfs(int x,int opt)
{
	now=dfn[x]=x;
	long long res=0,num=0,flag=0;
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to!=son[x])res-=dfs(e[i].to,0);
	if(son[x])
	{
		flag=dfs(son[x],1);
		num+=flag;
		now=dfn[x]=dfn[son[x]];
	}
	num+=insert(x);
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to!=son[x])num+=ask(e[i].to);
	ans+=(res+num-flag)*val[x];
	return num;
}
int main()
{
	scanf("%d",&n);
	for(int i=2;i<=n;i++){int x;scanf("%d",&x);add(x,i);}
	for(int i=1;i<=n;i++){int x;scanf("%d",&x);a[x]=i;}
	for(int i=1;i<=n;i++)scanf("%d",&val[i]);
	dfs(1);
	dfs(1,1);
	printf("%lld",ans);
	return 0;
}

rp++

原文地址:https://www.cnblogs.com/wzc521/p/11767788.html