矩阵基本运算的实现(standard C++Version)

一年前 自己动手编写的代码,尽量使用了标准C++,其中用到了一些标准模板库STL,但没有用到很高级的部分,不懂STL的朋友也应该可以看懂,建议看一下《Generic Programming and STL》,有候捷译本《泛型编程与STL》。

头文件
  1/***
  2*
  3*   author:XieXiaokui
  5*    purpose:Defines functions for matrix
  6*
  7***/

  8#pragma once
  9
 10#include <iostream>
 11#include <fstream>
 12
 13#include <string>
 14#include <sstream>
 15#include <algorithm>
 16#include <functional>
 17#include <numeric>
 18#include <iterator>
 19#include <cassert>
 20
 21#include "MatrixException.h"
 22
 23using namespace std;
 24
 25class Matrix
 26{
 27public:
 28  
 29    // Constructors
 30    explicit Matrix();
 31    explicit Matrix(int size);
 32    Matrix(int row,int col);
 33    Matrix(const Matrix& m);
 34
 35    // Destructor
 36    ~Matrix();
 37
 38    // Assignment operators
 39    Matrix& operator= (const Matrix& m);
 40
 41    // Value extraction method
 42    int GetRow() const;
 43    int GetCol() const;
 44
 45    // Subscript operator
 46    double operator()(int i,int j)const;    //subscript operator to get individual elements
 47    double& operator()(int i,int j);    //subscript operator to set individual elements
 48
 49    // Unary operators
 50    Matrix operator+() const;    //unary negation operator
 51    Matrix operator-() const;    //unary negation operator
 52
 53    //Binary operators
 54    Matrix operator+(const Matrix& m) const;
 55    Matrix operator-(const Matrix& m) const;
 56    Matrix operator*(const Matrix& m) const;
 57//    Matrix operator*(double d)const;
 58    Matrix operator/(const Matrix& m) const;
 59    Matrix operator/(double d) const;
 60    Matrix operator^(int pow) const;
 61
 62    bool operator== (const Matrix& m) const;    //logical equal-to operator
 63    bool operator!= (const Matrix& m) const;    //logical not-equal-to operator
 64
 65    friend Matrix operator* (double d,const Matrix& m);
 66    friend Matrix operator/ (double d,const Matrix& m);
 67
 68
 69
 70    // Combined assignment - calculation operators
 71    Matrix& operator +=(const Matrix& m) const;
 72    Matrix& operator -=(const Matrix& m) const;
 73    Matrix& operator *=(const Matrix& m) const;
 74    Matrix& operator *=(double d) const;
 75    Matrix& operator /=(const Matrix& m) const;
 76    Matrix& operator /=(double d) const;
 77    Matrix& operator ^=(int pow) const;
 78
 79    // Miscellaneous -methods
 80    void SetZero() ;    //zero matrix:零阵
 81    void SetUnit() ;    //unit matrix:单位阵
 82    void SetSize(int size) ;    //rsizing matrix
 83    void SetSize(int row,int col) ;    //resizing matrix
 84
 85    // Utility methods
 86    Matrix Solve(const Matrix& m)const;    //
 87    Matrix Adjoin() const;    //adjoin matrix:伴随矩阵
 88    double Determinant() const;    //determinant:行列式
 89    double Norm() const;    //norm:模
 90    Matrix Inverse() const;    //inverse:逆阵
 91    Matrix Transpose() const;    //transpose:转置
 92    double Cofactor() const;    //confactor
 93    double Condition() const;    //the condition number of a matrix
 94    int Pivot(int row) const;    // partial pivoting
 95
 96    //primary change
 97    Matrix& Exchange(int i,int j);// 初等变换 对调两行:ri<-->rj
 98    Matrix& Multiple(int index,double k);    //初等变换 第index 行乘以k
 99    Matrix& MultipleAdd(int index,int src,double k);    //初等变换 第src行乘以k加到第index行
100   
101
102    // Type of matrices
103    bool IsSquare() const;    //determine if the matrix is square:方阵
104    bool IsSingular() const;    //determine if the matrix is singular奇异阵
105    bool IsDiagonal() const;    //determine if the matrix is diagonal对角阵
106    bool IsScalar() const;    //determine if the matrix is scalar数量阵
107    bool IsUnit() const;    //determine if the matrix is unit单位阵
108    bool IsZero() const;    //determine if the matrix is zero零矩阵
109    bool IsSymmetric() const;    //determine if the matrix is symmetric对称阵
110    bool IsSkewSymmetric() const;    //determine if the matrix is skew-symmetric斜对称阵
111    bool IsUpperTriangular() const;    //determine if the matrix is upper-triangular上三角阵
112    bool IsLowerTriangular() const;    //determine if the matrix is lower-triangular下三角阵
113
114    // output stream function
115    friend ostream& operator<<(ostream& os,const Matrix& m);
116
117    // input stream function
118    friend istream& operator>>(istream& is,Matrix& m);
119
120    //conert to string
121    string ToString() const;
122
123protected:
124
125    //delete the matrix
126    void Create(int row,int col);
127    void Clear();
128
129private:
130
131    const static double epsilon;
132
133    double** m_data;
134    size_t m_row;
135    size_t m_col;
136
137}
;
实现文件

  1/***
  2*
  3*    
  4*   author:XieXiaokui
  5*    purpose:Defines functions for matrix
  6*
  7***/

  8#include "stdafx.h"
  9
 10#include <iostream>
 11#include <fstream>
 12
 13#include <string>
 14#include <sstream>
 15#include <algorithm>
 16#include <functional>
 17#include <numeric>
 18#include <iterator>
 19#include <cmath>
 20#include <cassert>
 21
 22#include "matrix.h"
 23
 24using namespace std;
 25
 26const double Matrix::epsilon = 1e-7;
 27
 28// constructor
 29
 30Matrix::Matrix():m_row(0),m_col(0),m_data(0)
 31{
 32}

 33
 34
 35// constructor
 36Matrix::Matrix(int size):m_row(size),m_col(size)
 37{
 38    if(size<=0)
 39    {
 40        m_row=0;
 41        m_col=0;
 42        m_data=0;
 43
 44        ostringstream oss;
 45        oss<<"In Matrix::Matrix(int size) size "<<size<<" <=0.Please check it.";
 46        throw oss.str();
 47        
 48    }

 49    
 50    m_data=new double*[size];
 51
 52    for(int i=0;i<size;i++)
 53    {
 54        m_data[i]=new double[size];
 55    }

 56
 57}

 58
 59// constructor
 60Matrix::Matrix(int row,int col):m_row(row),m_col(col)
 61{
 62    if(row<=0)
 63    {
 64        m_row=0;
 65        m_col=0;
 66        m_data=0;
 67
 68        ostringstream oss;
 69        oss<<"In Matrix::Matrix(int row,int col),row "<<row<<" <=0.Please check it.";
 70        throw oss.str();
 71    }

 72    if(col<=0)
 73    {
 74        m_row=0;
 75        m_col=0;
 76        m_data=0;
 77
 78        ostringstream oss;
 79        oss<<" In Matrix::Matrix(int row,int col),col "<<col<<" <=0.Pleasecheck it.";
 80        throw oss.str();
 81    }

 82
 83    m_data=new double*[row];
 84    for(int i=0;i<row;i++)
 85    {
 86        m_data[i]=new double[col];
 87    }

 88}

 89
 90//copy  constructor
 91Matrix::Matrix(const Matrix& m):m_row(m.m_row),m_col(m.m_col)
 92{
 93    m_data = new double*[m_row];
 94    for(int i=0;i<m_row;i++)
 95    {
 96        m_data[i] = new double[m_col];
 97    }

 98
 99    for(int i=0;i<m_row;i++)
100    {
101        copy(m.m_data[i],m.m_data[i]+m_col,m_data[i]);
102    }

103
104}

