动态规划(2)——算法导论(17)

写在前面

在上一篇博客中,学习了钢条切割问题。这一篇博客再来学习另一个典型的动态规划问题:矩阵乘法链问题

提出问题

我们知道,矩阵的乘法是满足结合律的,即对于矩阵A,B,C 满足(A B) C = A (B C) 。但不同的结合方式会导致最终所作的乘法总次数不同。

例如:对于矩阵 A(规模为10 x 100),B(100 x 5),C(5 x 50),如果按照( ( A B ) C )的结合方式,计算D = ( A B )将需要作10 x 100 x 5 = 5 000次乘法,计算 D x C 要做 10 x 5 x 50 = 2 500次乘法,总共要做7 500 次乘法;若按照 ( A (B C) )的结合方式,同理可算出一共需要计算75 000次乘法!二者相差1个数量级。

由上可见,找出最优的结合方式,使总的乘法数最少能极大的加快矩阵乘法链的计算速度。这便是矩阵乘法链问题:

[ 给定n个矩阵的链A_1A_2...A_n,其中矩阵A_i的规模为p_{i-1}×p_i, ]

[ 求完全括号化方案,使得计算乘积A_1A_2...A_n所需的标量乘法次数最少。 ]

暴力求解

我们最先想到的是,能否采用暴力求解的方式来找出最优结合方式。但事实上,可以证明这不会是一个高效的算法。因为,当n = 1时,只有唯一一种结合方式;当n > 1时,我们可以将总的结合方式的数目看做是两部分结合方式的数目的乘积,即:

[ A_1A_2...A_n = ((A_1A_2...A_k)(A_{k+1}A_{k+2}...A_n)),k = 1,2..n-1 ]

因此,对于长度为n的矩阵乘法链,总的结合方式的数目P(n)可以用如下递归公式表示:

[P(n) = egin {cases} 1 & n = 1\ sum_{k=1}^{n-1}P(k)P(n-k) & n geq 1 end{cases} ]

可以证明该公式的结果为(Omega(2^n))

应用动态规划方法

我们分4个步骤来执行动态规划方法。

1. 最优括号化方案的结构特征

动态规划方法的第一步还是寻找最优子结构,然后利用这种子结构从子问题的最优解中构造出原问题的最优解。在该问题中,此步骤的做法如下。

首先为了方便起见,我们记

[A_{i...j}(i <= j)表示A_iA_{i+1}...A_j乘积的结果矩阵 ]

如果 i < j ,为了对

[A_iA_{i+1}...A_j ]

进行括号化,我们就必须在某个A_k和A_k+1之间将矩阵链划分开,于是问题就变成了

[先求解矩阵A_{i...k}和A_{k+1...j},然后再计算他们的乘积,最终得到结果。 ]

那么我们在求解两个子矩阵链时,应该采用什么样的分割方案,才能使二者相乘后,得到的结果最优呢?答案是,我们应该采取单独求解子矩阵链时的最优方案最为分割方案。原因是,若我们不这么做,我们将此子矩阵链的最优方案作为该子矩阵链的分割方案代入时,一定比其他的分割方案优秀。

因此我们有结论:

对子矩阵链的最优括号化方案,就是对原矩阵链的最优括号化方案。

2. 一个递归求解方案

记m[i, j] ( i <= j)表示计算矩阵乘法链(A_{i...j})所需的标量乘法数目。

当i = j时,矩阵链只包含唯一一个矩阵,因此m[i, i] = 0;

当i < j时,假设最优括号化方案在k (i <= k < j)时取得。不妨设矩阵的A_i的维数是p_i-1 X pi,那么容易计算出:

[m[i, j] = m[i, k] + m[k+1, j] + p_{i-1}p_kp_j ]

于是最优括号化方案可用如下公式描述:

[m[i, j] = egin {cases} 0 & i = j\ min_{ileq k = j}{m[i, k] + m[k+1. j] + p_{i-1} p_k p_{j} } & i < j end{cases} ]

3. 计算最优代价

我们可以很容易地根据上述递归公式写出一个递归算法,来

[计算 A_1A_2...A_n相乘的最小时间代价m[1, n]。 ]

但可以看出,该递归算法的时间是指数时间,并不比暴力搜索的方案好。

但又注意到,在问题规模较大的子问题中又包含了问题规模较小的子问题,即所有子问题在求解时有重叠。我们可以采用同钢条切割问题中的处理办法一样,记录下每个子问题的结果,以便再次求解其时,可以直接得到答案。

