小清新数据结构题

XII.小清新数据结构题

太 清 新 了

话说就我一个人看到这道题后兴冲冲的以为暴力LCT就能过然后发现LCT如果维护子树信息的话只有根节点处的信息是正确的吗(没错,就我一个)

闲话少说,正片开始。

法一:推一种式子,然后LCT/树剖维护

我们设\(val_i\)\(i\)节点的值,然后\(sum_i\)根节点为\(1\)\(i\)为根的子树的子树和。

则根为\(1\)时的答案即为\(\sum\limits_{i=1}^{n}(sum_i)^2\)。设其为\(res\)

我们看一下当位置\(x\)\(val\)增大\(\Delta\)后,\(res\)会如何变化(其中\(path_x\)意为\(1\rightarrow x\)的路径,而\(dis_x\)为这一路径上的节点数量):

\(\begin{aligned}\Delta_{res}&=\sum\limits_{i\in path_x}(sum_i+\Delta)^2-(sum_i)^2\\&=\sum\limits_{i\in path_x}(sum_i)^2+2\Delta sum_i+\Delta^2-(sum_i)^2\\&=\sum\limits_{i\in path_x}2\Delta sum_i+\Delta^2\\&=2\Delta\sum\limits_{i\in path_x}sum_i+dis_x\Delta^2\end{aligned}\)

\(\sum\limits_{i\in path_x}sum_i\)\(dis_x\)都可以很方便地使用LCT或树剖维护。这里我选用LCT,毕竟这题的LCT如果用无根LCT(即根固定为\(1\)的LCT)的话,直接access一下即可打包出来这条路径,非常方便。

\(\sum\limits_{i\in path_x}sum_i\)的变化,实际上仅仅是在\(val_x\)增加\(\Delta\)时,每个\(sum_i\)全都增加\(\Delta\)而已。打个tag就解决了。

维护完这些东西后,我们便可以在点权变化的时候同时维护\(res\)就可以拿到\(30\%\)

