Apriori算法

Apriori算法

给出项集支持度的定义:数据集中包含该项集的数据的比例

置信度(A ightarrow B)的定义:(P(Bmid A) = frac{P(AB)}{P(A)}),其中(P(x))为项集(x)的支持度

算法流程:

先算出所有满足最小支持度的频繁项集,这个可以迭代来算,具体做法就是:

  1. 先找出所有的大小为(1)的候选项集列表(C_1)(C_1)即所有不同元素构成的列表)
  2. 在候选项集列表(C_k)中找出满足最小支持度的项集,加入到元素数量为(k)频繁项集列表(L_k)
  3. 根据(L_k)合并出下一个候选项集列表(C_{k+1})
  4. 重复步骤(2)

合并出下一个候选项集列表操作的算法原理就是:如果某个项集不是频繁项集,那么这个项集的所有超集都不是频繁项集。这样做剪枝可以去掉很多没用的状态

然后根据这些频繁项集来计算满足最小置信度的规则

这里的做法是:
先枚举每一个频繁项集
对于一个频繁项集会有很多可能满足最小置信度的规则,具体来说,如果一个频繁项集存在(n)个元素,那么存在(2^n-2)个规则(所有的组合再去掉全集和空集)
但是可以发现如果一个规则(A ightarrow B)没有达到最小置信度,那么规则((A-sub) ightarrow (B|sub))也达不到最小置信度(其中(sub)(A)的子集,(|)表示集合并,(-)表示集合差),也就是(A)的所有子集都不满足条件了
那么就可以和找频繁项集一样剪枝掉很多状态了
对于每一个频繁项集,先把所有单个的元素作为规则后件集合形成一个规则后件集合(list),然后计算每个规则是否满足条件
把满足条件的前件集合放到(list)中,和找频繁项集中的操作(3)一样找出合法的前件集合,再通过集合相减找出合法的后件集合,递归去做上一个操作就好了

下面是代码

view code


#coding:utf-8


# generate data
def genData():
    return [['牛奶','啤酒','尿布'],
    ['牛奶','面包','黄油'],
    ['牛奶','尿布','饼干'],
    ['面包','黄油','饼干'],
    ['啤酒','尿布','饼干'],
    ['牛奶','尿布','面包','黄油'],
    ['尿布','面包','黄油'],
    ['啤酒','尿布'],
    ['牛奶','尿布','面包','黄油'],
    ['啤酒','饼干'] ]

def loadDataSet():
    return [[1, 3, 4], [2, 3, 5], [1, 2, 3, 5], [2, 5]]


# 传入参数:数据集
# 返回值:候选项集C1
def genC1(datalist)->[frozenset]:
    goodsset = set()
    for items in datalist:
        for goods in items:
            goodsset.add(goods)
    C = list()
    for goods in goodsset:
        C.append(frozenset([goods]))
    return C



# 传入参数:频繁项集list
# 返回:下一个候选项集->list(frozenset)
def mergeToNext(preL):
    if len(preL) == 0:
        return []
    Ck = list()
    k = len(preL[0])
    for i in range(len(preL)):
        for j in  range(i+1,len(preL)):
            A = sorted([x for x in preL[i]])[:k-1]
            B = sorted([x for x in preL[j]])[:k-1]
            if A == B:
                Ck.append(preL[i] | preL[j])
    return Ck

# 传入参数:数据集,候选项集,最小支持度
# 返回值:频繁项集->list(frozenset),频繁项集支持度->dict
def genfreq(dataset, preC, minsupport):
    objfreq = dict()
    L = list()
    for item in preC:
        __appcnt = 0
        for data in dataset:
            if (item&data) == item:
                __appcnt += 1
        if __appcnt / len(dataset) >= minsupport:
            L.append(item)
            objfreq[item] = __appcnt / len(dataset)
    return L, objfreq



# 传入参数:频繁项,规则后集,支持度集合,规则集合,最小置信度
# 无返回值
def GetRules(freqset, R, suppotdata, rulelist, minconf):
    if len(R)==0 or len(R[0])==len(freqset):
        return
    legalconseq = list()
    for ret in R:
        # P(A|B) = P(AB) / P(B)
        conseq = freqset - ret
        conf = supportdata[freqset] / supportdata[conseq]
        if conf >= minconf:
            rulelist.append([conseq,ret,conf])
            legalconseq.append(conseq)
    nextconseqlist = mergeToNext(legalconseq)
    nextR = list()
    for conseq in nextconseqlist:
        nextR.append(freqset-conseq)
    if len(nextR)==0 or len(nextR[0])==0:
        return
    GetRules(freqset,nextR,supportdata,rulelist,minconf)

# 传入参数:各长度频繁项集,频繁项集支持度,最小置信度
# 返回值:规则列表以及置信度
def genRules(Llist, supportdata, minconf = .5):
    rulelist = list()
    for i in range(1,len(Llist)):
        L = Llist[i]
        if len(L) == 0:
            break
        for freqset in L:
            R = [frozenset([x]) for x in freqset]
            GetRules(freqset,R,supportdata,rulelist,minconf)
    return rulelist

# 传入参数:数据集,最小支持度
# 返回值:各长度频繁项集->list(list(frozenset)),频繁项集支持度->dist
def apriori(datalist, minsupport = .5):
    # C1 -> L1 ---merge---> C2
    dataset = list(map(frozenset,[x for x in datalist]))
    supportdata = dict()
    Llist = list()
    C = genC1(dataset)
    while len(C) != 0:
        L, tmpfreq = genfreq(dataset,C,minsupport)
        Llist.append(L)
        supportdata.update(tmpfreq)
        C = mergeToNext(Llist[-1])
    return Llist, supportdata


