[Python]实现简单决策树

基本思路:

  通过香农熵来决定每一层使用哪一种标签做分类,分类后,通过多数表决法来决定该层两个节点的类别。每次消耗一个标签,所以一共需要递归“标签个数”层。

 1 # -*- coding:utf-8 -*-
 2 import math
 3 import operator
 4 from collections import Counter
 5 
 6 def shannon_ent(dat):
 7   siz = len(dat)
 8   return 0.0 - reduce(lambda x, y: x + y,
 9     map(lambda each: float(each)/siz * math.log(float(each)/siz, 2),
10     Counter(map(lambda each: each[-1], dat)).values()))
11 
12 def split_dataset(dat, axis, val):
13   ret = filter(lambda each: each[axis] == val, dat)
14   return map(lambda each: each[:axis]+each[axis+1:], ret)
15 
16 def choose_best_feature(dat):
17   feature_num = len(dat[0]) - 1
18   base_ent = shannon_ent(dat)
19   best_info_gain = 0.0
20   best_feature = -1
21   for i in range(feature_num):
22     feature_list = set([each[i] for each in dat])
23     cur_ent = reduce(lambda x, y: x + y,
24               map(lambda val: len(split_dataset(dat, i, val))/float(len(dat))*shannon_ent(split_dataset(dat, i, val)),
25               feature_list))
26     info_gain = base_ent - cur_ent
27     if info_gain > best_info_gain:
28       best_info_gain, best_feature = info_gain, i
29   return best_feature
30 
31 def majority_count(class_list):
32   class_dict = sorted(dict(Counter(class_list)).iteritems(), key=operator.itemgetter(1))
33   return class_dict[-1][0]
34 
35 def create_tree(dat, label):
36   class_list = map(lambda each: each[-1], dat)
37   if class_list.count(class_list[0]) == len(class_list):
38     return class_list[0]
39   if len(dat[0]) == 1:
40     return majority_count(class_list)
41   best_feature = choose_best_feature(dat)
42   best_label = label[best_feature]
43   d_tree = {best_label:{}}
44   del(label[best_feature])
45   feature_val = map(lambda each: each[best_feature], dat)
46   val_set = set(feature_val)
47   def _update_tree(val):
48     sub_label = label[:]
49     d_tree[best_label][val] = create_tree(split_dataset(dat, best_feature, val), sub_label)
50   map(_update_tree, val_set)
51   return d_tree
52 
53 d = [[1,1,'y'], [1,1,'y'], [1,0,'n'], [0,1,'n'], [0,1,'n']]
54 l = ['no surfacing', 'flippers']
55 
56 print create_tree(d, l)
原文地址:https://www.cnblogs.com/kirai/p/6222832.html