快速幂

目录

目录地址

上一篇

下一篇


递归式快速幂的原理

我们觉得对于整数的整数次幂,用 cmath 的 power 函数太慢了

而且对于一些最终结果取模的式子,用 power 函数可能会涉及到溢出的问题

因此我们考虑如何优化整数的整数次幂,这个问题的求解复杂度

朴素上,求解 (a^n(a,nin N)) 需要 (O(n)) 的时间

我们现在这么考虑:

假设我们已知 (a^{lfloor{nover 2} floor})

则,很快可以得到: (a^n= egin{cases} a^{lfloor{nover 2} floor}cdot a^{lfloor{nover 2} floor},n ext{为偶数} \ \ a^{lfloor{nover 2} floor}cdot a^{lfloor{nover 2} floor}cdot a,n ext{为奇数} end{cases})

或者我们简记为 (a^n=(a^{lfloor{nover 2} floor})^2cdot a^{[2 mid n]})

而求解 (a^{lfloor{nover 2} floor}) 就是一个递归的过程,递归边界 (a^0=1)

因此,我们只需要 (O(log n)) 的时间即可求解

如果我们需要知道的是 (a^n\%m) 也只需要在实现的时候取模即可:

(a^n\%m=(a^{lfloor{nover 2} floor}\%m)^2\%mcdot a^{[2 mid n]}\%m)


递归式快速幂的实现

递归式快速幂

int fpow(int a,int n,int m){
    if(n==0) return 1;
    int ans=fpow(a,n/2,m);
    ans=ans*ans%m;
    if(n%2==1) ans=ans*a%m;
    return ans;
}

当然,喜欢快速幂与三目运算符的可以进一步压行:

int fpow(int a,int n,int m){
    int ans=( n?fpow(a,n>>1,m):1 );
    return ans*ans%m*( (n&1)?a:1 )%m;
}

非递归式快速幂

对于计算 (a^n(a,nin N)) ,我们考虑将 (n) 进行二进制分解

例如 (n=13=1101_{(2)},a^n=a^{13}=a^8cdot a^4cdot a^1)

接下来, (a^1) 是已知的,而根据 (a^k) 能很快地推出 (a^{2k}) 次方

即根据公式: (a^{2k}=a^kcdot a^k(\%m))

因此,我们需要从 (a^1) 推出 (a^{2^k}) 需要 (O(k))

(n) 的最高二进制位是不超过 (log n) 级别的,因此复杂度也为 (O(log n))

实现也并不复杂:

int fpow(int a,int n,int m){
    int ans=1,bas=a;
    while(n!=0){
        if(n%2==1) ans=ans*bas%m;
        bas=bas*bas%m;
        n/=2;
    }
    return ans;
}

或者可以用上位运算,理论上会加速运算:

int fpow(int a,int n,int m){
    int ans=1,bas=a;
    while(n){
        if(n&1) ans=ans*bas%m;
        bas=bas*bas%m;
        n>>=1;
    }
    return ans;
}

矩阵的快速幂

对于后期学到的矩阵,我们可以通过重载运算符的方式,同样可以使得求方阵的快速幂优化到 (O(k^3log n))

其中, (k) 为矩阵的边长

struct matrix{
    int num[MAXN][MAXN];
    matrix(){
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                num[i][j]=(i==j);
    }
    ...
    matrix operator * (const matrix &x){
        matrix y;
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++){
                y.num[i][j]=0;
                for(int k=1;k<=n;k++)
                    y.num[i][j]+=num[i][k]*x.num[k][j]%m;
                    if(y.num[i][j]>=m) y.num[i][j]-=m;
            }
        return y;
    }
}
...
matrix fpow(matrix a,int n){
    matirx ans,bas=a;
    while(n){
        if(n&1) ans=ans*bas;
        bas=bas*bas;
        n>>=1;
    }
    return ans;
}
原文地址:https://www.cnblogs.com/JustinRochester/p/12370760.html