HDU6964 I Love Counting(2021HDU多校第二场1004)(平衡树/树状数组+二维数点+字典树)

题意:

给出一个序列。

每次询问一个区间内有多少个不同的数异或a<=b。

题解:

首先有个前置知识,就是不带区间的情况下有多少个不同的数异或a<=b,这是一个经典的字典树上DP的模型,找到对应的子树统计信息即可,这里不再赘述。

然后考虑区间,如果把不同的数这个条件去掉,可以直接上可持久化字典树。

但是任何一个可持久化数据结构都无法处理不同这个条件。

做法一

比赛的时候想了一个莫队套字典树的做法,就是在莫队的过程中维护一颗字典树,这个思路比较好想,时间复杂度(O(nlognsqrt{n}))

比赛中居然有人用这个时间复杂度卡过去了?

但是也有一种莫队好像可以把log去掉,不得不说是真的nb

做法二

对字典树上的每个节点维护子树内所有数的前驱。这里我用的Splay树维护每个节点的前驱集合。

然后从左往右更新数组,先把与当前元素有关的所有节点的Splay更新,然后处理以当前下标为右端点的询问,与每个询问相关的子树数量是log级的,对这些子树的Splay查询比询问的左端点大的前驱数量。

求和就是答案,这样搞时间复杂度是(O(nlognlogn))的,但是对每个节点维护一颗Splay,好像复杂度并不能均摊,在HDU上稳T,在luogu上跑1.67s。

做法三

在做法二的基础上用空间换时间。对每个节点维护两个链表,一个表示与这个节点相关的所有数和它们的位置,一个表示与这个节点相关的询问。

然后遍历所有节点,先把所有节点的询问按右端点从小到大排序,然后从左往右遍历询问,同时在节点的元素链表里维护一个指针,每次把出现位置小于等于当前询问右端点的前驱全部更新到Splay树里,然后在Splay上询问比左端点大的数的数量。

这里由于不断要对Splay树做插入删除的操作,导致内存爆炸,还要手写一个垃圾回收。

时间复杂度(O(nlognlogn)),但是只用一颗Splay树,常数得到进一步优化,在HDU上1.9s AC。

做法四

在做法二的基础上,用树状数组代替Splay树,树状数组常数是真的很小,问题得以解决。在洛谷上1.25s,在HDU上1.4s。这好像也是std的做法。

总结

这道题没用什么很高级的思想,就是不断的通过一些小技巧优化常数,真的学到许多,是一道很好的数据结构练习题。

代码

这里贴上做法三和做法四的代码

做法三:

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int M=1e5*30;