105
106Matrix::~Matrix()
107{
108    if(m_data != 0)
109    {
110        for(int i=0;i<m_row;i++)
111        {
112            delete[] m_data[i];
113        }

114        delete[] m_data;
115    }

116}

117
118void Matrix::SetZero()
119{
120    for(int i=0;i<m_row;i++)
121        fill(m_data[i],m_data[i]+m_col,0);
122}

123
124void Matrix::SetUnit()
125{
126    for(int i=0;i<m_row;i++)
127        for(int j=0;j<m_col;j++)
128            m_data[i][j] = (i==j)?1:0;
129}

130
131void Matrix::SetSize(int size)
132{
133    Clear();
134    Create(size,size);
135}

136
137
138void Matrix::SetSize(int row,int col)
139{
140    Clear();
141    Create(row,col);
142}

143
144
145void Matrix::Clear()
146{
147
148    if(m_data != 0)
149    {    
150        for(int i=0;i<m_row;i++)
151        {
152            delete[] m_data[i];
153        }

154        delete[] m_data;
155    }

156
157    m_row=0;
158    m_col=0;
159    m_data=0;
160}

161
162/*
163Matrix& Matrix::Create(int size)
164{
165    if(size<=0)
166    {
167        ostringstream oss;
168        oss<<"In Matrix::Create(int size),size "<<size<<" <=0.Please check it.";
169
170        throw oss.str();
171    }
172
173    if(m_data != 0)
174    {
175        for(int i=0;i<m_row;i++)
176        {
177            delete[] m_data[i];
178        }
179        delete[] m_data;
180    }
181
182    m_row=size;
183    m_col=size;
184
185    m_data=new double*[size];
186    for(int i=0;i<size;i++)
187    {
188        m_data[i]=new double[size];
189    }
190
191    for(int i=0;i<size;i++)
192    {
193        for(int j=0;j<size;j++)
194        {
195            m_data[i][j] = ((i==j)?1:0);
196        }
197    }
198
199    return *this;
200}
201*/

