分治法-大整数乘法和Strassen矩阵乘法

4.5.1 大整数乘法

对于100位甚至更多位的十进制之间的乘法运算还是比较复杂的。我们使用经典的笔算算法来对两个n位整数相乘,第一个数中的n个数字都要被第二个数中的n个数字相乘,这样就需要做n2次相乘,而使用分治技术,我们就能设计出乘法次数少于n2次的算法。

先来看下这个简单公式:

image,则

                                                    image

我们实际上要处理的就是中间的image这一部分,就是将这两次乘法转为一次乘法,具体实现可由下面这个公式得到:

image

我们令image

所以image,原式为:

image

额,这个算法还是有点复杂,代码不知道该怎么写。

4.5.2 Strassen矩阵乘法

V.Strassen在1969年发表了这个算法,它的成功依赖于这个发现:计算两个2阶方阵A和B的积C只需要进行7次乘法运算,而不是蛮力算法所需要的8次。公式参照如下:

image

其中,

image

因此,对于两个2阶方阵相乘时,Strassen算法执行了7次乘法和18次加减法,而蛮力法需要执行8次乘法和4次加法。虽然只是减少了一次乘法,但当矩阵的阶趋于无穷大时,算法卓越的效率就渐渐表现出来了。

代码实现这个算法对我来说感觉还是有点复杂:-),毕竟考虑的因素有很多,因为进行乘法运算的矩阵并不都是2n阶的,而且矩阵之间是无法进行乘法运算的,总之,思路感觉有点多啊。以下代码是我排除了各种不定因素,且进行乘法运算的矩阵都是2n阶的方阵(好像是有点low哦,不过不管啦)。

代码实现:

/**
     * Strassen算法进行矩阵相乘
     * @author xiaofeig
     * @since 2015.9.19 
     * @param marix1 要进行相乘的矩阵1
     * @param marix2 要进行相乘的矩阵2
     * @return 返回相乘的结果
     * */
    public static int[][] strassenMultiplyMatrix(int[][] marix1, int[][] marix2){
        if(marix1.length==1){
            return new int[][]{{marix1[0][0]*marix2[0][0]}};
        }
        
        int xLen=marix1[0].length;
        int yLen=marix1.length;
        int[][] a00=copyArrayOfRange(marix1, 0, 0, yLen/2, xLen/2);
        int[][] a01=copyArrayOfRange(marix1, 0, xLen/2, yLen/2, xLen);
        int[][] a10=copyArrayOfRange(marix1, yLen/2, 0, yLen, xLen/2);
        int[][] a11=copyArrayOfRange(marix1, yLen/2, xLen/2, yLen, xLen);
        
        xLen=marix2[0].length;
        yLen=marix2.length;
        int[][] b00=copyArrayOfRange(marix2, 0, 0, yLen/2, xLen/2);
        int[][] b01=copyArrayOfRange(marix2, 0, xLen/2, yLen/2, xLen);
        int[][] b10=copyArrayOfRange(marix2, yLen/2, 0, yLen, xLen/2);
        int[][] b11=copyArrayOfRange(marix2, yLen/2, xLen/2, yLen, xLen);
        
        int[][] m1=strassenMultiplyMatrix(plusMarix(a00, a11), plusMarix(b00, b11));
        int[][] m2=strassenMultiplyMatrix(plusMarix(a10, a11), b00);
        int[][] m3=strassenMultiplyMatrix(a00, minusMarix(b01, b11));
        int[][] m4=strassenMultiplyMatrix(a11, minusMarix(b10, b00));
        int[][] m5=strassenMultiplyMatrix(plusMarix(a00, a01), b11);
        int[][] m6=strassenMultiplyMatrix(minusMarix(a10, a00), plusMarix(b00, b01));
        int[][] m7=strassenMultiplyMatrix(minusMarix(a01, a11), plusMarix(b10, b11));
        
        int[][] newMarix1=plusMarix(minusMarix(plusMarix(m1, m4), m5), m7);
        int[][] newMarix2=plusMarix(m3, m5);
        int[][] newMarix3=plusMarix(m2, m4);
        int[][] newMarix4=plusMarix(minusMarix(plusMarix(m1, m3), m2), m6);
        return mergeMarix(newMarix1, newMarix2, newMarix3, newMarix4);
        
    }
    /**
     * 复制指定矩阵的某范围内的数据到以新的数组
     * @author xiaofeig
     * @since 2015.9.19 
     * @param array 目标数组
     * @param i,j 左上角元素下标(包含)
     * @param m,n 右下角元素下标(不包含)
     * @return 返回指定数组某范围的新数组
     * */
    public static int[][] copyArrayOfRange(int[][] array,int i,int j,int m,int n){
        int[][] result=new int[m-i][n-j];
        int index=0;
        while(i<m){
            result[index]=Arrays.copyOfRange(array[i], j, n);
            index++;
            i++;
        }
        return result;
    }
    /**
     * 进行矩阵之间的加法运算
     * @author xiaofeig
     * @since 2015.9.19 
     * @param marix1 加数矩阵1
     * @param marix2 加数矩阵2
     * @return 返回结果矩阵
     * */
    public static int[][] plusMarix(int[][] marix1,int[][] marix2){
        int[][] result=new int[marix1.length][marix1[0].length];
        for(int i=0;i<marix1.length;i++){
            for(int j=0;j<marix1[0].length;j++){
                result[i][j]=marix1[i][j]+marix2[i][j];
            }
        }
        return result;
    }
    /**
     * 进行矩阵之间的减法运算
     * @author xiaofeig
     * @since 2015.9.19 
     * @param marix1 减数矩阵
     * @param marix2 被减数矩阵
     * @return 返回结果矩阵
     * */
    public static int[][] minusMarix(int[][] marix1,int[][] marix2){
        int[][] result=new int[marix1.length][marix1[0].length];
        for(int i=0;i<marix1.length;i++){
            for(int j=0;j<marix1[0].length;j++){
                result[i][j]=marix1[i][j]-marix2[i][j];
            }
        }
        return result;
    }
    /**
     * 将四个矩阵合并为一个矩阵
     * @param marix1 数组1
     * @param marix2 数组2
     * @param marix3 数组3
     * @param marix4 数组4
     * @return 返回合并之后的新矩阵
     * */
    public static int[][] mergeMarix(int[][] marix1,int[][] marix2,int[][] marix3,int[][] marix4){
        int m=marix1.length,n=marix1[0].length;
        int[][] marix=new int[m*2][n*2];
        for(int i=0;i<marix.length;i++){
            for(int j=0;j<marix[i].length;j++){
                if(i<m){
                    if(j<n){
                        marix[i][j]=marix1[i][j];
                    }else{
                        marix[i][j]=marix2[i][j-n];
                    }
                }else{
                    if(j<n){
                        marix[i][j]=marix3[i-m][j];
                    }else{
                        marix[i][j]=marix4[i-m][j-n];
                    }
                }
            }
        }
        return marix;
    }

算法分析:

上面的代码我用了两个23阶的矩阵测试过,结果是正确的,其它阶数的矩阵我没测试,估计会有很多错误。

估计一下算法的渐进效率,M(n)表示Strassen算法在计算两个n阶方阵时执行的乘法次数(n为2的乘方),它满足下面的递推关系式:

当n>1时,M(n)=7M(n/2),M(1)=1

因为n=2k

image

因为k=log2n,

image

它比蛮力法需要的n3次乘法运算要少。

宁可孤独,也不违心。宁可抱憾,也不将就。
原文地址:https://www.cnblogs.com/fei-er-blog/p/4821818.html