现在我们考虑根不是\(1\)了,它变到了\(x\)。则新的\(res\)(设为\(res'\))又会怎么变化呢?

我们有\(res'=\sum\limits_{i=1}^{n}(sum_i')^2\),其中\(sum'\)为新的\(sum\)值。

但是,对于大多数情况,都仍有\(sum_x'=sum_x\)——准确的说,除了\(x\)\(1\)路径上的点,其它的\(sum_x\)都没有发生变化。我们仍然设这条路径为\(path_x\),这里我们从\(1\)\(x\)将路径上的点依次编号为\(p_0,p_1,\dots,p_k\),并令\(p_0=1,p_k=x\)

如果我们画出图来观察一下,就会惊讶地发现,必然有

\[\large sum_{p_0}=sum_{p_0}'+sum_{p_1}=sum_{p_1}'+sum_{p_2}=\dots=sum_{p_{k-1}}'+sum_{p_k}=sum_{p_k}' \]

而这上面所有的东西,全都等于整棵树的权值和

因此我们有

\(\begin{aligned}\Delta_{res}&=\sum\limits_{i=1}^{n}(sum_i')^2-\sum\limits_{i=1}^{n}(sum_i)^2\\&=\sum\limits_{i=0}^{k}(sum_{p_i}')^2-\sum\limits_{i=0}^{k}(sum_{p_i})^2\\&=\sum\limits_{i=0}^{k}(sum_{p_i}')^2-\sum\limits_{i=0}^{k}(sum_{p_i})^2\\&=(sum_{p_k}')^2+\sum\limits_{i=0}^{k-1}(sum_{p_0}-sum_{p_{i+1}})^2-(sum_{p_0})^2-\sum\limits_{i=1}^{k}(sum_{p_i})^2\\&=\sum\limits_{i=1}^{k}(sum_{p_0}-sum_{p_i})^2-\sum\limits_{i=1}^{k}(sum_{p_i})^2\\&=\sum\limits_{i=1}^{k}(sum_{p_0})^2-2sum_{p_0}sum_{p_i}\\&=k(sum_{p_0})^2-2sum_{p_0}\sum\limits_{i=1}^{k}sum_{p_i}\\&=k(sum_1)^2-2sum_1\sum\limits_{i=1}^{k}sum_{p_i}\\&=(k+2)(sum_1)^2-2sum_1\sum\limits_{i=0}^{k}sum_{p_i}\end{aligned}\)

套用我们上面对\(path_x\)\(dis_x\)的定义(注意到这里则有\(dis_x=k+1\)),我们得到

\(\Delta_{res}=(dis_x+1)(sum_1)^2-2sum_1\sum\limits_{i\in path_x}sum_i\)

刚好是我们之前LCT维护的东西,因此直接搬过来用即可。

复杂度\(O(n\log n)\)

代码:

#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
#define lson t[x].ch[0]
#define rson t[x].ch[1]
int n,m,val[200100];
ll res;
struct LCT{
    int fa,ch[2],val,tag,sz;
    ll sum;
}t[200100];
inline int identify(int x){
    if(x==t[t[x].fa].ch[0])return 0;
    if(x==t[t[x].fa].ch[1])return 1;
    return -1;
}
inline void ADD(int x,int y){
    t[x].val+=y,t[x].tag+=y,t[x].sum+=1ll*t[x].sz*y;
}
inline void pushdown(int x){
    if(lson)ADD(lson,t[x].tag);
    if(rson)ADD(rson,t[x].tag);
    t[x].tag=0;
}
inline void pushup(int x){
    t[x].sum=t[x].val,t[x].sz=1;
    if(lson)t[x].sum+=t[lson].sum,t[x].sz+=t[lson].sz;
    if(rson)t[x].sum+=t[rson].sum,t[x].sz+=t[rson].sz;
}
inline void rotate(int x){
    register int y=t[x].fa;
    register int z=t[y].fa;
    register int dirx=identify(x);
    register int diry=identify(y);
    register int b=t[x].ch[!dirx];
    if(diry!=-1)t[z].ch[diry]=x;t[x].fa=z;
    if(b)t[b].fa=y;t[y].ch[dirx]=b;
    t[y].fa=x,t[x].ch[!dirx]=y;
    pushup(y),pushup(x);
}
inline void pushall(int x){
    if(identify(x)!=-1)pushall(t[x].fa);
    pushdown(x);
}
inline void splay(int x){
    pushall(x);
    while(identify(x)!=-1){
        register int fa=t[x].fa;
        if(identify(fa)==-1)rotate(x);
        else if(identify(x)==identify(fa))rotate(fa),rotate(x);
        else rotate(x),rotate(x);
    }
}
inline void access(int x){
    for(register int y=0;x;x=t[y=x].fa)splay(x),rson=y,pushup(x);
}
inline void makeroot(int x){
    access(x),splay(x);
}
vector<int>v[200100];
void dfs(int x,int fa){
	t[x].val=val[x];
	for(auto y:v[x])if(y!=fa)t[y].fa=x,dfs(y,x),t[x].val+=t[y].val;
	res+=1ll*t[x].val*t[x].val;
	pushup(x);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),v[x].push_back(y),v[y].push_back(x);
	for(int i=1;i<=n;i++)scanf("%d",&val[i]);
	dfs(1,0);
	for(int i=1,x,y,z;i<=m;i++){
		scanf("%d%d",&x,&y),makeroot(y);
		if(x==1)scanf("%d",&z),z-=val[y],res+=1ll*z*z*t[y].sz+2ll*z*t[y].sum,ADD(y,z),val[y]+=z;
		else pushall(1),printf("%lld\n",res+1ll*t[1].val*(1ll*(t[y].sz+1)*t[1].val-2ll*t[y].sum));
	}
	return 0;
}

法二:推另一种式子,然后点分树维护

我们仍然令\(val_x\)为单点权值,但这里的\(sum_x\)对于一次询问,以被询问节点为树根时,子树权值和。再令\(all=\sum val_x\)

则我们要求的是\(\sum\limits_{i=1}^n(sum_i)^2\)

看着不太爽吧?毕竟动态点分治更侧重于路径的维护(这一点跟LCT类似,但是动态点分治比起LCT还要更“路径”一点——它几乎维护不了子树信息)。

我们尝试将其变化成\(all\sum\limits_{i=1}^nsum_i-\sum\limits_{i=1}^nsum_i(all-sum_i)\)

后一半,我们发现乘起来的两部分,可以被抽象为由一条边连接着的两个子树的权值和乘一起——这恰恰证明了后面的东西与根无关,因为这个\(\sum\)它枚举了每一条边。

我们进一步可以把它拆成两个集合\(A\)\(B\)表示两半子树,则它实际上等价于\(\sum(\sum\limits_{i\in A}val_i)(\sum\limits_{j\in B}val_j)=\sum\sum\limits_{i\in A}\sum\limits_{j\in B}val_ival_j\)

我们发现,对于每一对\(i,j\),它们会在两点间路径上的每一条边处被计算一次。

因此上面实际上也可以被转成\(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}val_ival_j\operatorname{dis}(i,j)\)。这里就只与路径信息有关了。并且,因为这个值与根无关,我们实际上还可以做的更多。

现在话归前一半。\(all\)很容易维护,关键是求出\(\sum\limits_{i=1}^{n}sum_i\)

它可以等价于\(\sum\limits_{i=1}^{n}val_i\Big(\operatorname{dis}(i,root)+1\Big)\),因为每个点的点权会贡献给它所有的祖先——这一数量等于\(\operatorname{dis}(i,root)+1\)

我们将其拆开,便得到了\(\sum\limits_{i=1}^{n}val_i\operatorname{dis}(i,root)+\sum\limits_{i=1}^nval_x=all+\sum\limits_{i=1}^{n}val_i\operatorname{dis}(i,root)\)

我们发现后面这一半东西就可以轻松用点分治维护了。我们设\(\operatorname{calc}(x)=\sum\limits_{i=1}^{n}val_i\operatorname{dis}(i,x)\)

则最终该式即被转换成\(all\operatorname{calc}(x)-\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}val_ival_j\operatorname{dis}(i,j)\)