202
203void Matrix::Create(int row,int col)
204{
205    if(row<=0)
206    {
207        ostringstream oss;
208        oss<<"In Matrix::Create(int row,int col),row "<<row<<" <=0.Please check it.";
209        throw oss.str();
210    }

211    if(col<=0)
212    {
213        ostringstream oss;
214        oss<<"In Matrix::Create(int row,int col),col  "<<col<<" <=0.Please check it.";
215        throw oss.str();
216    }

217
218    if(m_data != 0)
219    {
220        for(int i=0;i<m_row;i++)
221        {
222            delete[] m_data[i];
223        }

224        delete[] m_data;
225    }

226
227    m_row=row;
228    m_col=col;
229    m_data=new double*[row];
230    for(int i=0;i<row;i++)
231    {
232        m_data[i]=new double[col];
233    }

234}

235
236
237int Matrix::GetRow() const
238{
239    return m_row;
240}

241
242int Matrix::GetCol() const
243{
244    return m_col;
245}

246
247//transpose 转置
248Matrix Matrix::Transpose() const
249{
250    Matrix ret(m_col,m_row);
251    for(int i=0;i<m_row;i++)
252    {
253        for(int j=0;j<m_col;j++)
254        {
255            ret.m_data[j][i] = m_data[i][j];
256        }

257    }

258    return ret;
259}

260
261int Matrix::Pivot(int row) const
262{
263    int index=row;
264
265    for(int i=row+1;i<m_row;i++)
266    {
267        if(m_data[i][row] > m_data[index][row])
268            index=i;
269    }

270
271    return index;
272}

273
274
275Matrix&  Matrix::Exchange(int i,int j)    // 初等变换:对调两行ri<-->rj
276{
277    if(i<0 || i>=m_row)
278    {
279        ostringstream oss;
280        oss<<"In void Matrix::Exchange(int i,int j)    ,i "<<i<<" out of bounds.Please check it.";
281        throw oss.str();
282    }

283
284    if(j<0 || j>=m_row)
285    {
286        ostringstream oss;
287        oss<<"In void Matrix::Exchange(int i,int j)    ,j "<<j<<" out of bounds.Please check it.";
288        throw oss.str();
289    }

290
291    for(int k=0;k<m_col;k++)
292    {
293        swap(m_data[i][k],m_data[j][k]);
294    }

295
296    return *this;
297}

