【HDOJ】5632 Rikka with Array

1. 题目描述
$A[i]$表示二级制表示的$i$的数字之和。求$1 le i < j le n$并且$A[i]>A[j]$的$(i,j)$的总对数。

2. 基本思路
$n le 10^300$。$n$这么大,显然只能用数位DP来做,我们可以预先处理一下将$n$表示成二进制,然后再进行DP。
$dp[i][j][k]$表示长度为i,两者$A$的差为$j$,状态为$k$的总数。
不妨令$|n| = l$,因此$j in [-l, l]$,因此需要$+l$,将$j$映射到$[0,l*2]$上。
在考虑$k$有多少种情况?不妨令$(x,y), x<y$表示一对可行解。
(0) $Pref(x) < Pref(y), Pref(y) < Pref(n)$;
(1) $Pref(x) < Pref(y), Pref(y) == Pref(n)$;
(2) $Pref(x) == Pref(y), Pref(y) < Pref(n)$;
(3) $Pref(x) == Pref(y), Pref(y) == Pref(n)$;
上面4中情况分别对应$k in [0, 3]$,剩下的就是状态转移就好了,还是挺简单的。总对数就是
[sum_{j = l+1}^{l*2}{dp[l][j][0]+dp[l][j][1]}]
可以使用滚动数组优化,其实也可以不使用。

