P3413 SAC#1

题目链接

  很明显,我们很难直接求出“包含长度大于等于2的回文串”的字符的个数,但是我们却可以较为容易的求出“不包含任何长度大于等于2的回文串”的字符的个数,那么我们不如采用正难则反的策略,用总的减去不合法的,那么得到的就是合法的串的个数了。

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cmath>
  4 #include <string>
  5 #include <cstring>
  6 #include <algorithm>
  7 #include <limits>
  8 #include <vector>
  9 #include <stack>
 10 #include <queue>
 11 #include <set>
 12 #include <map>
 13 #include <bitset>
 14 #include <unordered_map>
 15 #include <unordered_set>
 16 #define lowbit(x) ( x&(-x) )
 17 #define pi 3.141592653589793
 18 #define e 2.718281828459045
 19 #define INF 0x3f3f3f3f
 20 #define HalF (l + r)>>1
 21 #define lsn rt<<1
 22 #define rsn rt<<1|1
 23 #define Lson lsn, l, mid
 24 #define Rson rsn, mid+1, r
 25 #define QL Lson, ql, qr
 26 #define QR Rson, ql, qr
 27 #define myself rt, l, r
 28 #define pii pair<int, int>
 29 #define MP(a, b) make_pair(a, b)
 30 using namespace std;
 31 typedef unsigned long long ull;
 32 typedef unsigned int uit;
 33 typedef long long ll;
 34 const int maxN = 1e3 + 7;
 35 const ll mod = 1e9 + 7;
 36 char l[maxN], r[maxN];
 37 int dig[maxN];
 38 void MOD(ll &x) { x >= mod ? x %= mod : x; }
 39 ll dp[maxN][10][10];
 40 ll dfs(int pos, int x, int lx, bool top, bool zero)
 41 {
 42     if(pos == 1) return 1;
 43     if(!top && (~dp[pos][lx][x])) return dp[pos][lx][x];
 44     ll sum = 0;
 45     int u = top ? dig[pos - 1] : 9;
 46     for(int i = 0; i <= u; i ++)
 47     {
 48         if(i)
 49         {
 50             if(i == x) continue;
 51             if(i == lx) continue;
 52         }
 53         else
 54         {
 55             if(!zero && i == lx) continue;
 56             if(!zero && i == x) continue;
 57         }
 58         sum += dfs(pos - 1, i, x, top && (i == u), zero && (!x));
 59         MOD(sum);
 60     }
 61     if(!top && !zero) dp[pos][lx][x] = sum;
 62     return sum;
 63 }
 64 ll solve(char *s)
 65 {
 66     ll ans = 0;
 67     memset(dig, 0, sizeof(dig));
 68     int len = (int)strlen(s);
 69     for(int i = 0; i < len; i ++) dig[len - i] = s[i] - '0';
 70     ll all = 0;
 71     for(int i = len; i >= 1; i --)
 72     {
 73         all = all * 10 + dig[i];
 74         MOD(all);
 75     }
 76     memset(dp, -1, sizeof(dp));
 77     ans = dfs(1002, 0, 0, true, true);
 78     ans = all - ans + mod; MOD(ans);
 79     return ans;
 80 }
 81 int main()
 82 {
 83     scanf("%s%s", l, r);
 84     bool zero = true;
 85     int len = (int)strlen(l);
 86     for(int i = 0; zero && i < len; i ++) if(l[i] ^ '0') zero = false;
 87     ll x, y;
 88     if(!zero)
 89     {
 90         l[len - 1] -= 1;
 91         int tmp = len - 1;
 92         while(l[tmp] < '0')
 93         {
 94             l[tmp - 1] --;
 95             l[tmp] = '9';
 96             tmp --;
 97         }
 98         if(!tmp)
 99         {
100             while(l[tmp] == '0') tmp ++;
101             for(int i = 0; i + tmp < len; i ++) l[i] = l[i + tmp];
102             l[len - tmp] = '';
103             if(len == tmp)
104             {
105                 l[0] = '0';
106                 l[1] = '';
107             }
108         }
109         x = solve(l);
110     }
111     else x = 0;
112     y = solve(r);
113     printf("%lld
", (y - x + mod) % mod);
114     return 0;
115 }
原文地址:https://www.cnblogs.com/WuliWuliiii/p/14138559.html