298
299
300Matrix&  Matrix::Multiple(int index,double mul)    //初等变换 第index 行乘以mul
301{
302    if(index <0 || index >= m_row)
303    {
304        ostringstream oss;
305        oss<<"In void Matrix::Multiple(int index,double k) index "<<index<<" out of bounds.Please check it.";
306        throw oss.str();
307    }

308
309    transform(m_data[index],m_data[index]+m_col,m_data[index],bind2nd(multiplies<double>(),mul));
310    return *this;
311}

312
313
314Matrix& Matrix::MultipleAdd(int index,int src,double mul)    //初等变换 第src行乘以mul加到第index行
315{
316    if(index <0 || index >= m_row)
317    {
318        ostringstream oss;
319        oss<<"In void MultipleAdd(int index,int src,double k) index "<<index<<" out of bounds.Please check it.";
320        throw oss.str();
321    }

322
323    if(src < 0 || src >= m_row)
324    {
325        ostringstream oss;
326        oss<<"In void MultipleAdd(int index,int src,double k) src "<<src<<" out of bounds.Please check it.";
327        throw oss.str();
328    }

329        
330    for(int j=0;j<m_col;j++)
331    {
332        m_data[index][j] += m_data[src][j] * mul;
333    }

334
335    return *this;
336
337}

338
339        
340//inverse 逆阵:使用矩阵的初等变换,列主元素消去法
341Matrix Matrix::Inverse() const
342{
343    if(m_row != m_col)    //非方阵
344    {
345        ostringstream oss;
346        oss<<"In Matrix Matrix::Invert() const. m_row "<<m_row<<" != m_col "<<m_col<<" Please check it";
347        throw oss.str();
348    }

349
350    Matrix tmp(*this);
351    Matrix ret(m_row);    //单位阵
352    ret.SetUnit();
353
354    int maxIndex;
355    double dMul;
356
357    for(int i=0;i<m_row;i++)
358    {
359
360        maxIndex = tmp.Pivot(i);
361    
362        if(tmp.m_data[maxIndex][i]==0)
363        {
364            ostringstream oss;
365            oss<<"In Matrix Matrix::Invert() const 行列式的值等于0.Please check it";
366            throw oss.str();
367        }

368
369        if(maxIndex != i)    //下三角阵中此列的最大值不在当前行,交换
370        {
371            tmp.Exchange(i,maxIndex);
372            ret.Exchange(i,maxIndex);
373
374        }

375
376        ret.Multiple(i,1/tmp.m_data[i][i]);
377        tmp.Multiple(i,1/tmp.m_data[i][i]);
378
379
380        for(int j=i+1;j<m_row;j++)
381        {
382            dMul = -tmp.m_data[j][i];
383            tmp.MultipleAdd(j,i,dMul);
384            ret.MultipleAdd(j,i,dMul);
385    
386        }

387
388    }
//end for
389
390    for(int i=m_row-1;i>0;i--)
391    {
392        for(int j=i-1;j>=0;j--)
393        {
394            dMul = -tmp.m_data[j][i];
395            tmp.MultipleAdd(j,i,dMul);
396            ret.MultipleAdd(j,i,dMul);
397        }

398    }
//end for
399        
400    return ret;
401}

402
403    
404
405// assignment operator 賦值运算符
406Matrix& Matrix::operator= (const Matrix& m)
407{
408    if(m_data != 0)
409    {
410        for(int i=0;i<m_row;i++)
411        {
412            delete[] m_data[i];
413        }

414        delete[] m_data;
415    }

416
417    m_row = m.m_row;
418    m_col = m.m_col;
419
420    m_data = new double*[m_row];
421    for(int i=0;i< m_row;i++)
422    {
423        m_data[i] = new double[m_col];
424    }

425
426    for(int i=0;i<m_row;i++)
427    {
428        copy(m.m_data[i],m.m_data[i]+m_col,m_data[i]);
429    }

430    return *this;
431}

