词频统计

统计词频

需求分析

本次作业需要完成的是一个词频统计程序。  

需求分析:

  1. 统计文件的字符数(只需要统计Ascii码,汉字不用考虑)
  2. 统计文件的单词总数
  3. 统计文件的总行数(任何字符构成的行,都需要统计)
  4. 统计文件中各单词的出现次数,输出频率最高的10个。
  5. 对给定文件夹及其递归子文件夹下的所有文件进行统计
  6. 统计两个单词(词组)在一起的频率,输出频率最高的前10个。
  7. 在Linux系统下,进行性能分析,过程写到blog中(附加题)
StatusStages预估耗时实际耗时
Accepted 计划 Planning 10分钟 10分钟
Accepted 需求分析 Analysis 20分钟 20分钟
Accepted 具体设计 Design 30分钟 60分钟
Accepted 具体编码 Code 2个小时 4个小时
Accepted 测试 Test 1个小时 2个小时
Accepted 写报告 Report 1个小时 1个小时
总计   5个小时 约9个小时

具体设计

  • 扩展字符串的类,使得字符串可以在忽略大小写和后缀数字的情况下进行比较
  • 使用哈希表进行统计
  • 寻找词频前十的单词时,维护一个10个大小的“榜单”,线性遍历哈希表,用其中的每一个元素来更新“榜单”,时间复杂度O(N)
  • 读取文件采用缓冲区流式读取
  • 设计状态机对字符和单词进行提取

代码实现过程

  • 工具类的构建

    • string_plus类:比较两个字符串是否相等时忽略大小写和后缀数字,但仍要记录原字符串
    • string_pair类:两个string_plus放在一起,新增hash字段
    • hash_table类:哈希表,在查询并更新数据时,将字典序较小的string_plus字符串留在原地
  • 状态机的构造

    • char_consumer类:输入字符流,产生合法的单词流,并在流处理过程中完成行数和字符数的统计
    • word_consumer类:输入单词流,统计单词总数和词频、词组频
  • 输入输出

    • 文件输入与输出
    • 文件夹扫描

性能分析

  • 使用拉链法处理哈希表冲突,在哈希表足够大的情况下,查询字符串的速度接近O(1)
  • 使用modified-insertion-sort的方法来寻找前10大的元素,时间复杂度为O(n)
  • 哈希函数计算简单,且不易冲突
  • 大部分函数为内联函数

在WSL上,使用g++ -O2进行编译,运行测试集的时间约9s;在Windows下,使用VS 2015进行编译,运行测试集的时间约16s。

优化报告

最开始的版本运行测试集的时间在40-60s,对此我进行了下列优化:

  • 将std::map替换成自己写的hash_table。不仅将查询和更新的复杂度由O(lgn)降低为O(1),而且解决了std::map的key不可修改的问题。std::map的设计者为了保证二叉搜索树的序结构能够持久维持,禁止调用者修改key的值,但实际上,我们不能修改的,仅仅是key之间的序关系。在本问题中,Good和GOOD123是不同的key,但是在key的比较函数中被视为相同。我需要一种能更新key但是不更改key的序的方法,而std::map并未提供,所以只能将原key删除,再插入新key,这样造成了时间浪费。而自己写hash_table就解决了这个问题。

  • 优化了字符读取

    原来的读取函数是getline,现在换为read()

  • 将部分string改为char[]

经过这几步优化,代码性能得到大幅提高

VS性能分析

使用visual studio的分析工具进行性能分析

 

Linux下的性能分析

使用gprof进行性能分析

g++ stat.cpp -g -pg
./a.out Test/
gprof -p

