题解:[GXOI/GZOI2019]旧词

这个题目其实早就做了,只是突然发现还没发,那就凑一下GZOI

题意:给定$x,y$求

$$sum_{ileq x}dep(lca(i,y))^k$$

首先我们先来看这个题目的简化版 

https://www.luogu.org/problem/P4211

求  $$sum_{ileq x}dep(lca(i,y))$$

我们来看$dep$的实际意义——从 i 点到根有多少个点(包括 i )。

我们从整体上考虑,发现对于一个询问:所有的 $lca$ 都在 $y$ 到根的路径上。从而有一些点,它们对很多的 $lca$ 的深度都有贡献,而这个贡献等于在这个点下面的 $lca$ 的个数,所以我们可以把每个 $lca$ 到根的路径上的每个点的权值都加一。然后从 $y$ 向上走到根,沿路统计的权值就是答案了。

这里,我们可以把所有的询问离线下来,按照 $x$  排序,然后每个节点就向上跳把所有的上面的点染上颜色,然后查询的时候只需要向上找,看有多少染上颜色的节点,并且计算贡献,这里我们只需要用树链剖分维护一下就行了

然后是我非常丑的代码

#include <bits/stdc++.h>
using namespace std;

#define re register
#define ll long long
#define gc getchar()
inline ll read()
{
 	re ll x(0),f(1);re char c(gc);
    while(c>'9'||c<'0')f=c=='-'?-1:1,c=gc;
    while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
    return f*x;
}

const ll N=50500,mod=201314;
ll n,Q,k,h[N],cnt,qs;
struct edge{ll next,to;}e[N];

void add(ll u,ll v){e[++cnt]=(edge){h[u],v},h[u]=cnt;}
#define QXX(u) for(ll i=h[u],v;v=e[i].to,i;i=e[i].next)

ll dep[N],fa[N],son[N],siz[N],top[N],rev[N],seq[N],tot;

