[SDOI2013]淘金 数位DP

做了好久。。。。

大致思路:

求出前k大的方格之和即为答案,

先考虑一维的情况,设f[i]为数位上各个数相乘为i的数的总数,也就是对于数i,有f[i]个数它们各个位相乘为i,

再拓展到二维,根据乘法原理(貌似是这个原理吧),方格(i , j)的金块数就是f[i] * f[j],

所以先数位DP求出f数组,然后贪心取前k大.

具体过程:

首先观察这道题的特殊性质,可以发现,由于是各个位上的数相乘得到贡献的目标,而各个位上的数只有:

1 2 3 4 5 6 7 8 9(如果有0的话金块就飞出去了,所以可以不算)

那么这几个数的质因子只有2,3,5,7,也就是说不为0的f[i]的i值只有2,3,5,7这4个质因子。

然后用计算器算一下就可以得知满足f[i]不为0的i不会很多,因为2^40就基本超数据范围了,

但是由于n过大,如果连f[i]为0的i也要参与运算就太浪费了,

所以考虑先预处理出可能符合要求的i.

以下提供两种方法:

1,dfs:

从0位开始搜索,

每次都递增(包括自己)的枚举下一位,

如果当前乘积超过n,return (超过n就出去了)

如果当前位==n的位数,统计答案并return(这样的话3会被003之类的统计到)

2,根据质因子只有2,3,5,7,暴力枚举。

4个for分别枚举每种质因子的指数,

退出条件类似于dfs。

然后就是DP了。

虽然说已经找到了所有可能状态,但是状态太大了,放不进数组,

怎么办呢?

离散化,每次查询的时候二分一下就好了(当然map也是可以的,不过我不会用23333333)

f[i][j][k]表示dp到第i位,乘积是第j小的那个状态,k=0表示没有超过边界,1表示超过边界。

注意这里的边界是指前i位对比n的前i位是否超过边界(i越大位数越高)

三层for分别枚举i,j,x(下一位是什么)。

每次用当前状态对下一状态做出贡献。

由于是从低位开始枚举的,因此x占的比重更大,

所以根据x来分情况讨论转移方程.

rank表示加上x这位的乘积是第几大

1,若x > s[i+1](s[i]表示n的第i位)

则不管之前怎么样,都超出边界了

所以f[i+1][rank][1] += f[i][j][1] + f[i][j][0];

2,若x < s[i+1]

则不管之前怎么样,都不会超

所以f[i+1][rank][0] += f[i][j][1] + f[i][j][0];

3,x = s[i+1]

那么下一状态超不超完全取决于当前状态。

所以

f[i+1][rank][0] += f[i][j][0];

f[i+1][rank][1] += f[i][j][1];

然后我们发现对一个数i产生贡献的数的长度可能有很多种,

所以DP完之后,枚举产生贡献的数的长度统计入size[i]

其中长度为len(n的长度)时,

对i造成贡献的数,必须没有超过边界,因此只能加上f[len][i][0].

而其他长度都可以加上f[j][i][0] + f[j][i][1]

这里的size数组实际上就是前面思路里面的f数组,

现在我们得到了size数组,所以下一步就只需要统计答案了。

假设i表示行,j表示列,

显然对于指定的第i行来说,

要依次获得最大,只需要将j指针从大到小依次指向贡献大的列即可(此处size代表了贡献,因为行的限制已经满足了,所以看一个方格的贡献只需要看列,也就是size[列]就可以了)

那么如果我们先对size数组进行排序,那么最优答案必定由前k大的数互相搭配得到。

因此我们定义k个行指针,分别指向自己那一行,同时与这个行指针绑在一起的是一个列指针,先指向贡献最大的那一列,

然后用优先队列按size[行指针] * size[列指针]为标准维护大根堆。

每次取出最大的那一个,取k次,每次取的时候弹出指针,统计贡献,并将列指针移向下一个列(大小仅次于or等于当前列的列),然后放回优先队列中。

最后即可得到答案。

