hdu-5792 World is Exploding(容斥+树状数组)

题目链接:

World is Exploding

Time Limit: 2000/1000 MS (Java/Others)  

  Memory Limit: 65536/65536 K (Java/Others)


Problem Description
Given a sequence A with length n,count how many quadruple (a,b,c,d) satisfies: abcd,1a<bn,1c<dn,Aa<Ab,Ac>Ad.
 
Input
The input consists of multiple test cases. 
Each test case begin with an integer n in a single line.

The next line contains n integers A1,A2An.
1n50000
0Ai1e9
 
Output
For each test case,output a line contains an integer.
 
Sample Input
 
4
2 4 1 3
4
1 2 3 4
 
Sample Output
 
1
0
 
题意:
 
问符合题目给的四元组有多少个;
 
思路:
 
容斥,先算出a,b,c,d满足Aa<Ab&&Ac<Ad的个数,再减去a==c,a==d,b==c,b==d的个数,就是答案了,因为不可能有两个相等的出现;然后就是用树状数组
求pres[i],preb[i],nexs[i],nexb[i];分别表示第i个数前边比它小,比它大,后面比它小比它大的个数具体的看代码吧;
 
AC代码:
 
/************************************************
┆  ┏┓   ┏┓ ┆   
┆┏┛┻━━━┛┻┓ ┆
┆┃       ┃ ┆
┆┃   ━   ┃ ┆
┆┃ ┳┛ ┗┳ ┃ ┆
┆┃       ┃ ┆ 
┆┃   ┻   ┃ ┆
┆┗━┓    ┏━┛ ┆
┆  ┃    ┃  ┆      
┆  ┃    ┗━━━┓ ┆
┆  ┃  AC代马   ┣┓┆
┆  ┃           ┏┛┆
┆  ┗┓┓┏━┳┓┏┛ ┆
┆   ┃┫┫ ┃┫┫ ┆
┆   ┗┻┛ ┗┻┛ ┆      
************************************************ */ 
 
 
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <bits/stdc++.h>
#include <stack>
 
using namespace std;
 
#define For(i,j,n) for(int i=j;i<=n;i++)
#define mst(ss,b) memset(ss,b,sizeof(ss));
 
typedef  long long LL;
 
template<class T> void read(T&num) {
    char CH; bool F=false;
    for(CH=getchar();CH<'0'||CH>'9';F= CH=='-',CH=getchar());
    for(num=0;CH>='0'&&CH<='9';num=num*10+CH-'0',CH=getchar());
    F && (num=-num);
}
int stk[70], tp;
template<class T> inline void print(T p) {
    if(!p) { puts("0"); return; }
    while(p) stk[++ tp] = p%10, p/=10;
    while(tp) putchar(stk[tp--] + '0');
    putchar('
');
}
 
const LL mod=1e9+7;
const double PI=acos(-1.0);
const int inf=1e9;
const int N=5e4+10;
const int maxn=1e3+14;
const double eps=1e-8;

int n,fa[N],pres[N],preb[N],nexs[N],nexb[N],sum[N];
struct node
{
    int a,id;
}po[N];
int cmp(node x,node y)
{
    if(x.a==y.a)return x.id<y.id;
    return x.a<y.a;
}
int cmp1(node x,node y)
{
    if(x.a==y.a)return x.id<y.id;
    return x.a>y.a;
}

int lowbit(int x){return x&(-x);}

inline void update(int x)
{
    while(x<=n)
    {
        sum[x]++;
        x+=lowbit(x);
    }
}
int query(int x)
{
    int s=0;
    while(x)
    {
        s+=sum[x];
        x-=lowbit(x);
    }
    return s;
}
int main()
{      
        while(scanf("%d",&n)!=EOF)
        {
            For(i,1,n)read(po[i].a),po[i].id=i;
            sort(po+1,po+n+1,cmp);
            po[0].a=-1;
            mst(sum,0);
            For(i,1,n)
            {
                if(po[i].a==po[i-1].a)fa[i]=fa[i-1];
                else fa[i]=i;
                pres[po[i].id]=query(po[i].id)-(i-fa[i]);
                nexs[po[i].id]=fa[i]-1-pres[po[i].id];
                update(po[i].id);
            }
            mst(sum,0);
            sort(po+1,po+n+1,cmp1);
            For(i,1,n)
            {
                if(po[i].a==po[i-1].a)fa[i]=fa[i-1];
                else fa[i]=i;
                preb[po[i].id]=query(po[i].id)-(i-fa[i]);
                nexb[po[i].id]=fa[i]-1-preb[po[i].id];
                update(po[i].id);
            }
            sort(po+1,po+n+1,cmp1);
            LL ans1=0,ans2=0,ans;
            For(i,1,n)
            {
                ans1=ans1+pres[i];
                ans2=ans2+preb[i];
            }
            ans=ans1*ans2;
            For(i,1,n)
            {
                ans=ans-nexs[i]*nexb[i];//a==c
                ans=ans-preb[i]*nexb[i];//a==d
                ans=ans-pres[i]*nexs[i];//b==c
                ans=ans-pres[i]*preb[i];//b==d
            }         
            cout<<ans<<endl;
        }
        return 0;
}

  

原文地址:https://www.cnblogs.com/zhangchengc919/p/5732776.html