结果如下

  %   cumulative   self              self     total           
 time   seconds   seconds    calls  Ts/call  Ts/call  name    
  0.00      0.00     0.00 183761420     0.00     0.00  char_consumer_::consume(int)
  0.00      0.00     0.00 89113619     0.00     0.00  string_plus::~string_plus()
  0.00      0.00     0.00 66559020     0.00     0.00  string_plus& std::forward<string_plus&>(std::remove_reference<string_plus&>::type&)
  0.00      0.00     0.00 63047280     0.00     0.00  __gnu_cxx::__aligned_membuf<std::pair<string_pair, int> >::_M_ptr()
  0.00      0.00     0.00 63047280     0.00     0.00  __gnu_cxx::__aligned_membuf<std::pair<string_pair, int> >::_M_addr()
  0.00      0.00     0.00 63047280     0.00     0.00  std::_List_node<std::pair<string_pair, int> >::_M_valptr()
  0.00      0.00     0.00 60172120     0.00     0.00  std::_List_iterator<std::pair<string_pair, int> >::operator->() const
  0.00      0.00     0.00 58217474     0.00     0.00  string_plus::operator==(string_plus const&) const
  0.00      0.00     0.00 58217474     0.00     0.00  __gnu_cxx::__enable_if<std::__is_char<char>::__value, bool>::__type std::operator==<char>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
  0.00      0.00     0.00 51378113     0.00     0.00  __gnu_cxx::__aligned_membuf<std::pair<string_plus, int> >::_M_ptr()

由以上两个性能调优工具可得以下结论

  • 核心的consume函数调用次数很多,应该尽量优化(虽然我认为这两个函数很简单,性能已经快到极致了)
  • 同步IO造成了大量的时间消耗,有时间的话,可以写一个多线程异步IO的版本
  • 与string相关的函数调用次数也很多,特别是string_plus的==函数,这个函数默认实现是直接比较两个字符串。然而,发现其是性能瓶颈后,可以对它进行优化。先判断两个字符串hash值是否相等,若不相等,直接返回false。
  • 与list相关的函数调用次数也不少,说明hash table的大小还不够,或者碰撞次数太多,需要增大表的大小或修改hash函数。

代码质量分析

代码使用C++进行编写,使用了较多std标准库,代码清晰易懂,易维护,不使用影响可读性的优化。

代码架构清晰,使用了单例模式和有限状态机的模型,使得代码直接和解对应。

变量名清晰易懂,没有魔法数。

消除了g++,clang++,msvc的warning。

代码可移植性好,在WSL和Windows上均可运行(在Windows上编译需导入一个头文件)

静态分析与测试

静态分析结果

clang

stats.cpp:307:3: warning: Potential leak of memory pointed to by 'line_buf'
  char_consumer.consume(0);
  ^~~~~~~~~~~~~
1 warning generated.

可能出现的内存泄漏。(已重构)

VS

"fout"可能是"0": 这不符合函数"fprintf"的规范。

使用了assert。

assertion failed!
    
if (ch >= 32 && ch <= 126) stats.characters++;
else ch=0;

在ch不是8字节数时将其赋值为0,避免在调用isascii时触发assert。

Debug和测试记录

测试数据:助教的测试数据集、空文件、几份代码文件、含Unicode的文章等

构造的测试文件节选

test
test123
Test
asfe sg sag sdag gd fs sd d ss sddf fd d d d d   d d gsf as fa sdf 

Debug

  • 统计字符数量、词频和行数出错

    • bug原因:状态管理不当,例如在文件结束时没有给char_consumer和word_consumer信号,使得它们把两个文件连在了一起;比较函数出错,比较两个字符串的前缀时,应该先将字符toupper;
  • 哈希表开大了之后,导致程序段错误

    • bug原因:在某个函数中将hash_table作为参数传入,导致栈溢出。

经验

  • 少用奇技淫巧,多写易维护的代码
  • 代码和变量命名上尽量在算法级描述问题的解决方式,而非在具体实现级别
  • 学会使用静态分析工具帮助debug
  • 可以使用strace来跟踪系统调用
  • 对于c++这种需要程序员自己管理内存的语言,尤其要注意内存安全
  • 常常记录bug,分析自己犯错的模式,并加以改正

附录:代码

/**
* Author: Nicekingwei
* Date: 2018/3/24
1. 统计文件的字符数(只需要统计Ascii码,汉字不用考虑)
2. 统计文件的单词总数
3. 统计文件的总行数(任何字符构成的行,都需要统计)
4. 统计文件中各单词的出现次数,输出频率最高的10个。
5. 对给定文件夹及其递归子文件夹下的所有文件进行统计
6. 统计两个单词(词组)在一起的频率,输出频率最高的前10个。
7. 在Linux系统下,进行性能分析,过程写到blog中(附加题) */

#include <iostream>
#include <fstream>
#include <ctype.h>
#include <string>
#include <list>
#include <assert.h>
#include <dirent.h>
#include <cstring>
#include <limits.h>
#define EOF_CHAR INT_MAX


