spark aggregate算子

spark aggregate源代码

  /**
   * Aggregate the elements of each partition, and then the results for all the partitions, using
   * given combine functions and a neutral "zero value". This function can return a different result
   * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
   * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
   * allowed to modify and return their first argument instead of creating a new U to avoid memory
   * allocation.
   */
  def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
    combOp: JFunction2[U, U, U]): U =
    rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])

aggregate用于聚合RDD中的元素,先使用seqOp将RDD中每个分区中的T类型元素聚合成U类型,
再使用combOp将之前每个分区聚合后的U类型聚合成U类型,注意seqOp和combOp都会使用zeroValue的值,zeroValue的类型为U。


样例代码:

需要注意的是:

单分区和多分区是不一样的。

        List<Integer> list = new ArrayList<>();
        list.add(2);
        list.add(3);
        list.add(2);
        list.add(5);
        list.add(2);
        list.add(6);

        //单分区情况下
        JavaRDD<Integer> rdd1 = sc.parallelize(list,1);
        System.out.println("NumPartitions :"+rdd1.getNumPartitions());

        int result1 = rdd1.aggregate(1, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 * v2;//等同于zeroValue*2得到的值再*3...同理得到的值再*2*5*2*6等于720
            }
        }, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 + v2;//等同于zeroValue+前面函数得到的值,也就是1+720=721
            }
        });
        System.out.println("result1: "+result1);


        //多分区情况下
        JavaRDD<Integer> rdd2 = sc.parallelize(list,2);
        System.out.println("NumPartitions :"+rdd2.getNumPartitions());
        JavaRDD<String> mapPartitionsWithIndex = rdd2.mapPartitionsWithIndex(new Function2<Integer, Iterator<Integer>, Iterator<String>>() {
            @Override
            public Iterator<String> call(Integer part_id, Iterator<Integer> iterator) throws Exception {
                List<String> list = new ArrayList<>();
                while (iterator.hasNext()) {
                    list.add("partition" + part_id + ":" + iterator.next());
                }
                return list.iterator();
            }
        }, true);
        mapPartitionsWithIndex.foreachPartition((VoidFunction<Iterator<String>>) iterator -> {
            while (iterator.hasNext()) {
                System.out.println(iterator.next());
            }
        });
        //输出结果:
//        partition0:2
//        partition0:3
//        partition0:2
//        partition1:5
//        partition1:2
//        partition1:6

        int result2 = rdd2.aggregate(2, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 * v2;
                //这次修改zeroValue为2
                //partition0中的元素有2,3,2 计算结果是2*2*3*2=24 其中2指zeroValue
                //partition0中的元素有5,2,6 计算结果是2*5*2*6=120 其中2指zeroValue
            }
        }, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 + v2;
                //计算结果2+24+120=146,其中2指zeroValue
            }
        });
        System.out.println("result2: "+result2);
原文地址:https://www.cnblogs.com/zz-ksw/p/12162249.html