2017网易实习生笔试-矩阵n次方快速求解算法(平方指数法 exponentiation by squaring)

矩阵n次方存在普遍快速求解算法。(特殊矩阵利用线性代数有快速求解法,这里不讨论特殊矩阵,讨论的是普通矩阵的普适算法)。

想明白矩阵n次方的快速求解算法就得先明白数n次方的快速求解算法。

假设,我们要求$x^n$, 那问题可以分解为以下两种情况:

如果n是偶数,  $(x^2)^{(n/2)}$

如果n是奇数, $x * (x^2)^{(n-1)/2}$,  这里n > 0.(n<0, 可以转换为求1/x的-n次方)。

这就可以用分治算法求解。

实际上当我们求$2^8$时,我们没必要将2连乘8次,我们可以先求$x^2$,再将其平方求$x^4$,再平方求$x^8$。

这样就将复杂度从O(n)降到了O(logn)。

矩阵由于满足结合律,所以n次方也可以类似求解。

由于n次方问题一般n都比较大,所以要注意结果溢出问题哦。

下面给出java代码实现。

下面先给出数的n次方递归算法:

  public double exp(double a, int exp){
        if (exp < 0) return exp(1/a, -exp);
        if (exp == 0) return 1;
        if (exp == 1) return a;
        if (exp % 2 == 0)
            return exp(a*a, exp/2);
        else
            return a * exp(a*a, (exp-1)/2);
    }

这不是尾递归,我们可以稍加改变:

  /**
     * public entrance
     * @param exp
     * @return
     */
    public double exp1(double a, int exp){
        return expNative(1,a,exp);
    }

    /**
     * Write this function is to build tail recursion.
     * @param tmp initialize by 1.
     * @param a
     * @param exp
     * @return
     */
    private double expNative(double tmp, double a, int exp){
        if (exp < 0) return expNative(tmp, 1/a, -exp);
        if (exp == 0) return 1;
        if (exp == 1) return tmp * a;
        if (exp % 2 == 0)
            return expNative(tmp, a*a, exp/2);
        else
            return expNative(tmp*a, a*a, (exp-1)/2);

    }

上面是个单支递归,不改迭代速度也不慢,但我们还是可以改为迭代:

    public double expIterate(double a, int exp){
        if (exp < 0){
            exp = -exp;
            a = 1 / a;
        }
        if (exp == 0) return 1;
        double tmp = 1;
        while (exp > 1){ // after every loop, result is a^exp * tmp.
            if (exp % 2 == 1){
                tmp = tmp * a;
                a = a * a;
                exp = (exp - 1) / 2;
            }else {
                a = a * a;
                exp = exp / 2;
            }
        }
        return tmp * a;
    }

最后给出矩阵的n次方迭代算法:

public long[][] matrixExpMul(long[][] result, int exp){
        long[][] tmp= {
                {1,0,0},
                {0,1,0},
                {0,0,1}
        };
        while (exp > 1) {
            if ((exp & 0x1) == 1) {
                tmp = matrixMul(result, tmp);
                result = matrixMul(result, result);
                exp = (exp - 1) / 2;
            }else {
                result = matrixMul(result, result);
                exp = (exp) / 2;
            }
        }
        return matrixMul(result, tmp);
    }

 一种简单的防溢出法就是,在快溢出时跳出循环,进行处理,然后在外面,计算$tmp * result^{exp}$。

原文地址:https://www.cnblogs.com/zqiguoshang/p/6618950.html