「NOI2020」命运

我竟然一直写了个假的线段树合并!

考场上连(O(n^2))暴力都不会。。我是SB
感觉(O(2^m mlog n))的暴力容斥的部分分就是用来迷惑人的。。
(dp_{u,j})表示从(u)开始往上(j)条边全是(0)且第(j+1)条边为(1)的方案数。
(其实定义成到深度为(j)的边全是(0)更好,我不知道我为何要自找麻烦。。就按这个接着说吧)
每个点有一个(lim_u)表示(u)往上最多允许有几条连续的(0)边。每次新增条件的时候就更新一下它。

[ dp_{u,j}=left{ egin{aligned} 0 & & j > lim_u\ prod_{v in son_u} (dp_{v,j+1}+dp_{v,0}) & & 0 leq j leq lim_u end{aligned} ight. ]

然后如果(O(n^2))实现的比较好可以获得(72)分,代码如下:

#include<bits/stdc++.h>
using namespace std;
#define fp(i,l,r) for(register int (i)=(l);(i)<=(r);++(i))
#define fd(i,l,r) for(register int (i)=(l);(i)>=(r);--(i))
#define fe(i,u) for(register int (i)=front[(u)];(i);(i)=e[(i)].next)
#define mem(a) memset((a),0,sizeof (a))
#define O(x) cerr<<#x<<':'<<x<<endl
#define int long long
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x*f;
}
void wr(int x){
	if(x<0)putchar('-'),x=-x;
	if(x>=10)wr(x/10);
	putchar('0'+x%10);
}
const int MAXN=5e5+10,mod=998244353;
inline void tmod(int &x){x%=mod;}
struct tEdge{
	int v,next;
}e[MAXN<<1];
int front[MAXN],tcnt,dep[MAXN],n,m,lim[MAXN];
vector<int>dp[MAXN];
inline void adde(int u,int v){
	e[++tcnt]={v,front[u]};front[u]=tcnt;
}
void dfs1(int u,int f){
	dep[u]=dep[f]+1;lim[u]=dep[u]-1;
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dfs1(v,u);
	}
}
void dfs(int u,int f){
	lim[u]=min(lim[u],lim[f]+1);dp[u].resize(lim[u]+1);fp(i,0,lim[u])dp[u][i]=1;
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dfs(v,u);
		fd(j,lim[u],lim[v])tmod(dp[u][j]*=dp[v][0]);
		fd(j,min(lim[u],lim[v]-1),0)tmod(dp[u][j]*=dp[v][j+1]+dp[v][0]);
	}
}
main(){
	freopen("destiny.in","r",stdin);freopen("destiny.out","w",stdout);
	n=read();
	fp(i,1,n-1){
		int x=read(),y=read();adde(x,y);adde(y,x);
	}
	dfs1(1,0);m=read();
	fp(i,1,m){
		int x=read(),y=read();lim[y]=min(lim[y],dep[y]-dep[x]-1);
	}
	dfs(1,0);
	printf("%lld
",dp[1][0]);
	return 0;
}

然后可以线段树合并来优化了,就是每个节点我们先区间赋值为(1)然后再与儿子节点线段树合并。
然后我写了一份(48)分的代码。。

#include<bits/stdc++.h>
using namespace std;
#define fp(i,l,r) for(register int (i)=(l);(i)<=(r);++(i))
#define fd(i,l,r) for(register int (i)=(l);(i)>=(r);--(i))
#define fe(i,u) for(register int (i)=front[(u)];(i);(i)=e[(i)].next)
#define mem(a) memset((a),0,sizeof (a))
#define O(x) cerr<<#x<<':'<<x<<endl
#define int long long
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x*f;
}
void wr(int x){
	if(x<0)putchar('-'),x=-x;
	if(x>=10)wr(x/10);
	putchar('0'+x%10);
}
const int MAXN=5e5+10,mod=998244353;
inline void tmod(int &x){x%=mod;}
struct tEdge{
	int v,next;
}e[MAXN<<1];
int front[MAXN],tcnt,dep[MAXN],n,m,lim[MAXN],val[MAXN*80],nn,D;
signed ls[MAXN*80],rs[MAXN*80],cnt,rt[MAXN];
inline void adde(int u,int v){
	e[++tcnt]={v,front[u]};front[u]=tcnt;
}
void dfs1(int u,int f){
	lim[u]=dep[u];
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dep[v]=dep[u]+1;dfs1(v,u);
	}
}
inline void pushdown(signed o){
	if(!val[o])return;
	if(!ls[o])ls[o]=++cnt;if(!rs[o])rs[o]=++cnt;
	val[ls[o]]+=val[o];val[rs[o]]+=val[o];val[o]=0;
}
void add(signed &o,int l,int r,int ql,int qr,int v){
	if(!o)o=++cnt;if(ql<=l&&qr>=r){tmod(val[o]+=v);return;}
	int mid=l+r>>1;pushdown(o);
	if(ql<=mid)add(ls[o],l,mid,ql,qr,v);
	if(qr>mid)add(rs[o],mid+1,r,ql,qr,v);
}
void merge(signed &o1,signed o2,int l,int r){
	if(!o1)return;if(!o2){o1=0;return;}
	if(l==r){tmod(val[o1]*=val[o2]);return;}
	int mid=l+r>>1;pushdown(o1);pushdown(o2);
	merge(ls[o1],ls[o2],l,mid);merge(rs[o1],rs[o2],mid+1,r);
}
int ask(signed o,int l,int r,int p){
	if(l==r)return val[o];
	int mid=l+r>>1;pushdown(o);
	if(p<=mid)return ask(ls[o],l,mid,p);
	else return ask(rs[o],mid+1,r,p);
}
void dfs(int u,int f){
	int d=n-dep[u],l=lim[u];
	add(rt[u],0,n,d,d+l,1);
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dfs(v,u);add(rt[v],0,n,d,d+l,ask(rt[v],0,n,d-1));
		merge(rt[u],rt[v],0,n);
	}
}
main(){
	freopen("destiny.in","r",stdin);freopen("destiny.out","w",stdout);
	n=read();nn=n*2;
	fp(i,1,n-1){
		int x=read(),y=read();adde(x,y);adde(y,x);
	}
	dfs1(1,0);m=read();
	fp(i,1,m){
		int x=read(),y=read();lim[y]=min(lim[y],dep[y]-dep[x]-1);
	}
	dfs(1,0);
	printf("%lld
",ask(rt[1],0,n,n));
	return 0;
}

这是因为我的merge函数会把我区间覆盖后得到的节点不断下放,时间复杂度退化为(O(n^2))。。
正确写法应该是每次看它的左右儿子是否都为空,若如此则不再递归。当然这样的话我们还得让线段树同时支持区间加法和乘法。
满分代码如下(不过常数很大。。)

#include<bits/stdc++.h>
using namespace std;
#define fp(i,l,r) for(register int (i)=(l);(i)<=(r);++(i))
#define fd(i,l,r) for(register int (i)=(l);(i)>=(r);--(i))
#define fe(i,u) for(register int (i)=front[(u)];(i);(i)=e[(i)].next)
#define mem(a) memset((a),0,sizeof (a))
#define O(x) cerr<<#x<<':'<<x<<endl
#define ll long long
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x*f;
}
void wr(int x){
	if(x<0)putchar('-'),x=-x;
	if(x>=10)wr(x/10);
	putchar('0'+x%10);
}
const int MAXN=5e5+10,mod=998244353;
inline void rmod(int &x){x+=x>>31&mod;}
struct tEdge{
	int v,next;
}e[MAXN<<1];
int front[MAXN],tcnt,dep[MAXN],n,m,lim[MAXN],val[MAXN*60],mul[MAXN*60];
int ls[MAXN*60],rs[MAXN*60],cnt,rt[MAXN];
inline void adde(int u,int v){
	e[++tcnt]={v,front[u]};front[u]=tcnt;
}
void dfs1(int u,int f){
	lim[u]=dep[u];
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dep[v]=dep[u]+1;dfs1(v,u);
	}
}
inline void mmul(int o,ll v){mul[o]=mul[o]*v%mod;val[o]=val[o]*v%mod;}
inline void madd(int o,int v){rmod(val[o]+=v-mod);}
inline void pushdown(int o){
	if(mul[o]==1&&!val[o])return;
	if(!ls[o])ls[o]=++cnt,mul[cnt]=1;if(!rs[o])rs[o]=++cnt,mul[cnt]=1;
	if(mul[o]!=1)mmul(ls[o],mul[o]),mmul(rs[o],mul[o]),mul[o]=1;
	if(val[o])madd(ls[o],val[o]),madd(rs[o],val[o]),val[o]=0;
}
void add(int &o,int l,int r,int ql,int qr,int v){
	if(!o)o=++cnt,mul[cnt]=1;if(ql<=l&&qr>=r){madd(o,v);return;}
	int mid=l+r>>1;pushdown(o);
	if(ql<=mid)add(ls[o],l,mid,ql,qr,v);
	if(qr>mid)add(rs[o],mid+1,r,ql,qr,v);
}
void merge(int &o1,int o2){
	if(!ls[o1]&&!rs[o1])swap(o1,o2);
	if(!ls[o2]&&!rs[o2]){mmul(o1,val[o2]);return;}
	pushdown(o1);pushdown(o2);
	merge(ls[o1],ls[o2]);merge(rs[o1],rs[o2]);
}
int ask(int o,int l,int r,int p){
	if(l==r)return val[o];
	int mid=l+r>>1;pushdown(o);
	if(p<=mid)return ask(ls[o],l,mid,p);
	else return ask(rs[o],mid+1,r,p);
}
void dfs(int u,int f){
	int d=n-dep[u],l=lim[u];
	add(rt[u],0,n,d,d+l,1);
	fe(i,u){
		int v=e[i].v;if(v==f)continue;
		dfs(v,u);add(rt[v],0,n,d,d+l,ask(rt[v],0,n,d-1));
		merge(rt[u],rt[v]);
	}
}
main(){
	freopen("destiny.in","r",stdin);freopen("destiny.out","w",stdout);
	n=read();
	fp(i,1,n-1){
		int x=read(),y=read();adde(x,y);adde(y,x);
	}
	dfs1(1,0);m=read();
	fp(i,1,m){
		int x=read(),y=read();lim[y]=min(lim[y],dep[y]-dep[x]-1);
	}
	dfs(1,0);
	printf("%d
",ask(rt[1],0,n,n));
	return 0;
}
原文地址:https://www.cnblogs.com/WinterSpell/p/14303703.html