int fa[M],ch[M][2],val[M],cnt[M],sz[M],tot,hs[M],ts;
struct Splay {
	int rt=0;
	void maintain (int x) {
		sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
	}
	bool get (int x) {
		return x==ch[fa[x]][1];
	}
	void clear (int x) {
		ch[x][0]=ch[x][1]=fa[x]=val[x]=sz[x]=cnt[x]=0;
	}
	void rotate (int x) {
		int y=fa[x];
		int z=fa[y];
		int chk=get(x);
		ch[y][chk]=ch[x][chk^1];
		if (ch[x][chk^1]) {
			fa[ch[x][chk^1]]=y;
		}
		ch[x][chk^1]=y;
		fa[y]=x;
		fa[x]=z;
		if (z) {
			ch[z][y==ch[z][1]]=x;
		}
		maintain(x);
		maintain(y);
	}
	void splay (int x) {
		for (int f=fa[x];f=fa[x];rotate(x)) {
			if (fa[f]) {
				rotate(get(x)==get(f)?f:x);
			}
		}
		rt=x;
	}
	void ins (int k) {
		if (!rt) {
			if (ts==0) {
				val[++tot]=k;
				cnt[tot]++;
				rt=tot;
			}
			else {
				val[hs[ts]]=k;
				cnt[hs[ts]]++;
				rt=hs[ts];
				ts--;
			}
			maintain(rt);
			return;
		}
		int cur=rt,f=0;
		while (1) {
			if (val[cur]==k) {
				cnt[cur]++;
				maintain(cur);
				maintain(f);
				splay(cur);
				break;
			}
			f=cur;
			cur=ch[cur][val[cur]<k];
			if (!cur) {
				if (ts==0) {
					val[++tot]=k;
					cnt[tot]++;
					fa[tot]=f;
					ch[f][val[f]<k]=tot;
					maintain(tot);
					maintain(f);
					splay(tot);
				}
				else {
					val[hs[ts]]=k;
					cnt[hs[ts]]++;
					fa[hs[ts]]=f;
					ch[f][val[f]<k]=tot;
					maintain(hs[ts]);
					maintain(f);
					splay(hs[ts]);
					ts--;
				}
				
				break;
			}
		}
	}
	int rk (int k) {
		int res=0;
		int cur=rt;
		while (1) {
			if (!cur) {
				return res+1;
			} 
			if (k<val[cur]) {
				cur=ch[cur][0];
			}
			else 
			{
				res+=sz[ch[cur][0]];
				if (k==val[cur]) {
					splay(cur);
					return res+1;
				}
				res+=cnt[cur];
				cur=ch[cur][1];
			}
		}
	}
	int pre () {
		int cur=ch[rt][0];
		if (!cur) return cur;
		while (ch[cur][1]) {
			cur=ch[cur][1];
		}
 		splay(cur);
		return cur;	
	}
	void del (int k) {
		rk(k);
		if (cnt[rt]>1) {
			cnt[rt]--;
			maintain(rt);
			return;
		}
		if (!ch[rt][0]&&!ch[rt][1]) {
			clear(rt);
			rt=0;
			return;
		}
		if (!ch[rt][0]) {
			int cur=rt;
			rt=ch[rt][1];
			fa[rt]=0;
			clear(cur);
			return;
		}
		if (!ch[rt][1]) {
			int cur=rt;
			rt=ch[rt][0];
			fa[rt]=0;
			clear(cur);
			return;
		}
		int cur=rt;
		int x=pre();
		fa[ch[cur][1]]=x;
		ch[x][1]=ch[cur][1];
		clear(cur);
		maintain(rt);
	}
};

Splay * splay;