432
433
434//binary addition 矩阵加
435Matrix Matrix::operator+ (const Matrix& m) const
436{
437    if(m_row != m.m_row)
438    {
439        ostringstream oss;
440        oss<<"In Matrix::operator+(const Matrix& m) this->m_row "<<m_row<<" != m.m_row "<<m.m_row<<" Please check it.";
441        throw oss.str();
442    }

443    if(m_col != m.m_col)
444    {
445        ostringstream oss;
446        oss<<" Matrix::operator+ (const Matrix& m) this->m_col "<<m_col<<" != m.m_col "<<m.m_col<<" Please check it.";
447        throw oss.str();
448    }

449
450    Matrix ret(m_row,m_col);
451    for(int i=0;i<m_row;i++)
452    {
453        transform(m_data[i],m_data[i]+m_col,m.m_data[i],ret.m_data[i],plus<double>());
454    }

455    return ret;
456}

457
458//unary addition 求正
459Matrix Matrix::operator+() const
460{
461    return *this;
462}

463
464//binary subtraction 矩阵减
465Matrix Matrix::operator- (const Matrix& m) const
466{
467    if(m_row != m.m_row)
468    {
469        ostringstream oss;
470        oss<<"In Matrix::operator-(const Matrix& m) this->m_row "<<m_row<<" != m.m_row "<<m.m_row<<" Please check it.";
471        throw oss.str();
472    }

473    if(m_col != m.m_col)
474    {
475        ostringstream oss;
476        oss<<"In Matrix::operator- (const Matrix& m) this->m_col "<<m_col<<" != m.m_col "<<m.m_col<<" Please check it.";
477        throw oss.str();
478    }

479
480    Matrix ret(m_row,m_col);
481    for(int i=0;i<m_row;i++)
482    {
483        transform(m_data[i],m_data[i]+m_col,m.m_data[i],ret.m_data[i],minus<double>());
484    }

485    return ret;
486}

487
488//unary substraction 求负
489Matrix Matrix::operator-() const
490{
491    Matrix ret(*this);
492
493    for(int i=0;i<m_row;i++)
494        for(int j=0;j<m_col;j++)
495            ret.m_data[i][j] *= -1;
496
497    return ret;
498
499}

500
501
502//binary multiple 矩阵乘
503Matrix Matrix::operator*(const Matrix& m) const
504{
505    if(m_col != m.m_row)
506    {
507        ostringstream oss;
508        oss<<"In Matrix::operator*(const Matrix& m) this->m_col "<<m_col<<" != m.m_row "<<m.m_row<<" Please check it.";
509        throw oss.str();
510    }

511
512    Matrix ret(m_row,m.m_col);
513    Matrix tmp=m.Transpose();
514
515    for(int i=0;i<m_row;i++)
516    {
517        for(int j=0;j<m.m_col;j++)
518        {
519            ret.m_data[i][j]=inner_product(m_data[i],m_data[i]+m_col,tmp.m_data[j],0.0);
520        }

521    }

522
523    return ret;;
524}

525
526//friend scalar multiple 数乘
527Matrix operator* (double d,const Matrix& m)
528{
529    Matrix ret(m);
530
531    for(int i=0;i<ret.m_row;i++)
532        for(int j=0;j<ret.m_col;j++)
533            ret.m_data[i][j] *= d;
534
535    return ret;
536
537}

538
539//binary matrix division equivalent to multiple inverse矩阵除,等价于乘以逆阵
540Matrix Matrix::operator/(const Matrix& m) const
541{
542    return *this * m.Inverse();
543}

