稀疏矩阵乘积——三元组存储表示

#include <iostream>
using namespace std;

#define MAXSIZE 12500
#define MAXRC 100
#define ERROR -1
#define OK 1
#define zero 0

typedef int Status;

typedef struct {
    int i, j;
    int e;
}Triple;

typedef struct {
    Triple data[MAXSIZE + 1];    //非零元三元组表(0号单元未用)
    int rpos[MAXRC + 1];        //各行第一个非零元的位置表(0号单元未用)
    int mu, nu, tu;                //矩阵的行数、列数和非零元个数
}RLSMatrix;

Status CreatMatrix(RLSMatrix *matrix);
void PrintMatrix(RLSMatrix *matrix);
Status MultSMatrix(RLSMatrix M, RLSMatrix N, RLSMatrix &Q);

void main()
{
    RLSMatrix *M, *N, *Q;
    if (!(M = (RLSMatrix *)malloc(sizeof(RLSMatrix))))
        exit(ERROR);
    if (!(N = (RLSMatrix *)malloc(sizeof(RLSMatrix))))
        exit(ERROR);
    if (!(Q = (RLSMatrix *)malloc(sizeof(RLSMatrix))))
        exit(ERROR);
    if (CreatMatrix(M)&&CreatMatrix(N))
    {
        printf("
put out M:
");
        PrintMatrix(M);
        printf("
put out N:
");
        PrintMatrix(N);
        if(MultSMatrix(*M, *N, *Q))
        {    
            printf("


 M * N :
");
            PrintMatrix(Q);
        }
        else 
            printf("M.mu and N.nu are not mathing
");
    }
    else 
        printf("input error.
");
}

Status CreatMatrix(RLSMatrix *matrix)
{
    int num = 0, p, q, min, temp;
    int row;
    printf("input the row and col:
");
    scanf("%d%d",&matrix->mu, &matrix->nu);
    if (matrix->mu > MAXRC)
        return ERROR;
    printf("row col val
");
    scanf("%d%d%d", &matrix->data[num + 1].i, &matrix->data[num + 1].j, &matrix->data[num + 1].e);
    while(matrix->data[num + 1].i)
    {
        if (++num > MAXSIZE)
            return ERROR;
        scanf("%d%d%d", &matrix->data[num + 1].i, &matrix->data[num + 1].j, &matrix->data[num + 1].e);
    }
    matrix->tu = num;
    for (p = 1; p <= matrix->tu - 1; ++p)
    {
        min = p;
        for (q = p + 1; q <= matrix->tu; ++q)
        {
            if (matrix->data[min].i > matrix->data[q].i||
                (matrix->data[min].i == matrix->data[q].i && matrix->data[min].j > matrix->data[q].j))
                min = q;
        }
        if (p != min)
        {
            temp = matrix->data[min].i;
            matrix->data[min].i = matrix->data[p].i;
            matrix->data[p].i = temp;
            temp = matrix->data[min].j;
            matrix->data[min].j = matrix->data[p].j;
            matrix->data[p].j = temp;
            temp = matrix->data[min].e;
            matrix->data[min].e = matrix->data[p].e;
            matrix->data[p].e = temp;
        }
    }
    for (row = 1,num = 0; row <= matrix->mu; ++ row)
    {
        matrix->rpos[row] = num + 1;
        while (matrix->data[num +1].i == row)
            ++num;
    }
    return OK;
}

void PrintMatrix(RLSMatrix *matrix)
{
    int row, col;
    int num = 0;
    printf("
row:%d   col:%d   number:%d
", matrix->mu, matrix->nu, matrix->tu);
    for (row = 1; row <= matrix->mu; ++row)
    {
        for (col = 1; col <= matrix->nu; ++col)
        {
            if (num + 1 <= matrix->tu && matrix->data[num +1].i == row && matrix->data[num+1].j == col)
            {
                ++num;
                printf("%4d", matrix->data[num].e);
            }
            else 
                printf("%4d", zero);
        }
        printf("
");
    }
}

Status MultSMatrix(RLSMatrix M, RLSMatrix N, RLSMatrix &Q)
{
    int arow, brow, ccol;
    int *ctemp;
    ctemp = new int[N.nu];
    int tp, p, t, q;
    if (M.nu != N.mu) return ERROR;
    Q.mu = M.mu; Q.nu = N.nu; Q.tu = 0;
    if (M.tu * N.tu != 0)        //Q是非零矩阵
    {        
        for (arow = 1; arow <= M.mu; ++arow)    //处理M的每一行
        {    
            for (ccol = 1;ccol <= N.nu; ++ccol)
                ctemp[ccol] = 0;        //当前行各元素累加器清零
            Q.rpos[arow] = Q.tu + 1;    //当前行第一个非零乘积在Q中的位置
            if (arow < M.mu) tp = M.rpos[arow + 1];        //当前行第arow + 1行第一个非零元的位置
            else tp = M.tu + 1;
            for ( p = M.rpos[arow]; p < tp; ++p)        //当前行非零元位置范围
            {
                brow = M.data[p].j;                        //非零元对应N中的行号
                if (brow < N.mu) t = N.rpos[brow] + 1;    
                else t= N.tu + 1;
                for ( q = N.rpos[brow]; q < t; ++q)
                {    
                    ccol = N.data[q].j;
                    ctemp[ccol] += M.data[p].e*N.data[q].e;
                }
            }//求得Q中第crow(=arow)行的非零元
            for (ccol = 1; ccol <= Q.nu; ++ ccol)
                if (ctemp[ccol]) {
                    if (++Q.tu > MAXSIZE) return ERROR;
                    Q.data[Q.tu].i = arow;
                    Q.data[Q.tu].j = ccol;
                    Q.data[Q.tu].e = ctemp[ccol];
                }
        }
    }
    return OK;
}
原文地址:https://www.cnblogs.com/gjfhopeful/p/3620890.html