c#抽取pdf文档标题(4)——机器学习以及决策树

        我的一位同事告诉我,pdf抽取标题,用机器学习可以完美解决问题,抽取的准确率比较高。于是,我看了一些资料,就动起手来,实践了下。

        我主要是根据以往历史块的特征生成一个决策树,然后利用这棵决策树,去判断一个新的块到底是不是标题。理论上,历史块的数量越庞大,那么结果越准确。其实经过实践不是这样的,我觉得影响结果判断的因素越少,而且库的数量达到一定数量后,判断越准确。这个记录块信息的历史库,就是供计算机学习的原料。

       首先看下,如何形成一个决策树?

 1  private static DecisionTreeID3<string> BuildTree()
 2         {
 3             //var blockList = Tools.SelectList("/config/Blocks/Block");
 4 
 5             var blockList = DBHelper.Select<BlockData>();
 6 
 7             string[,] da = new string[blockList.Count, 6];
 8 
 9             for (int i = 0; i < blockList.Count; i++)
10             {
11                 var index = blockList[i].Index;
12 
13                 if (index >= 1 && index <= 5)
14                 {
15                     da[i, 0] = "high";
16                 }
17                 else if (index >= 6 && index <= 12)
18                 {
19                     da[i, 0] = "middle";
20                 }
21                 else
22                 {
23                     da[i, 0] = "low";
24                 }
25                 var space = blockList[i].Space.ToString() == "非数字" ? 0 : (int)blockList[i].Space;
26 
27                 if (space >= 3 && space <= 10 || space >= 17 && space <= 20)
28                 {
29                     da[i, 1] = "high";
30                 }
31                 else if (space >= 11 && space <= 16)
32                 {
33                     da[i, 1] = "middle";
34                 }
35                 else
36                 {
37                     da[i, 1] = "low";
38                 }
39 
40                 var xSize = blockList[i].XSize;
41 
42                 if (xSize >= 11 && xSize <= 19 || xSize >= 400 && xSize <= 440 || xSize >= 250 && xSize <= 260)
43                 {
44                     da[i, 2] = "high";
45                 }
46                 else
47                 {
48                     da[i, 2] = "low";
49                 }
50 
51                 var ySize = blockList[i].YSize;
52 
53                 if (ySize >= 11 && ySize <= 19 || ySize >= 250 && ySize <= 290 || ySize >= 400 && ySize <= 440)
54                 {
55                     da[i, 3] = "high";
56                 }
57                 else
58                 {
59                     da[i, 3] = "low";
60                 }
61 
62                 var height = (int)blockList[i].Height;
63 
64                 if (height >= 6 && height <= 13 || height >= 22 && height <= 24)
65                 {
66                     da[i, 4] = "high";
67                 }
68                 else
69                 {
70                     da[i, 4] = "low";
71                 }
72                 da[i, 5] = blockList[i].IsTitle.ToString();
73             }
74 
75             var names = new string[] { "Index", "Space", "XSize", "YSize", "Height", "IsTitle" };
76             var tree = new DecisionTreeID3<string>(da, names, new string[] { "True", "False" });
77             tree.Learn();
78             return tree;
79         }