int tr[M][2],tol;
vector<int> g[maxn];//数字i的二进制形式
void zh (int x) { 
	if (g[x].size()) return;
	int u=x;
	while (x) {
		g[u].push_back(x%2);
		x/=2;
	}
	while (g[u].size()<17) g[u].push_back(0);
	int uu=0;
	for (int i=16;i>=0;i--) {
		if (!tr[uu][g[u][i]]) tr[uu][g[u][i]]=++tol;
		uu=tr[uu][g[u][i]];
	}
} 
int Pre[maxn];//保存每个数的前驱 
int n,a[maxn];
struct qnode {
	int id,a,b,l,r;
	bool operator < (const qnode &x) const {
		return r<x.r;
	}
};
int ans[maxn];
vector<pair<int,int> > ys[M];//对每个节点维护一个元素数组 
vector<qnode> xy[M];//对每个节点维护一个询问数组 
void insert (int u,int x,int dep,int i) {
	if (u) {
		ys[u].push_back(make_pair(a[i],i));
	}
	if (dep<0) return;
	insert(tr[u][g[x][dep]],x,dep-1,i);
}
void query (int u,int a,int b,int dep,int l,int r,int id) {
	//在字典树上找到对应的子树
	if (dep<0) {
		//return splay[u].rk(r+1)-1;
		xy[u].push_back({id,a,b,l,r});//这个询问涉及到的节点 
	}
	if (g[a][dep]==1&&g[b][dep]==0) {
        if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==1&&g[b][dep]==1) {
    	if (tr[u][1]) xy[tr[u][1]].push_back({id,a,b,l,r});
        if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==0&&g[b][dep]==1) {
    	if (tr[u][0]) xy[tr[u][0]].push_back({id,a,b,l,r});
        if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==0&&g[b][dep]==0) {
       if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
    }
}

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
    while (c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
int main () {
	//100000*20*20*10
	splay=new Splay();
	n=read();
	for (int i=1;i<=n;i++) a[i]=read();
	for (int i=1;i<=n;i++) {
		zh(a[i]);
	}
	int m;
	m=read();
	for (int i=1;i<=m;i++) {
		int l,r,A,B;
		//scanf("%d%d%d%d",&l,&r,&A,&B);
		l=read();r=read();A=read();B=read();
		zh(A);
		zh(B);
		query(0,A,B,16,l,r,i);
	}
	for (int i=1;i<=n;i++) {
		insert(0,a[i],16,i);
	}
	for (int i=1;i<=tol;i++) {
		//遍历每个节点
		sort(xy[i].begin(),xy[i].end());
		int l=0;
		for (qnode it:xy[i]) {
			while (l<ys[i].size()&&ys[i][l].second<=it.r) {
				int x=ys[i][l].first;
				if (Pre[x]) splay->del(Pre[x]);
				Pre[x]=ys[i][l].second;
				splay->ins(Pre[x]);
				l++;
			}
			ans[it.id]+=splay->rk(1e9)-splay->rk(it.l);
		} 
		for (pair<int,int> it:ys[i]) splay->del(Pre[it.first]),Pre[it.first]=0;
	} 
	for (int i=1;i<=m;i++) printf("%d
",ans[i]); 
	
}


做法四:


#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int M=1e5*30;

int c[maxn];
int lowbit (int x) {
	return x&-x;
}
void up (int p,int v) {
	for (int i=p;i<maxn;i+=lowbit(i)) c[i]+=v;
}
int getsum (int p) {
	int ans=0;
	for (int i=p;i;i-=lowbit(i)) ans+=c[i];
	return ans;
}



int tr[M][2],tol;
vector<int> g[maxn];//数字i的二进制形式
void zh (int x) { 
	if (g[x].size()) return;
	int u=x;
	while (x) {
		g[u].push_back(x%2);
		x/=2;
	}
	while (g[u].size()<17) g[u].push_back(0);
	int uu=0;
	for (int i=16;i>=0;i--) {
		if (!tr[uu][g[u][i]]) tr[uu][g[u][i]]=++tol;
		uu=tr[uu][g[u][i]];
	}
} 
int Pre[maxn];//保存每个数的前驱 
int n,a[maxn];
struct qnode {
	int id,a,b,l,r;
	bool operator < (const qnode &x) const {
		return r<x.r;
	}
};
int ans[maxn];
vector<pair<int,int> > ys[M];//对每个节点维护一个元素数组 
vector<qnode> xy[M];//对每个节点维护一个询问数组 
void insert (int u,int x,int dep,int i) {
	if (u) {
		ys[u].push_back(make_pair(a[i],i));
	}
	if (dep<0) return;
	insert(tr[u][g[x][dep]],x,dep-1,i);
}
void query (int u,int a,int b,int dep,int l,int r,int id) {
	//在字典树上找到对应的子树
	if (dep<0) {
		//return splay[u].rk(r+1)-1;
		xy[u].push_back({id,a,b,l,r});//这个询问涉及到的节点 
	}
	if (g[a][dep]==1&&g[b][dep]==0) {
        if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==1&&g[b][dep]==1) {
    	if (tr[u][1]) xy[tr[u][1]].push_back({id,a,b,l,r});
        if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==0&&g[b][dep]==1) {
    	if (tr[u][0]) xy[tr[u][0]].push_back({id,a,b,l,r});
        if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
    }
    else if (g[a][dep]==0&&g[b][dep]==0) {
       if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
    }
}

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
    while (c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
int main () {
	//100000*20*20*10
	n=read();
	for (int i=1;i<=n;i++) a[i]=read();
	for (int i=1;i<=n;i++) {
		zh(a[i]);
	}
	int m;
	m=read();
	for (int i=1;i<=m;i++) {
		int l,r,A,B;
		//scanf("%d%d%d%d",&l,&r,&A,&B);
		l=read();r=read();A=read();B=read();
		zh(A);
		zh(B);
		query(0,A,B,16,l,r,i);
	}
	for (int i=1;i<=n;i++) {
		insert(0,a[i],16,i);
	}
	for (int i=1;i<=tol;i++) {
		//遍历每个节点
		sort(xy[i].begin(),xy[i].end());
		int l=0;
		for (qnode it:xy[i]) {
			while (l<ys[i].size()&&ys[i][l].second<=it.r) {
				int x=ys[i][l].first;
				if (Pre[x]) up(Pre[x],-1);
				Pre[x]=ys[i][l].second;
				up(Pre[x],1);
				l++;
			}
			ans[it.id]+=getsum(n+1)-getsum(it.l-1);
		} 
		for (pair<int,int> it:ys[i]) if (Pre[it.first])up(Pre[it.first],-1),Pre[it.first]=0;
	} 
	for (int i=1;i<=m;i++) printf("%d
",ans[i]); 
	
}
原文地址:https://www.cnblogs.com/zhanglichen/p/15049007.html