Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8835][Streaming] Provide pluggable Congestion Strategies for Receiver-based Streams #9200

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 75 additions & 75 deletions core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,45 +52,45 @@ object RandomSampler {
def newDefaultRNG: Random = new XORShiftRandom

/**
* Default maximum gap-sampling fraction.
* For sampling fractions <= this value, the gap sampling optimization will be applied.
* Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
* optimal value for this will depend on the RNG. More expensive RNGs will tend to make
* the optimal value higher. The most reliable way to determine this value for a new RNG
* is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
* in most cases, as an initial guess.
*/
* Default maximum gap-sampling fraction.
* For sampling fractions <= this value, the gap sampling optimization will be applied.
* Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
* optimal value for this will depend on the RNG. More expensive RNGs will tend to make
* the optimal value higher. The most reliable way to determine this value for a new RNG
* is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
* in most cases, as an initial guess.
*/
val defaultMaxGapSamplingFraction = 0.4

/**
* Default epsilon for floating point numbers sampled from the RNG.
* The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
* To guard against errors from taking log(0), a positive epsilon lower bound is applied.
* A good value for this parameter is at or near the minimum positive floating
* point value returned by "nextDouble()" (or equivalent), for the RNG being used.
*/
* Default epsilon for floating point numbers sampled from the RNG.
* The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
* To guard against errors from taking log(0), a positive epsilon lower bound is applied.
* A good value for this parameter is at or near the minimum positive floating
* point value returned by "nextDouble()" (or equivalent), for the RNG being used.
*/
val rngEpsilon = 5e-11

/**
* Sampling fraction arguments may be results of computation, and subject to floating
* point jitter. I check the arguments with this epsilon slop factor to prevent spurious
* warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
*/
* Sampling fraction arguments may be results of computation, and subject to floating
* point jitter. I check the arguments with this epsilon slop factor to prevent spurious
* warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
*/
val roundingEpsilon = 1e-6
}