把数据库中的块信息,通过转换,变成二维数组,而且每个特征值被转为离散的值,之前的值是几乎连续的值,它有多少个,无法确定,转为离散的值,才能控制决策树的规模。下面,我们看看决策树类 DecisionTreeID3:

  1  public class DecisionTreeID3<T> where T : IEquatable<T>
  2     {
  3         T[,] Data;
  4         string[] Names;
  5         int Category;
  6         T[] CategoryLabels;
  7         public DecisionTreeNode<T> Root;
  8         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
  9         {
 10             Data = data;
 11             Names = names;
 12             Category = data.GetLength(1) - 1;//类别变量需要放在最后一列
 13             CategoryLabels = categoryLabels;
 14         }
 15         public void Learn()
 16         {
 17             int nRows = Data.GetLength(0);
 18             int nCols = Data.GetLength(1);
 19             int[] rows = new int[nRows];
 20             int[] cols = new int[nCols];
 21             for (int i = 0; i < nRows; i++) rows[i] = i;
 22             for (int i = 0; i < nCols; i++) cols[i] = i;
 23             Root = new DecisionTreeNode<T>(-1, default(T));
 24             Learn(rows, cols, Root);
 25 
 26             DisplayNode(Root);
 27         }
 28 
 29         public bool Search(string[] test, DecisionTreeNode<T> Node = null)
 30         {
 31             bool isResult = false;
 32 
 33             if (Node == null) Node = Root;
 34 
 35             foreach (var item in Node.Children)
 36             {
 37                 var label = item.Label;
 38                 if (label < test.Length - 1 && test[label] != item.Value.ToString()) continue;
 39                 else
 40                 {
 41                     if (label == test.Length - 1 && item.Value.ToString() == "True")
 42                     {
 43                         isResult = true;
 44                         return isResult;
 45                     }
 46                     else
 47                     {
 48                         isResult = Search(test, item);
 49                     }
 50                 }
 51             }
 52             return isResult;
 53         }
 54 
 55         public StringBuilder sb = new StringBuilder();
 56 
 57         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
 58         {
 59             if (Node.Label != -1)
 60             {
 61                 string nodeStr = string.Format("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
 62                 sb.AppendLine(nodeStr);
 63             }
 64             foreach (var item in Node.Children)
 65                 DisplayNode(item, depth + 1);
 66         }
 67         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root, int depth = 0)
 68         {
 69             var categoryValues = GetAttribute(Data, Category, pnRows);
 70             var categoryCount = categoryValues.Distinct().Count();
 71             if (categoryCount == 1)
 72             {
 73                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
 74                 Root.Children.Add(node);
 75             }
 76             else
 77             {
 78                 if (depth > 10) return;
 79 
 80                 if (pnRows.Length == 0) return;
 81                 else if (pnCols.Length == 1)
 82                 {
 83                     //投票~
 84                     //多数票表决制
 85                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
 86                     var node = new DecisionTreeNode<T>(Category, Vote.First());
 87                     Root.Children.Add(node);
 88                 }
 89                 else
 90                 {
 91                     //var maxCol = MaxEntropy(pnRows, pnCols);
 92 
 93                     //按c4.5算法
 94                     var maxCol = MaxEntropyRate(pnRows, pnCols);
 95 
 96                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
 97                     string currentPrefix = Names[maxCol];
 98                     foreach (var attr in attributes)
 99                     {
100                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
101                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
102                         var node = new DecisionTreeNode<T>(maxCol, attr);
103                         Root.Children.Add(node);
104                         Learn(rows, cols, node, depth + 1);//递归生成决策树
105                     }
106                 }
107             }
108         }
109         public double AttributeInfo(int attrCol, int[] pnRows)
110         {
111             var tuples = AttributeCount(attrCol, pnRows);
112             var sum = (double)pnRows.Length;
113             double Entropy = 0.0;
114             foreach (var tuple in tuples)
115             {
116                 int[] count = new int[CategoryLabels.Length];
117                 foreach (var irow in pnRows)
118                     if (Data[irow, attrCol].Equals(tuple.Item1))
119                     {
120                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
121                         count[index]++;//目前仅支持类别变量在最后一列
122                     }
123                 double k = 0.0;
124                 for (int i = 0; i < count.Length; i++)
125                 {
126                     double frequency = count[i] / (double)tuple.Item2;
127                     double t = -frequency * Log2(frequency);
128                     k += t;
129                 }
130                 double freq = tuple.Item2 / sum;
131                 Entropy += freq * k;
132             }
133             return Entropy;
134         }
135 
136         public double AttributeInfoRate(int attrCol, int[] pnRows)
137         {
138             var tuples = AttributeCount(attrCol, pnRows);
139             var sum = (double)pnRows.Length;
140             double SplitE = 0.0;
141 
142             foreach (var tuple in tuples)
143             {
144                 double frequency = tuple.Item2 / (double)sum;
145                 double t = -frequency * Log2(frequency);
146                 SplitE += t;
147             }
148             return SplitE;
149         }
150 
151         public double CategoryInfo(int[] pnRows)
152         {
153             var tuples = AttributeCount(Category, pnRows);
154             var sum = (double)pnRows.Length;
155             double Entropy = 0.0;
156             foreach (var tuple in tuples)
157             {
158                 double frequency = tuple.Item2 / sum;
159                 double t = -frequency * Log2(frequency);
160                 Entropy += t;
161             }
162             return Entropy;
163         }
164         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
165         {
166             foreach (var irow in pnRows)
167                 yield return data[irow, col];
168         }
169         private static double Log2(double x)
170         {
171             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
172         }
173         /// <summary>
174         /// 计算增益率
175         /// </summary>
176         /// <param name="pnRows"></param>
177         /// <param name="pnCols"></param>
178         /// <returns></returns>
179         public int MaxEntropy(int[] pnRows, int[] pnCols)
180         {
181             double cateEntropy = CategoryInfo(pnRows);
182             int maxAttr = 0;
183             double max = double.MinValue;
184             foreach (var icol in pnCols)
185                 if (icol != Category)
186                 {
187                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
188                     if (max < Gain)
189                     {
190                         max = Gain;
191                         maxAttr = icol;
192                     }
193                 }
194             return maxAttr;
195         }
196         /// <summary>
197         /// 计算增益率最大的属性
198         /// </summary>
199         /// <param name="pnRows"></param>
200         /// <param name="pnCols"></param>
201         /// <returns></returns>
202         public int MaxEntropyRate(int[] pnRows, int[] pnCols)
203         {
204             double cateEntropy = CategoryInfo(pnRows);
205             int maxAttr = 0;
206             double max = double.MinValue;
207             foreach (var icol in pnCols)
208                 if (icol != Category)
209                 {
210                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
211 
212                     double SplitE = AttributeInfoRate(icol, pnRows);
213 
214                     double GrainRation = Gain / SplitE;
215 
216                     if (max < GrainRation)
217                     {
218                         max = GrainRation;
219                         maxAttr = icol;
220                     }
221                 }
222             return maxAttr;
223         }
224 
225         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
226         {
227             var tuples = from n in GetAttribute(Data, col, pnRows)
228                          group n by n into i
229                          select Tuple.Create(i.First(), i.Count());
230             return tuples;
231         }
232     }
233 
234     public sealed class DecisionTreeNode<T>
235     {
236         public int Label { get; set; }
237         public T Value { get; set; }
238         public List<DecisionTreeNode<T>> Children { get; set; }
239         public DecisionTreeNode(int label, T value)
240         {
241             Label = label;
242             Value = value;
243             Children = new List<DecisionTreeNode<T>>();
244         }
245     }

       这个类里面包含着两个算法,C4.5和ID3,C4.5是在ID3的基础上进行改进的一种算法。我采取了C4.5的算法,在94行。C4.5 算法,是用信息增益率来选择属性。ID3选择属性用的是子树的信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy, 熵是一种不纯度度量准则),也就是熵的变化值,而C4.5用的是信息增益率。  此处信息量比较大,可以参考 http://shiyanjun.cn/archives/428.html 这篇文章。

       决策树建好后,我们开始调用:

1            var tree = BuildTree();
2            //打印树
3             tree.sb.ToString();
4 
5             //用树来预测
6             var test = new string[] { "True", "False", "True", "False", "False", "" };
7           
8             bool isTitle = tree.Search(test);

第三行,是把树型结构输出来,最后两行是判断一个块信息是否是标题。这个数组当然也是数值转换为离散值后的结果。

有一点必须得明确,就是决策树得剪裁,否则有可能导致内存泄漏。决策类中的78行,如果树的层次结构超过了10层,就停止生长了。其实在规则过滤和决策树预测,我选择了规则过滤,因为用决策树的结果,经测试,准确率并不高,有可能是我才开始用,没有把握精髓,所以我保守选择。

原文地址:https://www.cnblogs.com/wangqiang3311/p/7743906.html