稀疏矩阵 part 3

▶ 各种稀疏矩阵数据结构下 y(n,1) = A(n,m) * x(m,1) 的实现,CPU版本

● MAT 乘法

 1 int dotCPU(const MAT *a, const MAT *x, MAT *y)
 2 {
 3     checkNULL(a); checkNULL(x); checkNULL(y);
 4     if (a->col != x->row)
 5     {
 6         printf("dotMATCPU dimension mismatch!
");
 7         return 1;
 8     }
 9     
10     y->row = a->row;
11     y->col = x->col;
12     for (int i = 0; i < a->row; i++)
13     {
14         format sum = 0;
15         for (int j = 0; j < a->col; j++)        
16             sum += a->data[i * a->col + j] * x->data[j];                
17         y->data[i] = sum;        
18     }
19     COUNT_MAT(y);
20     return 0;
21 }

● CSR 乘法

 1 int dotCPU(const CSR *a, const MAT *x, MAT *y)
 2 {
 3     checkNULL(a); checkNULL(x); checkNULL(y);
 4     if (a->col != x->row)
 5     {
 6         printf("dotCSRCPU dimension mismatch!
");
 7         return 1;
 8     }
 9     
10     y->row = a->row;
11     y->col = x->col;
12     for (int i = 0; i < a->row; i++)                            // i 遍历 ptr,j 遍历行内数据,A 中为 0 的元素不参加乘法
13     {
14         format sum = 0;
15         for (int j = a->ptr[i]; j < a->ptr[i + 1]; j++)
16             sum += a->data[j] * x->data[a->index[j]];
17         y->data[i] = sum;
18     }
19     COUNT_MAT(y);
20     return 0;
21 }

● ELL 乘法

 1 int dotCPU(const ELL *a, const MAT *x, MAT *y)      // CPU ELL乘法
 2 {
 3     checkNULL(a); checkNULL(x); checkNULL(y);
 4     if (a->colOrigin != x->row)
 5     {
 6         printf("dotELLCPU dimension mismatch!
");
 7         return 1;
 8     }
 9 
10     y->row = a->col;
11     y->col = x->col;
12     for (int i = 0; i<a->col; i++)
13     {
14         format sum = 0;
15         for (int j = 0; j < a->row; j++)
16         {
17             int temp = a->index[j * a->col + i];
18             if (temp < 0)                                   // 跳过无效元素
19                 continue;
20             sum += a->data[j * a->col + i] * x->data[temp];
21         }
22         y->data[i] = sum;
23     }
24     COUNT_MAT(y);
25     return 0;
26 }

● COO 乘法

 1 int dotCPU(const COO *a, const MAT *x, MAT *y)
 2 {
 3     checkNULL(a); checkNULL(x); checkNULL(y);
 4     if (a->col != x->row)
 5     {
 6         printf("dotCOOCPU null!
");
 7         return 1;
 8     }
 9 
10     y->row = a->row;
11     y->col = x->col;
12     for (int i = 0; i<a->count; i++)
13         y->data[a->rowIndex[i]] += a->data[i] * x->data[a->colIndex[i]];
14     COUNT_MAT(y);
15     return 0;
16 }

● DIA 乘法

 1 int dotCPU(const DIA *a, const MAT *x, MAT *y)
 2 {
 3     checkNULL(a); checkNULL(x); checkNULL(y);
 4     if (a->colOrigin != x->row)
 5     {
 6         printf("dotDIACPU null!
");
 7         return 1;
 8     }    
 9     y->row = a->row;
10     y->col = x->col;
11     int * inverseIndex = (int *)malloc(sizeof(int) * a->col);
12     for (int i = 0, j = 0; i < a->row + a->col - 1; i++)
13     {
14         if (a->index[i] == 1)
15         {
16             inverseIndex[j] = i;
17             j++;
18         }
19     }
20     for (int i = 0; i < a->row; i++)
21     {
22         format sum = 0;
23         for (int j = 0; j < a->col; j++)
24         {
25             if (i < a->row - 1 - inverseIndex[j] || i > inverseIndex[a->col - 1] - inverseIndex[j])
26                 continue;
27             sum += a->data[i * a->col + j] * x->data[i + inverseIndex[j] - a->row + 1];
28         }
29         y->data[i] = sum;
30     }
31     COUNT_MAT(y);
32     free(inverseIndex);
33     return 0;
34 }
原文地址:https://www.cnblogs.com/cuancuancuanhao/p/10428493.html