using namespace std;
typedef unsigned long long u64;
const size_t buffer_size = 1024 * 128;
const size_t hash_size = 1024 * 1024;


/*
 * hash table
 */
template<typename K>
struct hash_map {
    list <pair<K, size_t> > table[hash_size];

    /// increase the number of a key with default count 0
    void increase(const K &key) {
      size_t h = key.hash % hash_size;
      for (typename list<pair<K, size_t> >::iterator it = table[h].begin(); it != table[h].end(); it++) {
        if (it->first == key) {
          it->second++;
          return;
        }
      }
      table[h].push_back(make_pair(key, 1));
    }

    /// increase the number of a key, and update that key without changing hash code
    void increase_update(const K &key) {
      size_t h = key.hash % hash_size;
      for (typename list<pair<K, size_t> >::iterator it = table[h].begin(); it != table[h].end(); it++) {
        if (it->first == key) {
          if (key.compare(it->first)) {
              // update
              it->first = key;
          }
          it->second++;
          return;
        }
      }
      table[h].push_back(make_pair(key, 1));
    }

    /// find a key in hash table
    pair<K, size_t> find(const K &key) {
      size_t h = key.hash % hash_size;
      for (typename list<pair<K, size_t> >::iterator it = table[h].begin(); it != table[h].end(); it++) {
        if (it->first == key) {
          return *it;
        }
      }
      return make_pair(K(), -1);
    }
};

/*
 * string_plus
 * add refinment string and hash code to a string
 */
struct string_plus {
    string str;
    string cmp_str;
    u64 hash;

    string_plus() {}

    string_plus(const char *s) {
      hash = 0;
      str = s;
      size_t size = str.size();
      cmp_str.reserve(size);
      
      // shrink
      while (isdigit(s[size - 1])) size--;

      // hash
      for (size_t i = 0; i < size; i++) {
        char ch = (s[i] >= 'a' && s[i] <= 'z') ? s[i] - 'a' + 'A' : s[i];
        cmp_str.push_back(ch);
        hash = (hash * 147 + ch);
      }
    }

    inline bool compare(const string_plus &s) const {
      return str < s.str;
    }

    inline bool operator==(const string_plus &x) const {
      if(hash!=x.hash) return false;
      return cmp_str == x.cmp_str;
    }
};

struct string_pair : public pair<string_plus, string_plus> {
    u64 hash;

    string_pair() { hash = 0; }

    string_pair(const pair<string_plus, string_plus> &p) {
      first = p.first;
      second = p.second;
      hash = p.first.hash * p.second.hash;
    }

};

struct stats_ {
    u64 characters;
    u64 words;
    u64 lines;

    hash_map<string_plus> words_count_map;
    hash_map<string_pair> phrase_count_map;

    stats_() {
      characters = words = lines = 0;
    }

} stats;

struct word_consumer_ {

    string_plus last_ref;

    // consume a word
    void consume(const char *s) {
      string_plus ref(s);
      stats.words++;
      
      // word
      stats.words_count_map.increase_update(ref);
      
      // phrase
      if (!last_ref.str.empty())
        stats.phrase_count_map.increase(make_pair(last_ref, ref));
      
      last_ref = ref;
    }

} word_consumer;

struct char_consumer_ {

    enum {
        S_WHITE, S_ALPHA, S_NUMBER
    } state;

    char word_buf[buffer_size];
    int word_len;
    int prev;

    // judge if a word is valid
    inline bool is_valid_word() {
      if (word_len < 4) return false;
      for (size_t i = 0; i < 4; i++)
        if (!isalpha(word_buf[i]))
          return false;
      return true;
    }

    // consume a char
    inline void consume(int ch) {
      
      // count lines
      if(ch!=EOF_CHAR) {
        if( (prev=='\n') || (prev==EOF_CHAR) ) {
          stats.lines++;
        }
      } 

      // record
      prev = ch;

      // count chars
      if (ch >= 32 && ch <= 126) {
        stats.characters++;
      } else {
        ch = 0;
      }

      assert(word_len <= buffer_size);

      // dfa
      if (isalpha(ch)) {
        switch (state) {
          case S_WHITE:
            // start a new word
            word_len = 1;
            word_buf[0] = ch;
            state = S_ALPHA;
            break;
          case S_ALPHA:
            word_buf[word_len++] = ch;
            break;
          default:
            break;
        }
      } else if (isdigit(ch)) {
        switch (state) {
          case S_ALPHA:
            word_buf[word_len++] = ch;
            state = S_ALPHA;
            break;
          default:
            state = S_NUMBER;
            break;
        }
      } else {
        if (is_valid_word()) {
          word_buf[word_len++] = 0;
          word_consumer.consume(word_buf);
          word_len = 0;
        }
        state = S_WHITE;
      }
    }
} char_consumer;