void dfs(ll u)
{
	siz[u]=1;
	QXX(u)
	{
		dfs(v);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void Dfs(ll u)
{
	if(son[u])
	{
		ll v=son[u];
		seq[v]=++tot;
		rev[tot]=v;
		top[v]=top[u];
		Dfs(v);
	}
	QXX(u)
	{
		if(v==son[u]) continue;
		seq[v]=++tot;
		rev[tot]=v;
		top[v]=v;
		Dfs(v);
	}
}

struct node{ll id,x,z,ans;bool w;}q[N<<1];
bool operator < (node a,node b){return a.x<b.x;}

#define ls id<<1
#define rs id<<1|1
#define mid ((l+r)>>1)

ll sum[N<<2],tag[N<<2];
void pushup(ll id){sum[id]=(sum[ls]+sum[rs])%mod;}

void pushdown(ll id,ll l,ll r)
{
	if(tag[id])
	{
		tag[ls]+=tag[id];
		tag[rs]+=tag[id];
		sum[ls]+=tag[id]*(mid-l+1);
		sum[rs]+=tag[id]*(r-mid);
		sum[ls]%=mod;
		sum[rs]%=mod;
		tag[id]=0;
	}
}

void change(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R)
	{
		tag[id]++;
		sum[id]+=(r-l+1);
		sum[id]%=mod;
		return;
	}
	pushdown(id,l,r);
	if(mid>=L) change(ls,l,mid,L,R);
	if(mid<R) change(rs,mid+1,r,L,R);
	pushup(id);
}

void work(ll x)
{
	while(1)
	{
		if(top[x]!=x)
			change(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			change(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		if(x==0) return;
	}
}

ll query(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R) return sum[id]%mod;
	pushdown(id,l,r);
	ll ans=0;
	if(mid>=L) ans+=query(ls,l,mid,L,R);
	if(mid<R) ans+=query(rs,mid+1,r,L,R);
	return ans%mod;
}

ll ask(ll x)
{
	ll ans=0;
	while(1)
	{
		if(top[x]!=x)
			ans+=query(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			ans+=query(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		ans%=mod;
		if(x==0) return ans;
	}
}

ll ans[N];

int main()
{
	n=read(),Q=read();
	for(ll i=2;i<=n;++i)
	{
		ll x=read()+1;
		add(x,i);
		fa[i]=x;
	}
	top[1]=1;seq[1]=++tot;
	dfs(1);Dfs(1);
	for(ll i=1;i<=Q;++i)
	{
		ll l=read(),r=read()+1,z=read()+1;
		q[++qs]=(node){i,l,z,0,0};
		q[++qs]=(node){i,r,z,0,1};
	}
	sort(q+1,qs+1+q);
	ll t=0;
	while(q[t+1].x<1) ++t;
	for(ll i=1;i<=n;++i)
	{
		work(i);
		while(q[t+1].x<=i&&t<qs)
			++t,q[t].ans=ask(q[t].z);
		if(t==qs) break;
	}
	for(ll i=1;i<=qs;++i)
	{
		if(q[i].w==0) ans[q[i].id]-=q[i].ans;
		else ans[q[i].id]+=q[i].ans;
	}
	for(ll i=1;i<=Q;++i)
		cout<<(ans[i]+mod)%mod<<endl;
	return 0;
}  

然后我们回到本题,这里是多了一个 $k$ 次方

首先我们来看前面的每次$+1$是哪里来的 $dep[i]->dep[i+1]$所以这里实际上就是在做差分,那么我们把指数换成 $k$

$dep[i]^k->(dep[i]+1)^k$

那么,我们就预处理出来每一个 $dep^k$ 然后对于每个节点就相当于每次会增加 $dep[x]^k-(dep[x]-1)^$ 的贡献

然后我们就可以转化为,对于一个序列,每个点的值是 $a*b$ 其中 $b$ 是定值,但是每个节点不一样,每次操作就是做区间修改给 $a$ 加上1和区间查询

然后我们维护线段树的时候再多维护一个 $sum_b$ 就可以了

#include <bits/stdc++.h>
using namespace std;

#define re register
#define ll long long
#define gc getchar()
inline ll read()
{
 	re ll x(0),f(1);re char c(gc);
    while(c>'9'||c<'0')f=c=='-'?-1:1,c=gc;
    while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
    return f*x;
}

const ll N=50050,mod=998244353;

ll n,Q,k,h[N],cnt,qs;
struct edge{ll next,to;}e[N];

void add(ll u,ll v){e[++cnt]=(edge){h[u],v},h[u]=cnt;}
#define QXX(u) for(ll i=h[u],v;v=e[i].to,i;i=e[i].next)

ll dep[N],fa[N],son[N],siz[N],top[N],seq[N],rev[N],tot;

void dfs(ll u)
{
	siz[u]=1;
	QXX(u)
	{
		dep[v]=dep[u]+1;
		dfs(v);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void Dfs(ll u,ll to)
{
	seq[u]=++tot,rev[tot]=u;
	top[u]=to;
	if(son[u])
		Dfs(son[u],top[u]);
	QXX(u)
	{
		if(v==son[u]) continue;
		Dfs(v,v);
	}
}

ll qpow(ll x,ll b)
{
	ll a=1;
	while(b)
	{
		if(b&1) a=(x*a)%mod;
		x=(x*x)%mod,b>>=1;
	}
	return a;
}

#define ls id<<1
#define rs id<<1|1
#define mid ((l+r)>>1)

ll su[N<<2],sum[N<<2],tag[N<<2],po[N];

void pushup(ll id)
{
	sum[id]=sum[ls]+sum[rs];
	su[id]=su[ls]+su[rs];
}
void pushdown(ll id,ll l,ll r)
{
	if(tag[id])
	{
		tag[ls]+=tag[id];
		tag[rs]+=tag[id];
		sum[ls]=(sum[ls]+tag[id]*su[ls])%mod;
		sum[rs]=(sum[rs]+tag[id]*su[rs])%mod;
		tag[id]=0;
	}
}
void built(ll id,ll l,ll r)
{
	if(l==r)
	{
		su[id]=(po[dep[rev[l]]]+mod-po[dep[rev[l]]-1])%mod;
		return;
	}
	built(ls,l,mid);
	built(rs,mid+1,r);
	pushup(id);
}
void change(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R)
	{
		tag[id]++;
		sum[id]+=su[id];
		sum[id]%=mod;
		return;
	}
	pushdown(id,l,r);
	if(mid>=L) change(ls,l,mid,L,R);
	if(mid<R) change(rs,mid+1,r,L,R);
	pushup(id);
}
ll query(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R) return sum[id]%mod;
	pushdown(id,l,r);
	ll ans=0;
	if(mid>=L) ans+=query(ls,l,mid,L,R);
	if(mid<R) ans+=query(rs,mid+1,r,L,R);
	return ans%mod;
}
void work(ll x)
{
	while(1)
	{
		if(top[x]!=x)
			change(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			change(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		if(x==0) return;
	}
}
ll ask(ll x)
{
	ll ans=0;
	while(1)
	{
		if(top[x]!=x)
			ans+=query(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			ans+=query(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		ans%=mod;
		if(x==0) return ans;
	}
}
struct node{ll id,x,y,ans;}q[N];
bool cmpx(node a,node b){return a.x<b.x;}
bool cmpi(node a,node b){return a.id<b.id;}

int main()
{
	n=read(),Q=read(),k=read();
	for(ll i=1;i<=n;++i)
		po[i]=qpow(i,k);
	for(ll i=2;i<=n;++i)
	{
		fa[i]=read();
		add(fa[i],i);
	}
	dep[1]=1;
	dfs(1),Dfs(1,1);
	built(1,1,tot);
	for(ll i=1;i<=Q;++i)
	{
		ll x=read(),y=read();
		q[i]=(node){i,x,y,0};
	}
	sort(q+1,q+1+Q,cmpx);
	ll t=0;
	while(q[t+1].x<1) ++t;
	for(ll i=1;i<=n;++i)
	{
		work(i);
		while(q[t+1].x<=i&&t<Q)
			++t,q[t].ans=ask(q[t].y);
		if(t==Q) break;
	}
	sort(q+1,q+1+Q,cmpi);
	for(ll i=1;i<=Q;++i)
		cout<<q[i].ans<<endl;
	return 0;
}

  

原文地址:https://www.cnblogs.com/zijinjun/p/11256866.html