【考试题】爬树

爬树

题解极为不负责任,啥也没看懂,看了好久好久的 (std) 才懂

首先发现如果有一段的 (-1) 的话可以用组合数推出来方案

令长度为 (len) ,给定的左右上界是 ([l,r])

那么方案就是至多选择 (r-l+1) 个位置将权值加一

[sum_{i=0} ^{r-l+1}inom {len} i ]

这东西用一些恒等变形可以转化一下

所以我们直接维护一段上的 (-1) 数量和上下界

然后对于每次的 ([a,b]) 的限制直接考虑对于最左侧和最右侧的信息即可

如果两边不行那么就没有方案

然后这题目主算法是树剖+线段树维护信息:(-1) 的段数,两侧的上下界,

当然,想的东西不算很多

主要是难写,巨难写

几个注意的点:

(1.) 树剖对于 (x o lca)(lca o y) 的做法是不一样的

这里需要分别写两个函数

(2.) 不能把结构体和 (0) 直接 (push\_up) ,得记录是不是加上过了

(3.) 写代码的时候要全神贯注,不能手残啥的

如果像我这种手残脑子还不在的,比如:

(fac[i]=mul(fac[i-1],i-1))

就铁退役了

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define reg register
namespace yspm{
	inline int read()
	{
		int res=0,f=1; char k;
		while(!isdigit(k=getchar())) if(k=='-') f=-1;
		while(isdigit(k)) res=res*10+k-'0',k=getchar();
		return res*f;
	}
	const int N=1e5+10,inf=0x3f3f3f3f,mod=1e9+7;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int del(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return x*y-x*y/mod*mod;}
	int fac[N*10],inv[N*10];
	inline int C(int n,int m){return (m>=0&&n>=m)?mul(fac[n],mul(inv[m],inv[n-m])):0;}
	inline int ksm(int x,int y)
	{
		int res=1; for(;y;y>>=1,x=mul(x,x)) if(y&1) res=mul(res,x);
		return res;
	} 
	struct node{
		int to,nxt;
	}e[N<<1];
	int head[N],sz[N],cnt,tim,dfn[N],ord[N],son[N],dep[N],fa[N],top[N],v[N],n,m;
	inline void adde(int u,int v)
	{
		e[++cnt].to=v; e[cnt].nxt=head[u];
		return head[u]=cnt,void();
	}
	inline void dfs1(int x,int fat)
	{
		fa[x]=fat; sz[x]=1; dep[x]=dep[fa[x]]+1; 
		for(reg int i=head[x];i;i=e[i].nxt)
		{
			int t=e[i].to; if(t==fat) continue;
			dfs1(t,x); sz[x]+=sz[t]; 
			if(sz[t]>sz[son[x]]) son[x]=t;
		}
		return ;
	}
	inline void dfs2(int x,int topf)
	{
		top[x]=topf; dfn[x]=++tim; ord[tim]=x; 
		if(!son[x]) return ; dfs2(son[x],topf);
		for(int i=head[x];i;i=e[i].nxt)
		{
			int t=e[i].to; if(t==son[x]||t==fa[x]) continue;
			dfs2(t,t);
		} return ;
	} 
	struct point{
		int len1,len2,len,llen,rlen,l1,l2,r1,r2,l,r,sum,cnt;
	}f1[N<<2],f2[N<<2];
	inline point push_up(point a,point b)
	{
		point ans; 
		ans.len=a.len+b.len;//区间 
		ans.cnt=a.cnt+b.cnt-(a.rlen>0&&b.llen>0);//-1 的段数,如果两边是-1,那就合并 
		ans.llen=a.llen+(a.len==a.llen?b.llen:0);//左侧 -1 的长度 
		ans.rlen=b.rlen+(b.len==b.rlen?a.rlen:0);//右侧 -1 的长度 
		if(a.cnt) ans.len1=a.len1+(a.cnt==1&&a.rlen?b.llen:0); else ans.len1=b.len1;
		if(b.cnt) ans.len2=b.len2+(b.cnt==1&&b.llen?a.rlen:0); else ans.len2=a.len2;
		//左边段和右边段的长度 
		ans.l=a.len==a.llen?b.l:a.l;
		ans.r=b.rlen==b.len?a.r:b.r;
		//上下界 
		if(a.cnt) 
		{
			ans.l1=a.l1;
			if(a.cnt==1&&a.rlen) ans.r1=(~b.l)?b.l:inf; 
			else ans.r1=a.r1;
		}
		else 
		{
			ans.l1=b.llen?a.r:b.l1;
			ans.r1=b.r1;
		}
		if(b.cnt)
		{
			ans.r2=b.r2;
			if(b.cnt==1&&b.llen) ans.l2=(~a.r)?a.r:-inf;
			else ans.l2=b.l2; 
		}
		else 
		{
			ans.l2=a.l2;
			ans.r2=a.rlen?b.l:a.r2;
		}//两段上下界 
		if(a.llen==a.len||b.llen==b.len) ans.sum=mul(a.sum,b.sum);
		else
		{
			int t1=b.l-a.r,t2=a.rlen+b.llen;
			ans.sum=mul(mul(a.sum,b.sum),C(t1+t2,t2));
		}//方案统计,这里就是中间的段合并起来 
		return ans;
	}
	inline void push_up(int p)
	{
		f1[p]=push_up(f1[p<<1],f1[p<<1|1]); 
		f2[p]=push_up(f2[p<<1|1],f2[p<<1]); 
		return ;
	}
	inline void build(int p,int l,int r)
	{
		if(l==r) 
		{
			f1[p].l=f1[p].r=v[ord[l]];
			f1[p].l1=f1[p].l2=-inf; f1[p].r1=f1[p].r2=inf;
			f1[p].len=1; 
			f1[p].len1=f1[p].len2=f1[p].cnt=f1[p].llen=f1[p].rlen=(v[ord[l]]==-1)?1:0;
			f1[p].sum=1; 
			return f2[p]=f1[p],void();	
		}int mid=(l+r)>>1;
		build(p<<1,l,mid); build(p<<1|1,mid+1,r);
		return push_up(p);
	}
	inline void update(int p,int l,int r,int pos,int val)
	{
		if(l==r)
		{
			f1[p].l=f1[p].r=val; 
			f1[p].l1=f1[p].l2=-inf; f1[p].r1=f1[p].r2=inf;
			f1[p].len=1; f1[p].len1=f1[p].len2=f1[p].cnt=f1[p].llen=f1[p].rlen=(val==-1)?1:0;
			f1[p].sum=1; f2[p]=f1[p]; 
			return ;
		}int mid=(l+r)>>1;
		if(pos<=mid) update(p<<1,l,mid,pos,val);
		else update(p<<1|1,mid+1,r,pos,val); 
		return push_up(p);  
	} 
	inline point ask(int p,int l,int r,int st,int ed,bool fl)
	{
		if(st<=l&&r<=ed) return fl?f2[p]:f1[p];
		int mid=(l+r)>>1;
		if(st>mid) return ask(p<<1|1,mid+1,r,st,ed,fl);
		if(ed<=mid) return ask(p<<1,l,mid,st,ed,fl);
		if(fl) return push_up(ask(p<<1|1,mid+1,r,st,ed,fl),ask(p<<1,l,mid,st,ed,fl));
		else return push_up(ask(p<<1,l,mid,st,ed,fl),ask(p<<1|1,mid+1,r,st,ed,fl));
	}
	inline point query(int x,int y)
	{
		point ans,tmp;
		bool fl=0,fr=0;
		while(top[x]!=top[y])
		{
			if(dep[top[x]]>dep[top[y]]) 
			{
				if(!fl) fl=1,ans=ask(1,1,n,dfn[top[x]],dfn[x],1);
				else ans=push_up(ans,ask(1,1,n,dfn[top[x]],dfn[x],1));
				x=fa[top[x]];
			}
			else
			{
				if(!fr) fr=1,tmp=ask(1,1,n,dfn[top[y]],dfn[y],0);
				else tmp=push_up(ask(1,1,n,dfn[top[y]],dfn[y],0),tmp);
				y=fa[top[y]];
			}
		} 
		if(dep[x]<dep[y]) 
		{
			if(!fl) ans=ask(1,1,n,dfn[x],dfn[y],0);
			else ans=push_up(ans,ask(1,1,n,dfn[x],dfn[y],0));
		}
		else
		{
			if(!fl) ans=ask(1,1,n,dfn[y],dfn[x],1);
			else ans=push_up(ans,ask(1,1,n,dfn[y],dfn[x],1));
		}if(fr) ans=push_up(ans,tmp);
		return ans;
	}
	inline int calc(int x,int y,int a,int b)
	{
		point s=query(x,y);
		int ans=s.sum;
		if(s.cnt)
		{
			int r1,r2;
			if(s.cnt==1) 
			{ 
				if(!s.llen&&!s.rlen) ans=mul(ans,ksm(C(s.r1-s.l1+s.len1,s.len1),mod-2));
				r1=min(s.r1,b)-max(s.l1,a); r2=s.len1;
				ans=mul(ans,C(r1+r2,r2));
			}
			else
			{
				if(!s.llen) ans=mul(ans,ksm(C(s.r1-s.l1+s.len1,s.len1),mod-2));
				if(!s.rlen) ans=mul(ans,ksm(C(s.len2+s.r2-s.l2,s.len2),mod-2));
				r1=min(s.r1,b)-max(s.l1,a); r2=s.len1;
				ans=mul(ans,C(r1+r2,r2));
				r1=min(s.r2,b)-max(s.l2,a); r2=s.len2;
				ans=mul(ans,C(r1+r2,r2));
			}
		} return ans;
	}
	signed main()
	{
		freopen("tree.in","r",stdin);
		freopen("tree.out","w",stdout); 
		fac[1]=inv[1]=fac[0]=inv[0]=1;
		for(reg int i=2;i<N*10;++i) fac[i]=mul(fac[i-1],i),inv[i]=del(mod,mul(inv[mod%i],mod/i));
		for(reg int i=1;i<N*10;++i) inv[i]=mul(inv[i],inv[i-1]); 
		n=read(); m=read(); 
		for(reg int i=1;i<=n;++i) v[i]=read();
		for(reg int i=1;i<n;++i) 
		{
			int x=read(),y=read();
			adde(x,y); adde(y,x); 
		} dfs1(1,0); dfs2(1,1); build(1,1,n); 
		while(m--)
		{
			if(read()-1)
			{
				int x=read(),y=read(),a=read(),b=read();
				printf("%lld
",calc(x,y,a,b)); 
			}
			else
			{
				int pos=read(),val=read();
				update(1,1,n,dfn[pos],val);
			}
		} 
		return 0;
	}
}
signed main(){return yspm::main();}

本来是应该放到 (October) 泛做的

但是印象太为深刻(题解不负责而且巨难写)就单拎出来了

原文地址:https://www.cnblogs.com/yspm/p/13758714.html