POJ1990解题报告【树状数组】

题目地址:

  http://poj.org/problem?id=1990

题目概述:

  给出dist[i]跟v[i],求Σ(丨dist[i]-dist[j]丨*max(v[i],v[j]))。

大致思路:

  因为dist不会重复,所以先按照v数组升序排序,这时发现对于排完序后的第i头牛,他前面所有牛的v值都比它小,于是只需要求出它前面所有点到它的距离乘以第i头牛的v值,最后对所有牛求和即可。

  此时只需要能快速求出前面所有点到它的距离,用两个树状数组cnt和sum来维护第i头牛之前d值比它小的牛的数目及它们的d值的和,此时ansi=vi*((di*cnt-sum)+(all_sum-sum-(i-cnt)*di))

  其中的all_sum是牛i之前的所有牛的d值和,cnt中插入是add(di,1),sum中插入是add(di,di)。

代码:

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstdlib>
 4 #include <cmath>
 5 #include <vector>
 6 #include <ctime>
 7 #include <map>
 8 #include <queue>
 9 #include <cstring>
10 #include <algorithm>
11 using namespace std;
12 
13 #define sacnf scanf
14 #define scnaf scanf
15 #define maxn 20010
16 #define maxm 26
17 #define inf 1061109567
18 #define Eps 0.00001
19 const double PI=acos(-1.0);
20 #define mod 7
21 #define MAXNUM 10000
22 void Swap(int &a,int &b) {int t=a;a=b;b=t;}
23 double Abs(double x) {return (x<0)?-x:x;}
24 typedef long long ll;
25 typedef unsigned int uint;
26 
27 struct node
28 {
29     int v,d;
30     bool operator < (const node &a) const
31     {
32         return (v==a.v)?d<a.d:v<a.v;
33     }
34 } a[maxn];
35 
36 int n,m;
37 ll cnt[maxn];
38 ll dist[maxn];
39 
40 int lowbit(int x) {return x&-x;}
41 
42 void add(int x,int val,ll a[])
43 {
44     while(x<=m)
45     {
46         a[x]+=val;
47         x+=lowbit(x);
48     }
49 }
50 
51 ll query(int x,ll a[])
52 {
53     ll ans=0;
54     while(x)
55     {
56         ans+=a[x];
57         x-=lowbit(x);
58     }
59     return ans;
60 }
61 
62 int main()
63 {
64     //freopen("data.in","r",stdin);
65     //freopen("data.out","w",stdout);
66     //clock_t st=clock();
67     while(~scanf("%d",&n))
68     {
69         m=0;
70         for(int i=1;i<=n;i++)
71         {
72             scanf("%d%d",&a[i].v,&a[i].d);
73             m=max(m,a[i].d);
74         }
75         for(int i=1;i<=m;i++) cnt[i]=dist[i]=0;
76         sort(a+1,a+1+n);
77         ll ans=0,s=0;
78         for(int i=1;i<=n;i++)
79         {
80             add(a[i].d,1,cnt);
81             add(a[i].d,a[i].d,dist);
82             s+=a[i].d;
83             ll c=query(a[i].d,cnt);
84             ll d=query(a[i].d,dist);
85             ans+=a[i].v*((a[i].d*c-d)+(s-d-(i-c)*a[i].d));
86         }
87         printf("%lld
",ans);
88     }
89     //clock_t ed=clock();
90     //printf("

Time Used : %.5lf Ms.
",(double)(ed-st)/CLOCKS_PER_SEC);
91     return 0;
92 }
原文地址:https://www.cnblogs.com/CtrlKismet/p/6520316.html