注意此题需要LL,不开会wa。

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define R register int
  4 #define mod 1000000007
  5 #define LL long long
  6 //#define int LL
  7 #define AC 401000
  8 int len;
  9 int s[15];
 10 LL n,k,tmp,tot,cnt,ans,t;//error!!!tmp是用来放n的,所以也要LL啊
 11 LL f[15][AC][2],num[AC],may[AC],size[AC];
 12 struct node{
 13     LL v;
 14     int x,y;
 15 };
 16 
 17 bool operator < (node a,node b)
 18 {
 19     return a.v < b.v;
 20 }
 21 //1230784859 39640
 22 priority_queue <node> q;
 23 /*观察到实质上是要找到前k大金块最多的方格,
 24 然后利用一点组合的知识可以得到:
 25 如果用f[i]表示个位数乘积是i的数的个数,
 26 那么块(i,j)的方块数就是f[i] * f[j],
 27 于是问题变成了求出所有f[i],并找到任意不重复f[i] * f[j]的最大值,
 28 (格子不够就不取了)
 29 然后因为金块移动的方式是向个位数乘积移动,而观察这些可能的因数,
 30 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9,
 31 由这几个数相乘得到的数,它们的质因子只有2 , 3 , 5 , 7。
 32 并且观察到2 ^ 40就已经达到了1e12,所以总的有金块的格子不会太多,
 33 所以应该可以利用这个来计算?
 34 每个数都用2 ^ a * 3 ^ b * 5 ^ c * 7 ^ d来表示(这样大概就存的下了?)
 35 所以就是要对每个可能的数求凑它有多少种方案
 36 f[i][j][k][l]四维DP,分别表示各个质因数的次方
 37 实际上f只要开大概f[40][29][20][15]就可以了
 38 不过因为总状态不多,所以可以先找到所有状态?
 39 然后就按照数位DP的套路来,一位代表到了哪一位,一位是有没有在边界上,
 40 这里还加一位表示是哪个数。
 41 f[i][j][k]表示到了第i位,乘积是第j小的那个数(相当于离散化),k表示是否超过了n(也就是边界)
 42 */
 43 
 44 bool cmp(int a,int b)
 45 {
 46     return a > b;
 47 }
 48 
 49 void dfs(int x,int length,LL mul)//当前位,当前长度,当前乘积,因为要保证小于n,error!!!mul要开LL啊
 50 {//其实也不用在意这里搜到的状态是否严格满足大小关系,如果不满足下面反正统计不到的
 51     if(!mul || mul > n) return ;//所以直接从小到大枚举防止重复组合即可,貌似手算可以发现,满足条件的mul一定不会大于n?
 52     if(length == len)//为了防止过多无用搜索,只搜到这一位就行了
 53     {//事实上就是因为这里搜到的状态并不严格满足大小关系,所以这样搜才不会遗漏,比如03这里直接乘会变成0
 54     //然而貌似并不会,这就是从第0位开始的原因了吧
 55         num[++cnt]=mul;//所以搜到03的假等价条件为搜到13,虽然实际上n不一定大于这个,但是为了保证正确性
 56         return;//多搜一点也没关系
 57     }
 58     for(R i=x;i<=9;i++)
 59         dfs(i,length+1,mul * i);
 60 }//其实也可以按照原来的思路,直接枚举每个质因子的指数
 61 //其实从低位开始枚举也是有用的,因为这样是为了能搜到个位数
 62 void pre()
 63 {
 64     scanf("%lld%lld",&n,&k);
 65     tmp=n;
 66     while(tmp)
 67     {
 68         s[++len]=tmp % 10;
 69         tmp /= 10;
 70     }
 71     dfs(0,0,1);
 72     sort(num+1,num+cnt+1);    
 73     for(R i=1;i<=cnt;i++)
 74         if(num[i] != num[i+1]) num[++tot]=num[i];//去重
 75 }
 76 
 77 inline int half(LL x)//二分这个数的下标error!!!x也要LL啊
 78 {
 79     int l=1,r=tot,mid;//是从去重后的开始二分。。。。
 80     while(l < r)
 81     {
 82         mid=(l + r) >> 1;
 83         if(x == num[mid]) return mid;//error明明是和num[mid]比较。。。。
 84         if(x > num[mid]) l=mid + 1;
 85         else r=mid;
 86     }
 87     return l;
 88 }
 89     
 90 void work()
 91 {
 92     f[0][1][0]=1;//原因类似dfs,1的话就可以任意搭配了
 93     for(R i=0;i<len;i++)
 94         for(R j=1;j<=tot;j++)
 95         {//因为如果各个位上出现了0,那金块就出去了,所以是不能出现任何0的,但这样就不太好一次统计完了,
 96             for(R x=1;x<=9;x++)
 97             {
 98                 int rank=half(num[j] * x);//所以干脆统计到第i位就只统计第i位的,不考虑之前的方案,最后再重新计入答案
 99                 if(x > s[i+1]) f[i+1][rank][1] += f[i][j][1] + f[i][j][0];//如果是大于的话,由于是在高位,所以不管前面怎么样都是大于了
100                 else if(x < s[i+1]) f[i+1][rank][0] += f[i][j][1] + f[i][j][0];//同理,此处不管前面怎么样都是小于
101                 else//如果是等于的话大小关系就只受前面影响了
102                 {
103                     f[i+1][rank][1] += f[i][j][1];
104                     f[i+1][rank][0] += f[i][j][0];
105                 }
106             }
107         }
108     for(R i=1;i<=tot;i++)//枚举数字
109     {
110         for(R j=1;j<len;j++)//枚举长度,长度没那么多的都可以超过n(在有的位置上),
111             size[i]+=f[j][i][0] + f[j][i][1];
112         size[i]+=f[len][i][0];
113     }
114     //for(R i=1;i<=tot;i++) printf("%lld %lld
",num[i],size[i]);
115     sort(size + 1,size + tot + 1,cmp);
116 //    for(R i=1;i<=tot;i++) printf("%lld %lld
",num[i],size[i]);
117 }
118 /*用k个指针(其实不需要这么多?)
119 一开始都指向自己,指针i表示横坐标为i,指向j表示纵坐标为j,
120 因为同一个i,肯定是从大的纵坐标开始取,所以一开始i指向自己,然后如果取了这个,
121 那就指向下一个纵坐标*/
122 void getans()
123 {
124     node x;int en=min(k,tot);
125     for(R i=1;i<=en;i++)
126         q.push((node) {size[i] * size[1],i,1});//指针先指向1,,,这里要改完啊
127     for(R i=1;i<=k;i++)
128     {
129         if(q.empty()) break; 
130         x=q.top();
131         q.pop();
132         ans=(ans + x.v) % mod;
133         x.y++;//移动指针
134         x.v=size[x.x] * size[x.y];//更新金块数量
135         q.push(x);//再放回去
136     }
137     printf("%lld
",ans);
138 //    printf("time used %lf
",(double)clock()/CLOCKS_PER_SEC);
139 }
140 
141 int main()
142 {
143 //    freopen("in.in","r",stdin);
144     pre();
145     work();
146     getans();
147 //    fclose(stdin);
148     return 0;
149 }


由于本人有打注释的习惯,所以看上去可能有点乱,以下为无注释版本:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define R register int
  4 #define mod 1000000007
  5 #define LL long long
  6 #define AC 401000
  7 int len;
  8 int s[15];
  9 LL n,k,tmp,tot,cnt,ans,t;
 10 LL f[15][AC][2],num[AC],may[AC],size[AC];
 11 struct node{
 12     LL v;
 13     int x,y;
 14 };
 15 
 16 bool operator < (node a,node b)
 17 {
 18     return a.v < b.v;
 19 }
 20 
 21 priority_queue <node> q;
 22 
 23 bool cmp(int a,int b)
 24 {
 25     return a > b;
 26 }
 27 
 28 void dfs(int x,int length,LL mul)
 29 {
 30     if(!mul || mul > n) return ;
 31     if(length == len)
 32     {
 33         num[++cnt]=mul;
 34         return;
 35     }
 36     for(R i=x;i<=9;i++)
 37         dfs(i,length+1,mul * i);
 38 }
 39 
 40 void pre()
 41 {
 42     scanf("%lld%lld",&n,&k);
 43     tmp=n;
 44     while(tmp)
 45     {
 46         s[++len]=tmp % 10;
 47         tmp /= 10;
 48     }
 49     dfs(0,0,1);
 50     sort(num+1,num+cnt+1);    
 51     for(R i=1;i<=cnt;i++)
 52         if(num[i] != num[i+1]) num[++tot]=num[i];
 53 }
 54 
 55 inline int half(LL x)
 56 {
 57     int l=1,r=tot,mid;
 58     while(l < r)
 59     {
 60         mid=(l + r) >> 1;
 61         if(x == num[mid]) return mid;
 62         if(x > num[mid]) l=mid + 1;
 63         else r=mid;
 64     }
 65     return l;
 66 }
 67     
 68 void work()
 69 {
 70     f[0][1][0]=1;
 71     for(R i=0;i<len;i++)
 72         for(R j=1;j<=tot;j++)
 73         {
 74             for(R x=1;x<=9;x++)
 75             {
 76                 int rank=half(num[j] * x);
 77                 if(x > s[i+1]) f[i+1][rank][1] += f[i][j][1] + f[i][j][0];
 78                 else if(x < s[i+1]) f[i+1][rank][0] += f[i][j][1] + f[i][j][0];
 79                 else
 80                 {
 81                     f[i+1][rank][1] += f[i][j][1];
 82                     f[i+1][rank][0] += f[i][j][0];
 83                 }
 84             }
 85         }
 86     for(R i=1;i<=tot;i++)
 87     {
 88         for(R j=1;j<len;j++)
 89             size[i]+=f[j][i][0] + f[j][i][1];
 90         size[i]+=f[len][i][0];
 91     }
 92     sort(size + 1,size + tot + 1,cmp);
 93 }
 94 
 95 void getans()
 96 {
 97     node x;int en=min(k,tot);
 98     for(R i=1;i<=en;i++)
 99         q.push((node) {size[i] * size[1],i,1});
100     for(R i=1;i<=k;i++)
101     {
102         if(q.empty()) break; 
103         x=q.top();
104         q.pop();
105         ans=(ans + x.v) % mod;
106         x.y++;
107         x.v=size[x.x] * size[x.y];
108         q.push(x);
109     }
110     printf("%lld
",ans);
111 }
112 
113 int main()
114 {
115     pre();
116     work();
117     getans();
118     return 0;
119 }
View Code
原文地址:https://www.cnblogs.com/ww3113306/p/9048873.html