[HAOI2015]数字串拆分

题目描述

你有一个长度为n的数字串。定义f(S)为将S拆分成若干个1~m的数的和的方案数,比如m=2时,f(4)=5,分别为4=1+1+1+1你可以将这个数字串分割成若干个数字(允许前导0),将他们加起来,求f,并求和。比如g(123)=f(1+2+3)+f(1+23)+f(12+3)+f(123)。已知字符串和m后求答案对998244353(717223+1,一个质数)取模后的值。

输入输出格式

输入格式:

第一行输入一个字符串,第二行输入m

输出格式:

仅输出一个数表示答案

输入输出样例

输入样例#1:

123
3

输出样例#1:

394608467

说明

对于100%的数据,字符串长度不超过500,m<=5


题解

矩乘(这题似乎叫什么十进制快速幂)

首先(f[])数组很好求,(f[i]=sum_{j=1}^{j<=m}{f[i-j]})

然后我们可以很简单的写出转移矩阵(以m=3为例)

$egin{Bmatrix}
1 & 1 & 0

1 & 0 & 1

1 & 0 & 0
end{Bmatrix} $

然后我们可以快速处理出(Num[i][j].f[][])表示矩阵转移(i * 10^j)次后的答案

这样我们再考虑对原字符串做dp

可以发现一个问题

这个状态不好设

我一开始想当然的设的是(g[i])表示到第i位的方案数

然后(g[i] = sum_{j=0}^{j<i}{g[j]+f(j+1,i)})

但是这个这个(g[])不满足结合律

所以需要在f(j+1,i)前面乘上到g[j]有多少种情况的系数

所以不能这么转移

只能用矩阵(g[i])表示到i时的答案矩阵为g[i]

(g[i] = sum_{j=0}^{j<i}{g[j]*f(j+1,i)})

这样就可以不重不漏的统计所有的答案了

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define int long long
const int M = 505 ;
const int N = 6 ;
const int W = 10 ;
const int mod = 998244353 ;
using namespace std ;
char s[M] ;
int n , m , t[W] , f[M] , Ans ;

struct Matrix {
    int f[N][N] ;
    Matrix () { memset(f , 0 , sizeof(f)) ; }
    void Start() { for(int i = 1 ; i <= m ; i ++) f[i][i] = 1 ; }
    friend Matrix operator * (Matrix a , Matrix b) {
        Matrix temp ;
        for(int i = 1 ; i <= m ; i ++)
            for(int j = 1 ; j <= m ; j ++)
                for(int k = 1 ; k <= m ; k ++)
                    temp.f[i][j] = (temp.f[i][j] + a.f[i][k] * b.f[k][j]) % mod ;
        return temp ;
    }
    friend Matrix operator + (Matrix a , Matrix b) {
        Matrix temp ;
        for(int i = 1 ; i <= m ; i ++)
            for(int j = 1 ; j <= m ; j ++)
                temp.f[i][j] = (temp.f[i][j] + a.f[i][j] + b.f[i][j]) % mod ;
        return temp ;
    }
} Num[W][M] , g[M] , val[M][M] ;
Matrix Fpw(Matrix Base , int k) {
    Matrix temp ; temp.Start() ;
    while(k) {
        if(k & 1) temp = temp * Base ; 
        Base = Base * Base ; k >>= 1 ;
    }
    return temp ;
}
void PreSolve() {
	for(int i = n ; i >= 1 ; i --) {
		val[i][i] = Num[s[i] - '0'][0] ;
		for(int j = i - 1 ; j >= 1 ; j --)
			val[j][i] = val[j + 1][i] * Num[s[j] - '0'][i - j] ;
	}
}
# undef int
int main() {
# define int long long
    scanf("%s%lld",s + 1,&m) ; n = strlen(s + 1) ; t[0] = 1 ;
    for(int i = 1 ; i <= 9 ; i ++) 
        for(int j = 0 ; j <= m && i - j >= 0 ; j ++)  
            t[i] = (t[i] + t[i - j]) % mod ;
    Num[0][0].Start() ;
    for(int i = 1 ; i <= m ; i ++) Num[1][0].f[i][1] = 1 ;
    for(int i = 2 ; i <= m ; i ++) Num[1][0].f[i - 1][i] = 1 ;
    for(int i = 2 ; i <= 9 ; i ++)
        Num[i][0] = Fpw(Num[1][0] , i) ;
    for(int i = 0 ; i <= 9 ; i ++)
        for(int j = 1 ; j <= n ; j ++)
            Num[i][j] = Fpw(Num[i][j - 1] , 10) ;
    PreSolve() ;
    g[0].Start() ;
    for(int i = 1 ; i <= n ; i ++) {
        for(int j = 0 ; j < i ; j ++) {
        	Matrix temp = g[j] * val[j + 1][i] ;
        	g[i] = g[i] + temp ;
        }
    }
    printf("%lld
",g[n].f[1][1]) ;
    return 0 ;
}
原文地址:https://www.cnblogs.com/beretty/p/10082590.html