我们考虑设后面一大坨为\(ALL\)

考虑当\(val_i\)增大\(\Delta\)后,\(ALL\)会发生什么变化:

\(\begin{aligned}\Delta_{ALL}&=\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}val_ival_j\operatorname{dis}(i,j)-\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}(val_i+\Delta)val_j\operatorname{dis}(i,j)\\&=\Delta\sum\limits_{j=1}^{n}val_j\operatorname{dis}(i,j)\\&=\Delta\operatorname{calc}(i)\end{aligned}\)

则只需要在修改时顺手维护掉\(ALL\)即可。

则最终答案即为\(all\operatorname{calc}(x)-ALL\)

至此本题解决。

明显该算法复杂度为\(O(n\log n)\)——假如你用RMQ求\(\operatorname{dis}\)的话。但是其常数远大于LCT——大约是其\(3\)倍。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,val[200100],fa[200100],dep[200100],in[200100],tot,mn[400100][20],LG[400100];
namespace Tree{
    vector<int>v[200100];
    int sz[200100],SZ,msz[200100],ROOT;
    bool vis[200100];
    void getsz(int x,int fa){
        sz[x]=1;
        for(auto y:v[x])if(!vis[y]&&y!=fa)getsz(y,x),sz[x]+=sz[y];
    }
    void getroot(int x,int fa){
        sz[x]=1,msz[x]=0;
        for(auto y:v[x])if(!vis[y]&&y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
        msz[x]=max(msz[x],SZ-sz[x]);
        if(msz[x]<msz[ROOT])ROOT=x;
    }
    void solve(int x){
        getsz(x,0); 
        vis[x]=true;
        for(auto y:v[x]){
            if(vis[y])continue;
            ROOT=0,SZ=sz[y],getroot(y,0),fa[ROOT]=x,solve(ROOT);
		}
	}
    void getural(int x,int fa){
		mn[++tot][0]=x,in[x]=tot;
		for(auto y:v[x])if(y!=fa)dep[y]=dep[x]+1,getural(y,x),mn[++tot][0]=x;
    }
}
int MIN(int i,int j){
    return dep[i]<dep[j]?i:j;
}
int LCA(int i,int j){
	if(i>j)swap(i,j);
	int k=LG[j-i+1];
	return MIN(mn[i][k],mn[j-(1<<k)+1][k]);
}
int DIS(int i,int j){
	return dep[i]+dep[j]-dep[LCA(in[i],in[j])]*2;
}
namespace cdt{
	ll sf[200100],pa[200100],all,sz[200100],ALL;
	ll ask(int x){
		ll res=0;
		for(int u=x;u;u=fa[u]){
			res+=sf[u];
			res+=1ll*DIS(u,x)*sz[u];
			if(fa[u])res-=pa[u],res-=1ll*DIS(fa[u],x)*sz[u];
		}
		return res;
	}
	void change(int x,int delta){
		ALL+=1ll*ask(x)*delta;
		for(int u=x;u;u=fa[u]){
			sz[u]+=delta;
			sf[u]+=DIS(u,x)*delta;
			if(fa[u])pa[u]+=DIS(fa[u],x)*delta;
		}
		val[x]+=delta,all+=delta;
	}
	ll solve(int x){
		return all*(all+ask(x))-ALL;
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),Tree::v[x].push_back(y),Tree::v[y].push_back(x);
	Tree::msz[0]=0x3f3f3f3f,Tree::SZ=n,Tree::getroot(1,0),Tree::solve(Tree::ROOT);
	Tree::getural(1,0);
	for(int i=2;i<=tot;i++)LG[i]=LG[i>>1]+1;
	for(int j=1;j<=LG[tot];j++)for(int i=1;i+(1<<j)-1<=tot;i++)mn[i][j]=MIN(mn[i][j-1],mn[i+(1<<(j-1))][j-1]);
	for(int i=1,x;i<=n;i++)scanf("%d",&x),cdt::change(i,x);
    for(int i=1,x,y,z;i<=m;i++){
   		scanf("%d%d",&x,&y);
		if(x==1)scanf("%d",&z),cdt::change(y,z-val[y]);
		else printf("%lld\n",cdt::solve(y));
	}
	return 0;
} 

原文地址:https://www.cnblogs.com/Troverld/p/14605843.html