数据挖掘 之 关联规则求解算法Apriori的实现

关联规则求解算法Apriori的实现

code + 报告 见:https://github.com/JianmingS/Apriori

  1 // by Shi Jianming
  2 /*
  3 数据挖掘:关联规则求解算法Apriori的实现
  4 */
  5 
  6 #define _CRT_SECURE_NO_WARNINGS
  7 #define HOME
  8 
  9 #include <iostream>
 10 #include <cstdio>
 11 #include <vector>
 12 #include <string>
 13 #include <cmath>
 14 #include <map>
 15 #include <locale>
 16 using namespace std;
 17 const double eps = 1e-8;
 18 const int MaxColNum = 100;
 19 
 20 int rowNum, columnNum; // 行数,列数
 21 double supportMin, confidenceMin; // 最小支持度, 最小置信度
 22 int supporNum; // 最小支持数
 23 int total;
 24 int Case;
 25 
 26 vector<vector<int> > dataBase; // 保存原始数据集
 27 vector<string> columnName; // 保存每一列的栏目名
 28 
 29 // 数据集
 30 struct itemset
 31 {
 32     vector<int> item; // 事务(包含0个或多个项)
 33     int cnt; // 事务出现次数
 34     int id; // 事务唯一标识
 35     itemset()
 36     {
 37         cnt = 0;
 38         id = -1;
 39     }
 40 };
 41 
 42 vector<itemset> preL; // 频繁(k-1)-项集
 43 vector<itemset> C; // 候选(k)-项集
 44 vector<itemset> L; // 频繁(k)-项集
 45 
 46 map<int, itemset> forL; // 为构造频繁(k)-项集
 47 
 48 int C1[MaxColNum]; // 记录C1
 49 
 50 
 51 /****************************************************/
 52 /*
 53 Hash树:
 54 Hash函数: h(p) = p mod k
 55 时间复杂度:O(k)
 56 */
 57 
 58 struct hashTrie
 59 {
 60     hashTrie *next[MaxColNum]; // Hash树后继节点
 61     vector<itemset> C; // 候选(k)-项集
 62     hashTrie()
 63     {
 64         fill(next, next + MaxColNum, nullptr);
 65     }
 66 };
 67 // 创建Hash树
 68 void CrehashTrie(hashTrie *root, vector<int> branch)
 69 {
 70     hashTrie *p = root;
 71     for (auto i = 0; i < branch.size(); ++i)
 72     {
 73         int pos = branch[i] % branch.size();
 74         if (nullptr == p->next[pos])
 75         {
 76             p->next[pos] = new hashTrie;
 77         }
 78         p = p->next[pos];
 79     }
 80     itemset itsetTmp;
 81     itsetTmp.item = branch;
 82     itsetTmp.id = (total++);
 83     p->C.push_back(itsetTmp);
 84 }
 85 // 查找branch的值,判断是否可以在Hash树中匹配成功,并记录在Hash树中匹配成功的次数,保存频繁集
 86 bool FindhashTrie(hashTrie *root, vector<int> branch)
 87 {
 88     hashTrie *p = root;
 89     for (auto i = 0; i < branch.size(); ++i)
 90     {
 91         int pos = branch[i] % branch.size();
 92         if (nullptr == p->next[pos])
 93         {
 94             return false;
 95         }
 96         p = p->next[pos];
 97     }
 98     for (auto &tmp : p->C)
 99     {
100         auto i = 0;
101         for (; i != tmp.item.size(); ++i)
102         {
103             if (tmp.item[i] != branch[i])
104             {
105                 break;
106             }
107         }
108         if (i == tmp.item.size())
109         {
110 
111             ++(tmp.cnt);
112             if (tmp.cnt >= (supporNum))
113             {
114                 if (forL.find(tmp.id) != forL.end())
115                 {
116                     ++(forL[tmp.id].cnt);
117                 }else
118                 {
119                     forL.insert({tmp.id, tmp});
120                 }
121             }
122             return true;
123         }
124     }
125     return false;
126 }
127 // 销毁Hash树
128 void DelhashTrie(hashTrie *T, int len)
129 {
130     for (int i = 0; i < len; ++i)
131     {
132         if (T->next[i] != nullptr)
133         {
134             DelhashTrie(T->next[i], len);
135         }
136     }
137     if (!T->C.empty())
138     {
139         T->C.clear();
140     }
141     delete[] T->next;
142     total = 0;
143 }
144 
145 /****************************************************/
146 
147 
148 
149 /****************************************************/
150 /*
151 从集合{0,1,2,3..,(N-1)} 中找出所有大小为k的子集, 并按照字典序排序
152 */
153 vector<vector<int>> combine;
154 int arr[MaxColNum];
155 int visit[MaxColNum];
156 int combineN, combineK;
157 // 起始:dfs(0, 0)
158 void dfs(int d, int pos)
159 {
160     if (d == combineK)
161     {
162         vector<int> tmp;
163         for (int i = 0; i < combineK; ++i)
164         {
165             tmp.push_back(arr[i]);
166         }
167         combine.push_back(tmp);
168         return;
169     }
170     for (int i = pos; i < combineN; ++i)
171     {
172         if (!visit[i])
173         {
174             visit[i] = true;
175             arr[d] = i;
176             dfs(d + 1, i + 1);
177             visit[i] = false;
178         }
179     }
180 }
181 /****************************************************/
182 
183 // 读取原始数据集
184 void Input()
185 {
186     cin >> rowNum >> columnNum;
187     supporNum = ceil(supportMin*(rowNum - 1));
188     string rowFirst;
189     for (auto i = 0; i < rowNum; ++i)
190     {
191         cin >> rowFirst;
192         vector<int> dataRow;
193         int valueTmp;
194         // 去掉输入数据的第一列
195         for (auto j = 0; j < (columnNum - 1); ++j)
196         {
197             if (i != 0)
198             {
199                 cin >> valueTmp;
200                 if (valueTmp) {
201                     ++C1[j];
202                     dataRow.push_back(j);
203                 }
204             }
205             else
206             {
207                 string colNameTmp;
208                 cin >> colNameTmp; 
209                 columnName.push_back(colNameTmp);
210             }
211         }
212         if (i != 0) dataBase.push_back(dataRow);
213     }
214 }
215 
216 // 获取频繁1-项集
217 void Ini()
218 {
219     for (auto i = 0; i < (columnNum - 1); ++i)
220     {
221         if (C1[i] >= supporNum)
222         {
223             itemset itemsetTmp;
224             itemsetTmp.item.push_back(i);
225             itemsetTmp.cnt = C1[i];
226             preL.push_back(itemsetTmp);
227         }
228     }
229 }
230 
231 
232 // 枚举所有事务(即原始数据)包含的k-项集,计算支持度
233 void getItemsK(hashTrie *root, int k)
234 {
235     vector<int> branch;
236 //    int bbb = 0;
237     for (auto tmp : dataBase)
238     {
239 //        cout << bbb++ << " : " << endl;
240         if (tmp.size() < k) continue;
241 
242         combineN = tmp.size();
243         combineK = k;
244         dfs(0, 0);
245 
246         for (int i = 0; i < combine.size(); ++i)
247         {
248             for (int j = 0; j < combine[i].size(); ++j)
249             {
250                 branch.push_back(tmp[combine[i][j]]);
251             }
252             /***********************/
253             /*
254             匹配候选k-项集,计算支持数
255             */
256             FindhashTrie(root, branch);
257 //            if (FindhashTrie(root, branch))
258 //            {
259 //                for (auto aaa = 0; aaa < branch.size(); ++aaa)
260 //                {
261 //                    cout << branch[aaa] << " ";
262 //                }
263 //                cout << endl;
264 //            }
265 //            /***********************/
266             branch.clear();
267         }
268         combine.clear();
269 //        cout << endl;
270     }
271     
272 }
273 
274 // 判断生成的候选(k)-项集的某个(k-1)-项子集是否为频繁项集
275 bool isInfrequentSubset(itemset c)
276 {
277     hashTrie *root = new hashTrie;
278     int k = c.item.size() - 1;
279     for (auto tmp : preL)
280     {
281         CrehashTrie(root, tmp.item);
282     }
283     vector<int> branch;
284 
285     combineN = c.item.size();
286     combineK = k;
287     dfs(0, 0);
288 
289     for (int i = 0; i < combine.size(); ++i)
290     {
291         for (int j = 0; j < combine[i].size(); ++j)
292         {
293             branch.push_back(c.item[combine[i][j]]);
294         }
295 
296         /***********************/
297         /*
298         判断生成的((k-1)-项子集是否为频繁的。
299         */
300         if (!FindhashTrie(root, branch))
301         {
302             combine.clear();
303             DelhashTrie(root, k);
304             return true;
305         }
306         /***********************/
307         branch.clear();
308     }
309     combine.clear();
310     DelhashTrie(root, k);
311     return false;
312 }
313 
314 // 产生候选(k)-项集
315 void apriori_gen(int k)
316 {
317     for (auto L1 = 0; L1 < preL.size(); ++L1)
318     {
319         for (auto L2 = L1 + 1; L2 < preL.size(); ++L2)
320         {
321             auto judge = true;
322             for (auto i = 0; i < (k - 1); ++i)
323             {
324                 if (preL[L1].item[i] != preL[L2].item[i])
325                 {
326                     judge = false;
327                 }
328             }
329             if (!judge) continue;
330             itemset itemsetTmp;
331             for (auto i = 0; i < (k - 1); ++i)
332             {
333                 itemsetTmp.item.push_back(preL[L1].item[i]);
334             }
335             itemsetTmp.item.push_back(preL[L1].item[k - 1]);
336             itemsetTmp.item.push_back(preL[L2].item[k - 1]);
337             if (isInfrequentSubset(itemsetTmp)) {
338                 continue;
339             }
340             C.push_back(itemsetTmp);
341         }
342     }
343 }
344 // Apriori算法实现,并输出关联规则集
345 void Apriori()
346 {
347     for (auto k = 2; !preL.empty(); ++k)
348     {
349         hashTrie *root = new hashTrie;
350         apriori_gen(k - 1); // 求出候选(k)-项集;
351         for (auto i = 0; i < C.size(); ++i)
352         {
353             CrehashTrie(root, C[i].item);
354         }
355         C.clear();
356         getItemsK(root, k);
357         DelhashTrie(root, k);
358         for (auto tmp : forL)
359         {
360             L.push_back(tmp.second);
361         }
362         forL.clear();
363         if (L.empty())
364         {
365             break;
366         }
367         for (auto fromTmp : L)
368         {
369             for (auto toTmp : preL)
370             {
371                 auto i = 0;
372                 for (; i < toTmp.item.size(); ++i)
373                 {
374                     if (toTmp.item[i] != fromTmp.item[i])
375                     {
376                         break;
377                     }
378                 }
379                 if (i == toTmp.item.size())
380                 {
381 //                    double aaa = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt);
382 //                    double bbb = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt) - confidenceMin;
383                     if ((1.0*fromTmp.cnt)/(1.0*toTmp.cnt) - confidenceMin >= 0.0)
384                     {
385                         cout << "Case " << Case++ << " : " << endl;
386                         for (auto j = 0; j < toTmp.item.size(); ++j)
387                         {
388                             cout << columnName[toTmp.item[j]];
389                             if (j != toTmp.item.size() - 1)
390                             {
391                                 cout << ",";
392                             }
393                         }
394                         cout << " => " << columnName[fromTmp.item[toTmp.item.size()]] << endl;
395                     }
396                 }
397             }
398         }
399         preL.clear();
400         preL = L;
401         L.clear();
402     }
403 }
404 
405 int main()
406 {
407 #ifdef HOME
408     freopen("in", "r", stdin);
409     freopen("out", "w", stdout);
410 #endif
411     cin >> supportMin >> confidenceMin;
412     Case = 0;
413     total = 0;
414     Input();
415     Ini();
416     Apriori();
417 
418 #ifdef HOME
419     cerr << "Time elapsed: " << clock() / CLOCKS_PER_SEC << " ms" << endl;
420 #endif
421     return 0;
422 }
原文地址:https://www.cnblogs.com/shijianming/p/4992610.html