Spark反压实现原理解析

在Spark中要启用反压机制,需要将配置spark.streaming.backpressure.enabled设置为true,默认为false

具体实现

1.RateController

该类继承StreamingListener,是一个监听器

/**
 * A StreamingListener that receives batch completion updates, and maintains
 * an estimate of the speed at which this stream should ingest messages,
 * given an estimate computation from a `RateEstimator`
 */
private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator)
    extends StreamingListener with Serializable {

  init()

  protected def publish(rate: Long): Unit

  @transient
  implicit private var executionContext: ExecutionContext = _

  @transient
  private var rateLimit: AtomicLong = _

  /**
   * An initialization method called both from the constructor and Serialization code.
   */
  private def init(): Unit = {
    executionContext = ExecutionContext.fromExecutorService(
      ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))
    rateLimit = new AtomicLong(-1L)
  }

  private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
    ois.defaultReadObject()
    init()
  }

  /**
   * Compute the new rate limit and publish it asynchronously.
   */
  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        //发布新的流量阈值
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  //监听StreamingListenerBatchCompleted事件,表示一个批次完成
  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    //获取处理完成的时间戳 实际处理时长 调度时延
    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      //获取批次输入数据的条数
      elems <- elements.get(streamUID).map(_.numRecords)
      //将4个参数传递给RateEstimator,计算新的流量阈值
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.get(BACKPRESSURE_ENABLED)
}

2.RateEstimator

创建一个PIDRateEstimator实例

/**
 * A component that estimates the rate at which an `InputDStream` should ingest
 * records, based on updates at every batch completion.
 *
 * Please see `org.apache.spark.streaming.scheduler.RateController` for more details.
 */
private[streaming] trait RateEstimator extends Serializable {

  /**
   * Computes the number of records the stream attached to this `RateEstimator`
   * should ingest per second, given an update on the size and completion
   * times of the latest batch.
   *
   * @param time The timestamp of the current batch interval that just finished
   * @param elements The number of records that were processed in this batch
   * @param processingDelay The time in ms that took for the job to complete
   * @param schedulingDelay The time in ms that the job spent in the scheduling queue
   */
  def compute(
      time: Long,
      elements: Long,
      processingDelay: Long,
      schedulingDelay: Long): Option[Double]
}

object RateEstimator {

  /**
   * Return a new `RateEstimator` based on the value of
   * `spark.streaming.backpressure.rateEstimator`.
   *
   * The only known and acceptable estimator right now is `pid`.
   *
   * @return An instance of RateEstimator
   * @throws IllegalArgumentException if the configured RateEstimator is not `pid`.
   */
  def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
  //spark.streaming.backpressure.rateEstimator
    conf.get(BACKPRESSURE_RATE_ESTIMATOR) match {
      case "pid" =>
        //spark.streaming.backpressure.pid.proportional 默认值1.0
        val proportional = conf.get(BACKPRESSURE_PID_PROPORTIONAL)
        //spark.streaming.backpressure.pid.integral 默认值0.2
        val integral = conf.get(BACKPRESSURE_PID_INTEGRAL)
        //spark.streaming.backpressure.pid.derived 默认值0.0
        val derived = conf.get(BACKPRESSURE_PID_DERIVED)
        //spark.streaming.backpressure.pid.minRate 默认值100
        val minRate = conf.get(BACKPRESSURE_PID_MIN_RATE)
        new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate)

      case estimator =>
        throw new IllegalArgumentException(s"Unknown rate estimator: $estimator")
    }
}

3.PIDRateEstimator

PIDRateEstimator是RateEstimator唯一的实现类

这里利用了PID控制器的思想,将收集到的数据和一个设定值进行比较,利用它们之间的差计算新的输入值

image-20200804225609931

/**
 * Implements a proportional-integral-derivative (PID) controller which acts on
 * the speed of ingestion of elements into Spark Streaming. A PID controller works
 * by calculating an '''error''' between a measured output and a desired value. In the
 * case of Spark Streaming the error is the difference between the measured processing
 * rate (number of elements/processing delay) and the previous rate.
 *
 * @see <a href="https://en.wikipedia.org/wiki/PID_controller">PID controller (Wikipedia)</a>
 *
 * @param batchIntervalMillis the batch duration, in milliseconds
 * @param proportional how much the correction should depend on the current
 *        error. This term usually provides the bulk of correction and should be positive or zero.
 *        A value too large would make the controller overshoot the setpoint, while a small value
 *        would make the controller too insensitive. The default value is 1.
 * @param integral how much the correction should depend on the accumulation
 *        of past errors. This value should be positive or 0. This term accelerates the movement
 *        towards the desired value, but a large value may lead to overshooting. The default value
 *        is 0.2.
 * @param derivative how much the correction should depend on a prediction
 *        of future errors, based on current rate of change. This value should be positive or 0.
 *        This term is not used very often, as it impacts stability of the system. The default
 *        value is 0.
 * @param minRate what is the minimum rate that can be estimated.
 *        This must be greater than zero, so that the system always receives some data for rate
 *        estimation to work.
 */
