YbtOJ#791子集最值【三维偏序】

正题

题目链接:http://www.ybtoj.com.cn/contest/123/problem/1


题目大意

给出\(3\)个长度为\(n\)的排列\(A,B,C\)。然后一个下标集合\(S\)的三元组是

\[(max\{A_i\},max\{B_i\},max\{C_i\})(i\in S) \]

求所有下标集合不同的三元组数量
\(1\leq n\leq 10^5\)


解题思路

所有下标集合的三元组都能用一个\(|S|\leq 3\)的集合代替,所以我们只考虑\(|S|\leq 3\)的就好了。

\(|S|=1\)的个数就是\(n\),直接累加即可。

\(|S|=2\)的话,那就代表某个下标霸占了两个最大值,而另一个一定是另一个下标的,如果是\(a,b\)最大,那么我们就要找满足\(a_i> a_j,b_i> a_j,c_i< c_j\)的方案,用三维偏序就好了。

然后\(a,c\)\(b,c\)的情况也都要做

\(|S|=3\)的话很麻烦,考虑容斥,总方案\(\binom n 3\)减去有一个下标是至少两个的最大值。
同样和上面,先考虑\(a,b\),假设下标\(i\)满足\(a_i>a_j,b_i>b_j\)的情况有\(k\)种,那么就好有\(\binom{k}{2}\)种情况使得\(i\)占据了至少两个最大值。
同理\(a,c\)\(b,c\)也要做,这是二维偏序,直接树状数组就好了。

但是发现对于\(i\)占据了三个最大值的情况我们统计了三次,需要加回多余的两次,那么统计\(a_i>a_j,b_i>b_j,c_i>c_j\)的个数\(k\),然后加回\(k(k-1)\)的方案就好了,这个也要三维偏序

代码里三维偏序用的是\(CDQ\)分治+树状数组

时间复杂度\(O(n\log^2 n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=1e5+10;
struct node{
	ll a,b,c;
}w[N],a[N],b[N];
ll n,ans,sum,t[N],g[N];
void Change(ll x,ll val){
	while(x<=n){
		t[x]+=val;
		x+=lowbit(x);
	}
	return;
}
ll Ask(ll x){
	ll ans=0;
	while(x){
		ans+=t[x];
		x-=lowbit(x);
	}
	return ans;
}
void Merge(ll l,ll mid,ll r){
	ll p=l,q=mid+1;
	for(ll i=1;i<=r-l+1;i++){
		if(p<=mid&&w[p].b<=w[q].b||q>r)b[i]=w[p],p++;
		else b[i]=w[q],q++;
	}
	for(ll i=1;i<=r-l+1;i++)w[l+i-1]=b[i];
	return;
}
void CDQ(ll l,ll r,bool op){
	if(l==r)return;
	ll mid=(l+r)>>1;
	CDQ(l,mid,op);CDQ(mid+1,r,op);
	ll p=l,tmp;
	for(ll i=mid+1;i<=r;i++){
		while(p<=mid&&w[p].b<w[i].b)
			Change(w[p].c,1),p++;
		sum+=(tmp=Ask(w[i].c));
		g[w[i].a]+=(op?tmp:0);
	}
	for(ll i=l;i<p;i++)Change(w[i].c,-1);
	Merge(l,mid,r);return;
}
bool cmp(node x,node y)
{return x.a<y.a;}
void solve(){
	sort(w+1,w+1+n,cmp);
	for(ll i=1;i<=n;i++){
		ll tmp=Ask(w[i].b);
		ans-=tmp*(tmp-1)/2;
		Change(w[i].b,1);
	}
	memset(t,0,sizeof(t));
	return;
}
signed main()
{
	freopen("subset.in","r",stdin);
	freopen("subset.out","w",stdout);
	scanf("%lld",&n);ans=n;
	for(ll i=1;i<=n;i++)scanf("%lld",&a[i].a);
	for(ll i=1;i<=n;i++)scanf("%lld",&a[i].b);
	for(ll i=1;i<=n;i++)scanf("%lld",&a[i].c);
		
	for(ll i=1;i<=n;i++)
		w[i].a=a[i].a,w[i].b=a[i].b,w[i].c=n-a[i].c+1;
	sort(w+1,w+1+n,cmp);CDQ(1,n,0);
	for(ll i=1;i<=n;i++)
		w[i].a=a[i].a,w[i].b=a[i].c,w[i].c=n-a[i].b+1;
	sort(w+1,w+1+n,cmp);CDQ(1,n,0);
	for(ll i=1;i<=n;i++)
		w[i].a=a[i].b,w[i].b=a[i].c,w[i].c=n-a[i].a+1;
	sort(w+1,w+1+n,cmp);CDQ(1,n,0);
	ans+=sum;ans+=n*(n-1)*(n-2)/6;
	
	for(ll i=1;i<=n;i++)w[i].a=a[i].a,w[i].b=a[i].b;solve();
	for(ll i=1;i<=n;i++)w[i].a=a[i].b,w[i].b=a[i].c;solve();
	for(ll i=1;i<=n;i++)w[i].a=a[i].a,w[i].b=a[i].c;solve();
	
	for(ll i=1;i<=n;i++)
		w[i].a=a[i].a,w[i].b=a[i].b,w[i].c=a[i].c;
	sort(w+1,w+1+n,cmp);
	CDQ(1,n,1);
	for(ll i=1;i<=n;i++)ans+=g[i]*(g[i]-1);
	printf("%lld\n",ans);
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14441885.html