544
545//binary scalar division equivalent to multiple reciprocal数除,等价于乘以此数的倒数
546Matrix Matrix::operator/(double d) const
547{
548    return 1/* (*this);
549}

550
551//friend division
552Matrix operator/(double d,const Matrix& m)
553{
554    return d * m.Inverse();
555}

556
557
558// subscript operator to get individual elements 下标运算符
559double Matrix::operator()(int i,int j) const
560{
561    if(i<0 || i>= m_row)
562    {
563        ostringstream oss;
564        oss<<"In double Matrix::operator()(int i,int j) const.i "<<i<<" out of bounds.Please check it";
565        throw oss.str();
566    }

567    if(j<0 || j>= m_col)
568    {
569        ostringstream oss;
570        oss<<"In double Matrix::operator()(int i,int j) const.j "<<j<<" out of bounds.Please check it";
571        throw oss.str();
572    }

573
574    return m_data[i][j];
575}

576
577// subscript operator to set individual elements 下标运算符
578double& Matrix::operator ()(int i,int j)
579{
580    if(i<0 || i>= m_row)
581    {
582        ostringstream oss;
583        oss<<"In double Matrix::operator()(int i,int j) const.i "<<i<<" out of bounds.Please check it";
584        throw oss.str();
585    }

586    if(j<0 || j>= m_col)
587    {
588        ostringstream oss;
589        oss<<"In double Matrix::operator()(int i,int j) const.j "<<j<<" out of bounds.Please check it";
590        throw oss.str();
591    }

592
593    return m_data[i][j];
594}

595
596//to string 化为标准串
597string Matrix::ToString() const
598{
599    ostringstream oss;
600    for(int i=0;i<m_row;i++)
601    {
602        copy(m_data[i],m_data[i]+m_col,ostream_iterator<double>(oss," "));
603        oss<<"\n";
604    }

605
606    return oss.str();
607}

608
609// outputt stream function 输出
610
611ostream& operator<<(ostream& os,const Matrix& m)
612{
613    for(int i=0;i<m.m_row;i++)
614    {
615        copy(m.m_data[i],m.m_data[i]+m.m_col,ostream_iterator<double>(os,"  "));
616        os<<'\n';
617    }

618    return os;
619}

620
621// input  stream function 输入
622istream& operator>>(istream& is,Matrix& m)
623{
624    for(int i=0;i<m.m_row;i++)
625    {
626        for(int j=0;j<m.m_col;j++)
627        {
628            is>>m.m_data[i][j];
629        }

630    }

631    return is;
632}

633
634
635// reallocation method
636
637// public method for resizing matrix
638
639// logical equal-to operator
640
641// logical no-equal-to operator
642
643// combined addition and assignment operator
644
645// combined subtraction and assignment operator
646
647// combined scalar multiplication and assignment operator
648
649
650// combined matrix multiplication and assignment operator
651
652// combined scalar division and assignment operator
653
654// combined power and assignment operator
655
656// unary negation operator
657
658// binary addition operator
659
660// binary subtraction operator
661
662// binary scalar multiplication operator
663
664// binary scalar multiplication operator
665
666// binary matrix multiplication operator
667
668// binary scalar division operator
669
670
671// binary scalar division operator
672
673// binary matrix division operator
674
675// binary power operator
676
677// unary transpose operator
678
679// unary Inverse operator
680
681// Inverse function
682
683// solve simultaneous equation
684
685// set zero to all elements of this matrix
686
687// set this matrix to unity
688
689// private partial pivoting method
690
691// calculate the determinant of a matrix
692
693// calculate the norm of a matrix
694
695// calculate the condition number of a matrix
696
697// calculate the cofactor of a matrix for a given element
698
699// calculate adjoin of a matrix
700
701// Determine if the matrix is singular
702
703// Determine if the matrix is diagonal
704
705// Determine if the matrix is scalar
706
707// Determine if the matrix is a unit matrix
708
709// Determine if this is a null matrix
710
711// Determine if the matrix is symmetric
712
713// Determine if the matrix is skew-symmetric
714// Determine if the matrix is upper triangular
715
716// Determine if the matrix is lower triangular
实现了很多基本的操作,需要的话还可以自己扩充。如果实现那些注释掉的代码,功能就比较强了。

另有标准C语言版本。
原文地址:https://www.cnblogs.com/xiexiaokui/p/158836.html