[HDU 6318] Swaps and Inversions

[题目链接]

         http://acm.hdu.edu.cn/showproblem.php?pid=6318

[算法]

        线段树 / 树状数组

[代码]

         

#include<bits/stdc++.h>
using namespace std;
#define MAXN 100010

long long i,n,x,y,len,ans,l,r,mid;
long long a[MAXN],rk[MAXN],tmp[MAXN];

struct SegmentTree
{
        struct Node
        {
                long long l,r;
                long long sum;
        } Tree[MAXN << 2];
        inline void build(long long index,long long l,long long r)
        {
                long long mid;
                Tree[index].l = l;
                Tree[index].r = r;
                Tree[index].sum = 0;
                if (l == r) return;
                mid = (l + r) >> 1;
                build(index << 1,l,mid);
                build(index << 1 | 1,mid + 1,r); 
        }
        inline void add(long long index,long long pos,long long val)
        {
                long long mid;
                Tree[index].sum += val;
                if (Tree[index].l == Tree[index].r) return;
                mid = (Tree[index].l + Tree[index].r) >> 1;
                if (mid >= pos) add(index << 1,pos,val);
                else add(index << 1 | 1,pos,val);
        }
        inline long long query(long long index,long long l,long long r)
        {
                long long mid;
                if (l > r) return 0;
                if (Tree[index].l == l && Tree[index].r == r) return Tree[index].sum;
                mid = (Tree[index].l + Tree[index].r) >> 1;
                if (mid >= r) return query(index << 1,l,r);
                else if (mid + 1 <= l) return query(index << 1 | 1,l,r);
                else return query(index << 1,l,mid) + query(index << 1 | 1,mid + 1,r);
        }
} T;

int main() 
{
        
        while (scanf("%lld%lld%lld",&n,&x,&y) != EOF)
        {
                for (i = 1; i <= n; i++) 
                {
                        scanf("%lld",&a[i]);
                        tmp[i] = a[i];
                }
                sort(tmp + 1,tmp + n + 1,greater<int>());
                len = 1;
                for (i = 2; i <= n; i++)
                {
                        if (tmp[i] != tmp[i - 1])
                                tmp[++len] = tmp[i];
                }
                for (i = 1; i <= n; i++) 
                {
                        l = 1; r = len;
                        while (l <= r)
                        {
                                mid = (l + r) >> 1;
                                if (a[i] >= tmp[mid]) r = mid - 1;
                                else l = mid + 1;
                        }
                        rk[i] = l;
                }
                T.build(1,1,len); 
                ans = 0;
                for (i = 1; i <= n; i++)
                {
                        ans += T.query(1,1,rk[i] - 1);
                        T.add(1,rk[i],1);
                }
                printf("%lld
",1ll * ans * min(x,y));
        }
        return 0;
    
}
原文地址:https://www.cnblogs.com/evenbao/p/9391850.html