FP-Growth in Spark MLLib

并行FP-Growth算法思路

FG-Tree example

上图的单线程形成的FP-Tree。

分布式算法事实上是对FP-Tree进行分割,分而治之

首先,假设我们只关心...|c这个conditional transaction,那么可以把每个transaction中的...|c保留,并发送到一个计算节点中,必然能在该计算节点构造出FG-Tree

root
 |   
f:3  c:1
 |
c:3

进而得到频繁集(f,c)->3.

同样,如果把所有transaction中的...|b保留,并发送到一个计算节点中,必然能在该几点构造出FG-tree

  root
  /     
f:2     c:1
 |         
c:1   b:1    b:1
 |
a:1
 |
b:1

进而得到(b)->3。

以上两个例子得到了两个tree,并且分别得到了部分结果。
事实上算法的思路就是把transaction的conditional transaction进行分割(分组),分割的依据就是conditional transaction的suffix(的hash,如果直接是suffix也可以,但是使得spark任务有过多task)。对每个分组分别构建FP-tree,然后在每个子树中获得部分结果,合并得到最终结果。

Spark Mllib中算法

遍历一次数据集输出F-List,类似wordcount,得出频繁出现的items,将F-List划分为G-List,即将频繁items进行分组:

  • F-List包含item全集I中的频繁item,F-List={f_1,...},f_i在Transaction中出现的频率>support阈值。
  • G-List={g_1,...}, g_i=hash_of(f_i)=H(f_i).

    实际上,计算的f_i hash值作为partition_id,在MLLib过程中将conditional transaction f'1,f'2,...|f_i 分发到partition_id=H(f_i)对应的计算节点。
// data即所有的transaction,每个trans是Item数组
def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
    // 计算support阈值
    val count = data.count()
    val minCount = math.ceil(minSupport * count).toLong
    val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
    val partitioner = new HashPartitioner(numParts)
    // 第一次遍历,统计frequency,过滤掉低于support阈值的item
    val freqItems:Array[Item] = genFreqItems(data, minCount, partitioner)
    // 第二次遍历
    val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
    new FPGrowthModel(freqItemsets)
  }

各组构建FP-tree

再次遍历数据集,每个trans中的items按照frequency进行降序排列,并构造conditional transactions,例如一个trans={a,b,c,d,e},a的frequency最高,以此降低,构造其相应的conditional transactions:

a,b,c,d,e:
condition trans  ;  partition_id
      a,b,c,d|e  ;   partition = H(e)
      a,b,c|d    ;   partition = H(d)
      a,b|c      ;   partition = H(c)
      a|b        ;   partition = H(b)
      a          ;   partition = H(a)

对应code在genFreqItemsets.

  • 每个transaction的conditional transaction,并且按照suffix计算hash作为partition_id分组
  • 各个partition_id对应的[condition items]所有集合,即G-List,对G-List的agg即为构造FP-Tree过程
  • 在各个part中提取该part包含的频繁集。在part子树中,node x,若hash(x)=part_id,并且x到root路径能形成频繁集,则输出path(x->root)中的各个节点作为频繁集。
  • 将rank转为对应的item
private def genFreqItemsets[Item: ClassTag](
      data: RDD[Array[Item]],  // transactions
      minCount: Long,          // support threshold
      freqItems: Array[Item],  // FP-List
      partitioner: Partitioner): RDD[FreqItemset[Item]] = {
   // freqItems已经排序了,zip出每个Item的rank
    val itemToRank = freqItems.zipWithIndex.toMap
    // 形成partition_id->[condition items]
    data.flatMap { transaction =>
      // 计算conditional transactions
      genCondTransactions(transaction, itemToRank, partitioner)
    } 
    // 各个partition_id对应的[condition items]所有集合,即G-List,
    // 对G-List的agg即为构造FP-Tree过程
    .aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
      (tree, transaction) => tree.add(transaction, 1L),
      (tree1, tree2) => tree1.merge(tree2)) 
    // 在各个part中提取该part包含的频繁集
    .flatMap { case (part, tree) =>
      tree.extract(minCount, x => partitioner.getPartition(x) == part)
    } 
    // 将rank转为对应的item
    .map { case (ranks, count) =>
      new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
    }
  }

计算conditional transactions

  • itemToRank,rank越小对应的frequency是越大的
  • 每个trans中筛出frequent Item,并对rank排序,得到的item即按照frequency由大到小排序
    • FP_list={a,b,c,d,e,f}
    • 一个trans=[f,e,d,a,c], 那么将得到[0,2,3,4,5]
  • 构造conditional transaction
    • 例如0,2|3 计算3的partition_id(3), 形成partition_id(3)->[0,2,3]
private def genCondTransactions[Item: ClassTag](
      transaction: Array[Item],
      itemToRank: Map[Item, Int],
      partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
    val output = mutable.Map.empty[Int, Array[Int]]
    // Filter the basket by frequent items pattern and sort their ranks.
    val filtered = transaction.flatMap(itemToRank.get)
    ju.Arrays.sort(filtered)
    val n = filtered.length
    var i = n - 1
    while (i >= 0) {
      val item = filtered(i)
      val part = partitioner.getPartition(item)
      if (!output.contains(part)) {
        output(part) = filtered.slice(0, i + 1)
      }
      i -= 1
    }
    output
  }
原文地址:https://www.cnblogs.com/luweiseu/p/7768838.html