if __name__ == "__main__":
    # datalist = genData()
    datalist = loadDataSet()
    Llist, supportdata = apriori(datalist)
    rulelist = genRules(Llist,supportdata)
    # for L in Llist:
    #     for p in L:
    #         print(p,supportdata[p])
    for rule in rulelist:
        print(rule[0],'->',rule[1],'conf = ',rule[2])

处理Online_Retail数据的代码
只处理(France)部分

view code
import pandas as pd

# !/usr/bin/python
# coding:utf-8
# author: kiko


# 传入参数:数据集
# 返回值:候选项集C1
def genC1(datalist)->[frozenset]:
    goodsset = set()
    for items in datalist:
        for goods in items:
            goodsset.add(goods)
    C = list()
    for goods in goodsset:
        C.append(frozenset([goods]))
    return C



# 传入参数:频繁项集list
# 返回:下一个候选项集->list(frozenset)
def mergeToNext(preL):
    if len(preL) == 0:
        return []
    Ck = list()
    k = len(preL[0])
    for i in range(len(preL)):
        for j in  range(i+1,len(preL)):
            A = sorted([x for x in preL[i]])[:k-1]
            B = sorted([x for x in preL[j]])[:k-1]
            if A == B:
                Ck.append(preL[i] | preL[j])
    return Ck

# 传入参数:数据集,候选项集,最小支持度
# 返回值:频繁项集->list(frozenset),频繁项集支持度->dict
def genfreq(dataset, preC, minsupport):
    objfreq = dict()
    L = list()
    for item in preC:
        __appcnt = 0
        for data in dataset:
            if (item&data) == item:
                __appcnt += 1
        if __appcnt / len(dataset) >= minsupport:
            L.append(item)
            objfreq[item] = __appcnt / len(dataset)
    return L, objfreq



# 传入参数:频繁项,规则后集,支持度集合,规则集合,最小置信度
# 无返回值
def GetRules(freqset, R, suppotdata, rulelist, minconf):
    if len(R)==0 or len(R[0])==len(freqset):
        return
    legalconseq = list()
    for ret in R:
        # P(A|B) = P(AB) / P(B)
        conseq = freqset - ret
        conf = supportdata[freqset] / supportdata[conseq]
        if conf >= minconf:
            rulelist.append([conseq,ret,conf])
            legalconseq.append(conseq)
    nextconseqlist = mergeToNext(legalconseq)
    nextR = list()
    for conseq in nextconseqlist:
        nextR.append(freqset-conseq)
    if len(nextR)==0 or len(nextR[0])==0:
        return
    GetRules(freqset,nextR,supportdata,rulelist,minconf)

# 传入参数:各长度频繁项集,频繁项集支持度,最小置信度
# 返回值:规则列表以及置信度
def genRules(Llist, supportdata, minconf = .5):
    rulelist = list()
    for i in range(1,len(Llist)):
        L = Llist[i]
        if len(L) == 0:
            break
        for freqset in L:
            R = [frozenset([x]) for x in freqset]
            GetRules(freqset,R,supportdata,rulelist,minconf)
    return rulelist

# 传入参数:数据集,最小支持度
# 返回值:各长度频繁项集->list(list(frozenset)),频繁项集支持度->dist
def apriori(datalist, minsupport = .5):
    # C1 -> L1 ---merge---> C2
    dataset = list(map(frozenset,[x for x in datalist]))
    supportdata = dict()
    Llist = list()
    C = genC1(dataset)
    while len(C) != 0:
        L, tmpfreq = genfreq(dataset,C,minsupport)
        Llist.append(L)
        supportdata.update(tmpfreq)
        C = mergeToNext(Llist[-1])
    return Llist, supportdata


def encode_units(x):
    if x <= 0:
        return 0
    if x >= 1:
        return 1

def test():
    datalist = loadDataSet()
    datalist = genData()
    Llist, supportdata = apriori(datalist)
    rulelist = genRules(Llist,supportdata)
    for L in Llist:
        for p in L:
            print(p,supportdata[p])
    for rule in rulelist:
        print(rule[0],'->',rule[1],'conf = ',rule[2])


def getFranceData():
    df = pd.read_excel('xxx/Online_Retail.xlsx')
    df = df[df['Country'].str.contains('France')]
    df.to_excel('xxx/FranceData.xlsx')


if __name__ == "__main__":
    # getFranceData()
    df = pd.read_excel('xxx/FranceData.xlsx')
    # print('Is all France' if len(df[df['Country'].str.contains('France')])==df.shape[0] else 'Not all France')
    df['Description'] = df['Description'].str.strip()
    df['InvoiceNo'] = df['InvoiceNo'].astype('str')
    df = df[~df['InvoiceNo'].str.contains('C')]
    #显示所有列
    pd.set_option('display.max_columns', None)
    #显示所有行
    pd.set_option('display.max_rows', None)
    basket = (df.groupby(['InvoiceNo', 'Description'])['Quantity'].sum()
         .unstack().reset_index().fillna(0)
         .set_index('InvoiceNo'))
    basket_sets = basket.applymap(encode_units)
    basket_sets.drop('POSTAGE', inplace=True, axis=1)
    datalist = list()
    for ind in basket_sets.index:
        singledata = list()
        for col in basket_sets.columns:
            if basket_sets.loc[ind,col] == 1:
                singledata.append(col)
        datalist.append(singledata)
    
    Llist, supportdata = apriori(datalist,minsupport = .07)
    rulelist = genRules(Llist,supportdata,minconf = .7)
    for rule in rulelist:
        print(rule[0],'->',rule[1],'conf = ',rule[2])
原文地址:https://www.cnblogs.com/kikokiko/p/13996120.html