private[streaming] class PIDRateEstimator(
    batchIntervalMillis: Long,
    proportional: Double, //  1.0 比例增益Kp
    integral: Double, // 0.2 积分增益Ki
    derivative: Double, // 0.0 微分增益Kd
    minRate: Double // 100
  ) extends RateEstimator with Logging {

  private var firstRun: Boolean = true
  private var latestTime: Long = -1L
  private var latestRate: Double = -1D
  private var latestError: Double = -1L

  require(
    batchIntervalMillis > 0,
    s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.")
  require(
    proportional >= 0,
    s"Proportional term $proportional in PIDRateEstimator should be >= 0.")
  require(
    integral >= 0,
    s"Integral term $integral in PIDRateEstimator should be >= 0.")
  require(
    derivative >= 0,
    s"Derivative term $derivative in PIDRateEstimator should be >= 0.")
  require(
    minRate > 0,
    s"Minimum rate in PIDRateEstimator should be > 0")

  logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " +
    s"derivative = $derivative, min rate = $minRate")

  def compute(
      time: Long, // in milliseconds
      numElements: Long,
      processingDelay: Long, // in milliseconds
      schedulingDelay: Long // in milliseconds
    ): Option[Double] = {
    logTrace(s"
time = $time, # records = $numElements, " +
      s"processing time = $processingDelay, scheduling delay = $schedulingDelay")
    this.synchronized {
      if (time > latestTime && numElements > 0 && processingDelay > 0) {

        // in seconds, should be close to batchDuration
        val delaySinceUpdate = (time - latestTime).toDouble / 1000

        // in elements/second 处理速率
        val processingRate = numElements.toDouble / processingDelay * 1000

        // In our system `error` is the difference between the desired rate and the measured rate
        // based on the latest batch information. We consider the desired rate to be latest rate,
        // which is what this estimator calculated for the previous batch.
        // in elements/second
        // 误差error是上一批次处理速率latestRate与这一批次处理速率processingRate之差
        val error = latestRate - processingRate

        // The error integral, based on schedulingDelay as an indicator for accumulated errors.
        // A scheduling delay s corresponds to s * processingRate overflowing elements. Those
        // are elements that couldn't be processed in previous batches, leading to this delay.
        // In the following, we assume the processingRate didn't change too much.
        // From the number of overflowing elements we can calculate the rate at which they would be
        // processed by dividing it by the batch interval. This rate is our "historical" error,
        // or integral part, since if we subtracted this rate from the previous "calculated rate",
        // there wouldn't have been any overflowing elements, and the scheduling delay would have
        // been zero.
        // (in elements/second)
        // 过去累计误差是调度时延的过程中数据积压的速度
        val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis

        // in elements/(second ^ 2)
        // 将来误差是error对时间微分的结果
        val dError = (error - latestError) / delaySinceUpdate

        // 得到最新速率
        val newRate = (latestRate - proportional * error -
                                    integral * historicalError -
                                    derivative * dError).max(minRate)
        logTrace(s"""
            | latestRate = $latestRate, error = $error
            | latestError = $latestError, historicalError = $historicalError
            | delaySinceUpdate = $delaySinceUpdate, dError = $dError
            """.stripMargin)

        latestTime = time
        //如果第一次运行
        if (firstRun) {
          latestRate = processingRate
          latestError = 0D
          firstRun = false
          logTrace("First run, rate estimation skipped")
          None
        } else {
          latestRate = newRate
          latestError = error
          logTrace(s"New rate = $newRate")
          Some(newRate)
        }
      } else {
        logTrace("Rate estimation skipped")
        None
      }
    }
  }
}

4.ReceiverTracker

追踪Receiver的状态,通过sendRateUpdate发布

/** Update a receiver's maximum ingestion rate */
def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized {
  if (isTrackerStarted) {
    endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
  }
}
原文地址:https://www.cnblogs.com/jordan95225/p/13437089.html