3. 代码

  1 /* 5632 */
  2 #include <iostream>
  3 #include <sstream>
  4 #include <string>
  5 #include <map>
  6 #include <queue>
  7 #include <set>
  8 #include <stack>
  9 #include <vector>
 10 #include <deque>
 11 #include <bitset>
 12 #include <algorithm>
 13 #include <cstdio>
 14 #include <cmath>
 15 #include <ctime>
 16 #include <cstring>
 17 #include <climits>
 18 #include <cctype>
 19 #include <cassert>
 20 #include <functional>
 21 #include <iterator>
 22 #include <iomanip>
 23 using namespace std;
 24 //#pragma comment(linker,"/STACK:102400000,1024000")
 25 
 26 #define sti                set<int>
 27 #define stpii            set<pair<int, int> >
 28 #define mpii            map<int,int>
 29 #define vi                vector<int>
 30 #define pii                pair<int,int>
 31 #define vpii            vector<pair<int,int> >
 32 #define rep(i, a, n)     for (int i=a;i<n;++i)
 33 #define per(i, a, n)     for (int i=n-1;i>=a;--i)
 34 #define clr                clear
 35 #define pb                 push_back
 36 #define mp                 make_pair
 37 #define fir                first
 38 #define sec                second
 39 #define all(x)             (x).begin(),(x).end()
 40 #define SZ(x)             ((int)(x).size())
 41 #define lson            l, mid, rt<<1
 42 #define rson            mid+1, r, rt<<1|1
 43 
 44 const int mod = 998244353;
 45 const int maxl = 305;
 46 const int maxn = 1205;
 47 char ss[maxl];
 48 int a[maxn];
 49 int dp[2][maxn<<1][4];
 50 
 51 void solve() {
 52     int l = 0, tmp;
 53     int len = strlen(ss);
 54     
 55     rep(i, 0, len)
 56         ss[i] -= '0';
 57     
 58     int b = 0;
 59     
 60     while (b<len && ss[b]==0)
 61         ++b;
 62     if (b >= len) {
 63         puts("0");
 64         return ;
 65     }
 66     
 67     while (1) {
 68         a[l++] = ss[len-1] & 1;
 69         tmp = 0;
 70         rep(i, b, len) {
 71             if (ss[i] & 1) {
 72                 ss[i] = (tmp+ss[i])>>1;
 73                 tmp = 10;
 74             } else {
 75                 ss[i] = (tmp+ss[i])>>1;
 76                 tmp = 0;
 77             }
 78         }
 79         while (b<len && ss[b]==0)
 80             ++b;
 81         if (b >= len)
 82             break;
 83     }
 84     
 85     reverse(a, a+l);
 86     
 87     int l2 = l + l;
 88     int p = 0, q = 1;
 89     
 90     memset(dp, 0, sizeof(dp));
 91     
 92     rep(ii, 0, a[0]+1) {
 93         rep(jj, 0, a[0]+1) {
 94             if (ii > jj)
 95                 continue;
 96             
 97             int nj = ii - jj + l;
 98             int nk = (ii==jj) ? (jj==a[0])|2 : (jj==a[0]);
 99             ++dp[p][nj][nk];
100         }
101     }
102     
103     rep(i, 1, l) {
104         rep(j, 0, l2+1) {
105             // i < j
106             rep(k, 0, 2) {
107                 if (!dp[p][j][k])
108                     continue;
109                 
110                 int mn1, mn2, nj, nk;
111                 
112                 mn1 = 1;
113                 mn2 = (k&1) ? a[i]:1;
114                 
115                 rep(ii, 0, mn1+1) {
116                     rep(jj, 0, mn2+1) {
117                         nj = j + ii - jj;
118                         nk = (k==1) && (jj==a[i]);
119                         if (nj >= 0)
120                             dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod;
121                     }
122                 }
123             }
124             // i = j
125             rep(k, 2, 4) {
126                 if (!dp[p][j][k])
127                     continue;
128                 
129                 int mn, nj, nk;
130                 
131                 mn = (k&1) ? a[i]:1;
132                 rep(ii, 0, mn+1) {
133                     rep(jj, 0, mn+1) {
134                         if (ii > jj)
135                             continue;
136                         
137                         nj = j + (ii==1) - (jj==1);
138                         if (k == 2) {
139                             nk = (ii<jj) ? 0:2;
140                         } else {
141                             nk = (ii<jj) ? (jj==a[i]) : (jj==a[i])|2;
142                         }
143                         if (nj >= 0)
144                             dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod;
145                     }
146                 }
147             }
148         }
149         p ^= 1;
150         q ^= 1;
151         memset(dp[q], 0, sizeof(dp[q]));
152     }
153     
154     int ans = 0;
155     
156     rep(j, l+1, l2+1)
157         rep(k, 0, 2)
158             ans = (ans + dp[p][j][k]) % mod;
159             
160     printf("%d
", ans);
161 }
162 
163 int main() {
164     ios::sync_with_stdio(false);
165     #ifndef ONLINE_JUDGE
166         freopen("data.in", "r", stdin);
167         freopen("data.out", "w", stdout);
168     #endif
169     
170     int t;
171     
172     scanf("%d", &t);
173     while (t--) {
174         scanf("%s", ss);
175         solve();
176     }
177     
178     #ifndef ONLINE_JUDGE
179         printf("time = %d.
", (int)clock());
180     #endif
181     
182     return 0;
183 }

4. 数据生成器

 1 import sys
 2 import string
 3 from random import randint, shuffle
 4 
 5     
 6 def GenData(fileName):
 7     with open(fileName, "w") as fout:
 8         t = 10
 9         fout.write("%d
" % (t))
10         ld = string.digits
11         for tt in xrange(t):
12             length = randint(200, 300)
13             L = [0] * length
14             for i in xrange(length):
15                 L[i] = randint(0, 9)
16             L[0] = randint(1, 9)
17             fout.write("".join(map(str, L)) + "
")
18             
19             
20 def MovData(srcFileName, desFileName):
21     with open(srcFileName, "r") as fin:
22         lines = fin.readlines()
23     with open(desFileName, "w") as fout:
24         fout.write("".join(lines))
25 
26         
27 def CompData():
28     print "comp"
29     srcFileName = "F:Qt_prjhdojdata.out"
30     desFileName = "F:workspacecpp_hdojdata.out"
31     srcLines = []
32     desLines = []
33     with open(srcFileName, "r") as fin:
34         srcLines = fin.readlines()
35     with open(desFileName, "r") as fin:
36         desLines = fin.readlines()
37     n = min(len(srcLines), len(desLines))-1
38     for i in xrange(n):
39         ans2 = int(desLines[i])
40         ans1 = int(srcLines[i])
41         if ans1 > ans2:
42             print "%d: wrong" % i
43 
44             
45 if __name__ == "__main__":
46     srcFileName = "F:Qt_prjhdojdata.in"
47     desFileName = "F:workspacecpp_hdojdata.in"
48     GenData(srcFileName)
49     MovData(srcFileName, desFileName)
50     
原文地址:https://www.cnblogs.com/bombe1013/p/5274336.html