/**
* :: DeveloperApi ::
* A sampler based on Bernoulli trials for partitioning a data sequence.
*
* @param lb lower bound of the acceptance range
* @param ub upper bound of the acceptance range
* @param complement whether to use the complement of the range specified, default to false
* @tparam T item type
*/
* :: DeveloperApi ::
* A sampler based on Bernoulli trials for partitioning a data sequence.
*
* @param lb lower bound of the acceptance range
* @param ub upper bound of the acceptance range
* @param complement whether to use the complement of the range specified, default to false
* @tparam T item type
*/
@DeveloperApi
class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false)
extends RandomSampler[T, T] {
extends RandomSampler[T, T] {

/** epsilon slop to avoid failure from floating point jitter. */
require(
Expand Down Expand Up @@ -126,8 +126,8 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
}

/**
* Return a sampler that is the complement of the range specified of the current sampler.
*/
* Return a sampler that is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliCellSampler[T] =
new BernoulliCellSampler[T](lb, ub, !complement)

Expand All @@ -143,16 +143,16 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
* @tparam T item type
*/
@DeveloperApi
class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
class BernoulliSampler[T: ClassTag](fraction: Double,
rng: Random = RandomSampler.newDefaultRNG)
extends RandomSampler[T, T] {

/** epsilon slop to avoid failure from floating point jitter */
require(
fraction >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction <= (1.0 + RandomSampler.roundingEpsilon),
&& fraction <= (1.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 1]")

private val rng: Random = RandomSampler.newDefaultRNG

override def setSeed(seed: Long): Unit = rng.setSeed(seed)

override def sample(items: Iterator[T]): Iterator[T] = {
Expand All @@ -174,15 +174,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
/**
* :: DeveloperApi ::
* A sampler for sampling with replacement, based on values drawn from Poisson distribution.
*
* @param fraction the sampling fraction (with replacement)
* @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low.
* @tparam T item type
*/
*
* @param fraction the sampling fraction (with replacement)
* @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low.
* @tparam T item type
*/
@DeveloperApi
class PoissonSampler[T: ClassTag](
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {

def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)

Expand All @@ -209,9 +209,9 @@ class PoissonSampler[T: ClassTag](
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
} else {
items.flatMap { item =>
val count = rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
val count = rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
}
}

Expand All @@ -221,10 +221,10 @@ class PoissonSampler[T: ClassTag](

private[spark]
class GapSamplingIterator[T: ClassTag](
var data: Iterator[T],
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
var data: Iterator[T],
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {

require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
Expand All @@ -235,17 +235,17 @@ class GapSamplingIterator[T: ClassTag](
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}

Expand Down Expand Up @@ -275,10 +275,10 @@ class GapSamplingIterator[T: ClassTag](

private[spark]
class GapSamplingReplacementIterator[T: ClassTag](
var data: Iterator[T],
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
var data: Iterator[T],
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {

require(f > 0.0, s"Sampling fraction ($f) must be > 0")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
Expand All @@ -289,17 +289,17 @@ class GapSamplingReplacementIterator[T: ClassTag](
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}

Expand All @@ -317,10 +317,10 @@ class GapSamplingReplacementIterator[T: ClassTag](
}

/**
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probabililty of Poisson(0; f)
*/
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probabililty of Poisson(0; f)
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / (-f)).toInt
Expand All @@ -335,10 +335,10 @@ class GapSamplingReplacementIterator[T: ClassTag](
private val q = math.exp(-f)

/**
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution
*/
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution
*/
private def poissonGE1: Int = {
// simulate that the standard poisson sampling
// gave us at least one iteration, for a sample of >= 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[streaming] trait BlockGeneratorListener {
private[streaming] class BlockGenerator(
listener: BlockGeneratorListener,
receiverId: Int,
conf: SparkConf,
private[receiver] val conf: SparkConf,
clock: Clock = new SystemClock()
) extends RateLimiter(conf) with Logging {

Expand All @@ -101,7 +101,7 @@ private[streaming] class BlockGenerator(
private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")
require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value")

private val blockIntervalTimer =
private[receiver] val blockIntervalTimer =
new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator")
private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10)
private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize)
Expand Down Expand Up @@ -226,6 +226,13 @@ private[streaming] class BlockGenerator(

def isStopped(): Boolean = state == StoppedAll

private[receiver] val congestionStrategy = CongestionStrategy.create(this)

private def pruneElementsOfBlock: Option[Iterator[Any] => Iterator[Any]] =
PartialFunction.condOpt(congestionStrategy) {
case s: DestructiveCongestionStrategy => (s.restrictCurrentBuffer _)
}

/** Change the buffer to which single records are added to. */
private def updateCurrentBuffer(time: Long): Unit = {
try {
Expand All @@ -236,7 +243,8 @@ private[streaming] class BlockGenerator(
currentBuffer = new ArrayBuffer[Any]
val blockId = StreamBlockId(receiverId, time - blockIntervalMs)
listener.onGenerateBlock(blockId)
newBlock = new Block(blockId, newBlockBuffer)
newBlock = new Block(blockId,
pruneElementsOfBlock.map(f => f(newBlockBuffer.to)).getOrElse(newBlockBuffer).to)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.receiver

import org.apache.spark.SparkConf
import scala.collection.mutable.ArrayBuffer

/**
* These traits provide a strategy to deal w/ a large amount of data seen
* at a Receiver, possibly ensuing an exhaustion of resources.
* See SPARK-7398
* Any long blocking operation in this class will hurt the throughput.
*/
private[streaming] abstract class CongestionStrategy(conf: SparkConf) {

protected val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")

/**
* Called on every batch interval with the estimated maximum number of
* elements per second that can been processed based on the processing
* speed observed over the last batch interval.
*/
def onBlockBoundUpdate(bound: Long): Unit

}

private [streaming] abstract class DestructiveCongestionStrategy(conf: SparkConf)
extends CongestionStrategy(conf) {

/**
* Given a data buffer intended for a block, return an iterator with an
* amount appropriate with respect to the back-pressure information
* provided through `onBlockBoundUpdate`.
*/
def restrictCurrentBuffer(currentBuffer: Iterator[Any]): Iterator[Any]

}

private[streaming] abstract class ThrottlingCongestionStrategy(
private[receiver] val rateLimiter: RateLimiter,
conf: SparkConf)
extends CongestionStrategy(conf)

object CongestionStrategy {

/**
* Return a new CongestionStrategy based on the value of
* `spark.streaming.backpressure.congestionStrategy`.
*
* Intended clients of this factory are receiver-based streams
*
* @return An instance of CongestionStrategy
* @throws IllegalArgumentException if there is a configured CongestionStrategy
* that doesn't match any known strategies.
*/
def create(blockGenerator: BlockGenerator): CongestionStrategy = {
blockGenerator.conf.get("spark.streaming.backpressure.congestionStrategy", "throttle") match {
case "drop" => new DropCongestionStrategy(blockGenerator.conf)
case "sample" => new SampleCongestionStrategy(blockGenerator.conf)
case "throttle" => new ThrottleCongestionStrategy(blockGenerator)
case strategy =>
throw new IllegalArgumentException(s"Unkown congestion strategy: $strategy")
}
}

}
Loading