template<typename T, int N>
struct top {
    pair<T, size_t> res[N];

    // find the top N items in hash table
    top(hash_map<T>& col) {
      for (size_t i = 0; i < N; i++) res[i].second = 0;

      // for each item
      for (size_t i = 0; i < hash_size; i++) {
        list <pair<T, size_t> > &slot = col.table[i];
        for (typename list<pair<T, size_t> >::iterator it = slot.begin(); it != slot.end(); it++) {
          // update the top N records
          size_t pos = N;
          while (pos >= 1 && it->second > res[pos-1].second) pos--;
          for (size_t i = N - 1; i > pos; i--) {
              res[i] = res[i - 1];
          }
          // important condition to avoid index out of bound 
          if (pos <= N - 1) res[pos] = *it;
        }
      }
    }

};

/// buffer of read
char read_buf[buffer_size];

void readfile(string filename) {

#ifdef DEBUG
  cout << filename << "\n";
#endif

  // important clear-up
  word_consumer.last_ref = "";
  char_consumer.prev = EOF_CHAR;

  ifstream fin(filename.c_str());
  while (!fin.eof()) {
    fin.read(read_buf, buffer_size);
    long long n = fin.gcount();
    for (long long i = 0; i < n; i++) char_consumer.consume(read_buf[i]);
  }
  char_consumer.consume(EOF_CHAR);
  fin.close();
}

/*
 * write back the result
 */
void write_result() {
#ifdef DEBUG
  cout<<"write\n";
#endif
  FILE *fout;
  fout = fopen("result.txt", "w+");
  assert(fout != NULL);
  fprintf(fout, "char_number :%llu\n", stats.characters);
  fprintf(fout, "line_number :%llu\n", stats.lines);
  fprintf(fout, "word_number :%llu\n", stats.words);

  // get the top 10  
  top<string_plus, 10> top_words(stats.words_count_map);
  top<string_pair, 10> top_phrase(stats.phrase_count_map);
  
  fprintf(fout, "\nthe top ten frequency of word :\n");
  for (size_t i = 0; i < 10; i++) {
    string_plus s1 = top_words.res[i].first;
    if (!s1.str.empty())
      fprintf(fout, "%s\t%d\n", s1.str.c_str(), top_words.res[i].second);
  }

  // phrase
  fprintf(fout, "\n\nthe top ten frequency of phrase :\n");
  for (size_t i = 0; i < 10; i++) {
    string_pair p = top_phrase.res[i].first;
    if (p.first.str.empty()) continue;
    pair<string_plus, int> ref1 = stats.words_count_map.find(p.first);
    pair<string_plus, int> ref2 = stats.words_count_map.find(p.second);
    assert(ref1.second != -1 && ref2.second != -1);
    fprintf(fout, "%s %s\t%d\n", ref1.first.str.c_str(), ref2.first.str.c_str(), top_phrase.res[i].second);
  }

  fclose(fout);
}

/*
 * search the directory recursively
 */
void search(string file) {
  DIR *dir = opendir(file.c_str());
  dirent *info;

  if (dir) {
    string dir_path = file + "/*.*";
    while ((info = readdir(dir)) != NULL) {
      // ignore '.' and '..'
      if ((info->d_name[0] == '.') &&
          ((info->d_name[1] == 0) || (info->d_name[1] == '.' && info->d_name[2] == 0)))
        continue;
      search(string().assign(file).append("/").append(info->d_name));
    }
    closedir(dir);
  } else {
    readfile(file);
  }
}

int main(int argc, const char *argv[]) {
  if (argc <= 1) {
    printf("fatal error: no such file or directory\n");
    return -1;
  }
  string arg = argv[1];
  search(arg);
  write_result();
  return 0;
}
原文地址:https://www.cnblogs.com/nicekingwei/p/8658863.html