事实上,每对满足1 <= j <= j <= n 的i和j对应一个唯一的子问题,共有((_2^n)+n=Θ(n^2))个。

像这种子问题重叠的性质是应用动态规划的另一个标识。(第一种标识是最优子结构)

下面给出带备忘的自底向上方法的Java实现:

/**
 * 
 * @param p
 *            p[i](0 <= i < p.lenth) 表示第(i + 1)个矩阵的行数,因此第(i + 1)个矩阵的规模为p[i]
 *            × p[i+1]
 * @return
 */
public static int[][] matrixChainOrder(int[] p) {
	int n = p.length - 1;// n为待求矩阵链的总长度,待求矩阵链为A0A1...A(n-1)
	int[][] record = new int[n][n];// record[i][j]表示AiA(i+1)...Aj最优括号化方案的结果
	// l为子矩阵链的长度,l = 2 to n(长度为1只包含一个矩阵,不需要作乘积,因此不考虑)
	for (int l = 2; l <= n; l++) {
		// i = 0 to n - l,表示起始矩阵的下标
		for (int i = 0; i <= n - l; i++) {
			int j = i + l - 1; // j 表示结束矩阵的下标
			record[i][j] = Integer.MAX_VALUE;
			// k = i to j-1,表示分割点
			for (int k = i; k < j; k++) {
				int q = record[i][k] + record[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
				if (q < record[i][j]) {
					record[i][j] = q;
				}
			}
		}
	}
	return record;
}

可以看出,上述算法的时间复杂度为(O(n^3)),并且还需要(O(n^2))的内存空间来保存record数组。但这比暴力求解的指数时间复杂度高效的多。

下面做一个测试,对于一个长度为6的矩阵链,求其最优括号化方案所需要作的标量乘法次数。其中每个矩阵的规模如下:

[A_1 : 30 × 35,A_2 : 35× 15, A_3 : 15 × 5,A_4 : 5× 10, A_5 : 10 × 20,A_6 : 20× 25 ]

此时,输入参数int p[] = {30, 35, 15, 5, 10, 20, 25},代入到上面的matrixChainOrder()方法中,求得结果为record[0][p.length - 2] = 15125

4. 构造最优解

上述matrixChainOrder()方法只能求出(只记录了)各子链问题的最优方案需要进行的标量乘法的数目,而未记录其最优方案的分割方法,即k值。因此,我们可以改进一下,把k值也保存下来,并且把最终的最优括号化方案“友好的”打印出来。下面是改进后的实现代码:

/**
 * 
 * @param p
 *            p[i](0 <= i < p.lenth) 表示第(i + 1)个矩阵的行数,因此第(i + 1)个矩阵的规模为p[i]
 *            × p[i+1]
 * @return
 */
public static int[][] matrixChainOrder(int[] p) {
	int n = p.length - 1;// n为待求矩阵链的总长度,待求矩阵链为A0A1...A(n-1)
	int[][] record = new int[n][n]; // record[i][j]表示AiA(i+1)...Aj最优括号化方案的结果
	int[][] cut = new int[n][n]; // cut[i][j]表示AiA(i+1)...Aj最优分割点
	// l为子矩阵链的长度,l = 2 to n(长度为1只包含一个矩阵,不需要作乘积,因此不考虑)
	for (int l = 2; l <= n; l++) {
		// i = 0 to n - l,表示起始矩阵的下标
		for (int i = 0; i <= n - l; i++) {
			int j = i + l - 1; // j 表示结束矩阵的下标
			record[i][j] = Integer.MAX_VALUE;
			// k = i to j-1,表示分割点
			for (int k = i; k < j; k++) {
				int q = record[i][k] + record[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
				if (q < record[i][j]) {
					record[i][j] = q;
					cut[i][j] = k;
				}
			}
		}
	}
	// 打印 最终括号化方案
	print(cut, 0, n - 1);
	return record;
}

// 打印 括号化方案
public static void print(int[][] cut, int i, int j) {
	if (i == j) {
		System.out.print("A" + i);
		return;
	}
	System.out.print("(");
	print(cut, i, cut[i][j]);
	print(cut, cut[i][j] + 1, j);
	System.out.print(")");
}
原文地址:https://www.cnblogs.com/dongkuo/p/5442889.html