Atcoder 1975 Iroha and Haiku

题意:不是很好解释。大体就是,给定x, y, z (1<= x, z <= 5, 1<= y <= 7),对于一个只由1-10构成的数列,如果存在某个连续子数列,使得可以把这个子数列分成三部分,并且从左到右各部分的元素和为x, y, z, 那么就说这个数列是好的。比如x = 5, y = 7, z = 5, 那么数列1, 2, 3, 3, 4, 5就是好的,因为它从第二个元素到最后一个元素的子数列满足2+3 = 5, 3+4 = 7, 5 = 5。下面就是给你n<=40, x, y, z,让你求出好的数列的个数mod 1e9+7。

观察:比较容易想到反向求解,求不好的数列,然后再用好数列的个数 = 10^n - 不好数列的个数,得到答案。而且对于字符串有一类题目,给定禁止串,求长度为n的串有多少种之类的,和这题的限制很像,我们可以想到定义状态dp(len, tail),表示当前长度为len,尾部的若干各元素是tail的不好的数列的个数(这里的tail是一个数列)。但是tail的表示和合法性判断都不是很好做。

方法:(看了题解,但是是日文的,看的不是很懂,如果有理解不对的地方,或是可以改进的地方,请大家指出)

学到了!用二进制数 1 表示 1, 10 表示 2 ,...,  10000表示5, ..., 1000000000 表示10。即一个数i可以用1加上i-1各0表示,就把这种方法成为特殊表示把。那么可以发现,任意几个相邻的数的特殊表示相连,一定包含了这几个数和的特殊表示。比如 2和3相邻,特殊表示为:10100, 5的特殊表示为10000,被10100包含。所以就可以把tail通过这种特殊表示用二进制数来储存,并且可以快速判断是否合法(是否包含x, y, z)。同时注意到x+y+z <= 17,可以只考虑特殊表示的最后18位,2^18不会内存爆炸。

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 #define mp make_pair
 5 #define pb push_back
 6 
 7 typedef long long ll;
 8 typedef pair<int, int> ii;
 9 typedef pair<ll, ll> l4;
10 
11 
12 const int maxn = 20;
13 int x, y, z;
14 int dp[2][1<<maxn] = {0};
15 int f = 0;
16 const int mod = 1e9+7;
17 inline void add(int &a, int b)
18 {
19   a += b;
20   if (a >= mod)
21     a %= mod;
22 }
23 int n;
24 int mask;
25 bool work[1<<maxn] = {0};
26 int main()
27 {
28   scanf("%d", &n);
29   scanf("%d %d %d", &x, &y, &z);
30   int bit = ((((1<<x)+1)<<y)+1)<<(z-1);
31   mask = (1<<(x+y+z))-1;
32   for (int i = 0; i <= mask; ++i)
33     {
34       work[i] = true;
35       int tmp = i;
36       while (tmp)
37     {
38       if ((tmp&bit)==bit)
39         {
40           work[i] = false;
41           break;
42         }
43       tmp >>= 1;
44     }
45     }
46   f = 0;
47   dp[f][0] = 1;
48   for (int i = 0; i < n; ++i)
49     {
50       f ^= 1;
51       memset(dp[f], 0, (mask+1)*sizeof(int));
52       for (int j = 0; j <= mask; ++j)
53     if (work[j])
54       for (int d = 1; d <= 10; ++d)
55         {
56           int nxt = ((j<<1)|1)<<(d-1);
57           nxt &= mask;
58           if (work[nxt])
59         add(dp[f][nxt], dp[f^1][j]);
60         }
61     }
62   int ans = 0;
63   for (int i = 0; i <= mask; ++i)
64     add(ans, dp[f][i]);
65   ll tot = 1;
66   for (int i = 0; i < n; ++i)
67     tot = tot * 10 % mod;
68   tot -= ans;
69   tot %= mod;
70   if (tot < 0)
71     tot += mod;
72   ans = tot;
73   printf("%d
", ans);
74   
75     
76       
77 }
View Code
原文地址:https://www.cnblogs.com/skyette/p/8082575.html