Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Moskalenko <[email protected]>
  • Loading branch information
pyalex committed Dec 24, 2020
1 parent 158a887 commit 372bd69
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 33 deletions.
4 changes: 2 additions & 2 deletions spark/ingestion/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@
<dependency>
<groupId>com.dimafeng</groupId>
<artifactId>testcontainers-scala-scalatest_${scala.version}</artifactId>
<version>0.38.6</version>
<version>0.38.8</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.dimafeng</groupId>
<artifactId>testcontainers-scala-kafka_${scala.version}</artifactId>
<version>0.38.6</version>
<version>0.38.8</version>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ object BatchPipeline extends BasePipeline {
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics

val input = config.source match {
case source: BQSource =>
Expand Down Expand Up @@ -68,7 +69,7 @@ object BatchPipeline extends BasePipeline {
}

val validRows = projected
.mapPartitions(IngestionPipelineMetrics.incrementRead)
.mapPartitions(metrics.incrementRead)
.filter(rowValidator.allChecks)

validRows.write
Expand All @@ -84,7 +85,7 @@ object BatchPipeline extends BasePipeline {
case Some(path) =>
projected
.filter(!rowValidator.allChecks)
.mapPartitions(IngestionPipelineMetrics.incrementDeadletters)
.mapPartitions(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ object StreamingPipeline extends BasePipeline with Serializable {
val featureTable = config.featureTable
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)

val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics
val validationUDF = createValidationUDF(sparkSession, config)

val input = config.source match {
Expand Down Expand Up @@ -107,7 +107,7 @@ object StreamingPipeline extends BasePipeline with Serializable {
implicit def rowEncoder: Encoder[Row] = RowEncoder(rowsAfterValidation.schema)

rowsAfterValidation
.mapPartitions(IngestionPipelineMetrics.incrementRead)
.mapPartitions(metrics.incrementRead)
.filter(if (config.doNotIngestInvalidRows) expr("_isValid") else rowValidator.allChecks)
.write
.format("feast.ingestion.stores.redis")
Expand All @@ -120,10 +120,9 @@ object StreamingPipeline extends BasePipeline with Serializable {

config.deadLetterPath match {
case Some(path) =>

rowsAfterValidation
.filter("!_isValid")
.mapPartitions(IngestionPipelineMetrics.incrementDeadletters)
.mapPartitions(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,39 @@ package feast.ingestion.metrics

import org.apache.spark.SparkEnv
import org.apache.spark.metrics.source.IngestionPipelineMetricSource
import org.apache.spark.sql.Row

object IngestionPipelineMetrics {
def incrementDeadletters[A](rowIterator: Iterator[A]): Iterator[A] = {
class IngestionPipelineMetrics extends Serializable {

def incrementDeadLetters(rowIterator: Iterator[Row]): Iterator[Row] = {
val materialized = rowIterator.toArray
if (metricSource.nonEmpty)
metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(rowIterator.length)
metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(materialized.length)

rowIterator
materialized.toIterator
}

def incrementRead[A](rowIterator: Iterator[A]): Iterator[A] = {
def incrementRead(rowIterator: Iterator[Row]): Iterator[Row] = {
val materialized = rowIterator.toArray
if (metricSource.nonEmpty)
metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(rowIterator.length)
metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(materialized.length)

rowIterator
materialized.toIterator
}

private lazy val metricSource: Option[IngestionPipelineMetricSource] = {
this.synchronized {
if (
SparkEnv.get.metricsSystem
.getSourcesByName(IngestionPipelineMetricSource.sourceName)
.isEmpty
) {
SparkEnv.get.metricsSystem.registerSource(new IngestionPipelineMetricSource)
val metricsSystem = SparkEnv.get.metricsSystem
IngestionPipelineMetricsLock.synchronized {
if (metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName).isEmpty) {
metricsSystem.registerSource(new IngestionPipelineMetricSource)
}
}

SparkEnv.get.metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match {
metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match {
case Seq(head) => Some(head.asInstanceOf[IngestionPipelineMetricSource])
case _ => None
}
}
}

private object IngestionPipelineMetricsLock
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,13 @@ class StatsdReporterWithTags(
private def reportGauge(name: String, gauge: Gauge[_])(implicit socket: DatagramSocket): Unit =
formatAny(gauge.getValue).foreach(v => send(fullName(name), v, GAUGE))

private def reportCounter(name: String, counter: Counter)(implicit socket: DatagramSocket): Unit =
send(fullName(name), format(counter.getCount), COUNTER)
private def reportCounter(name: String, counter: Counter)(implicit
socket: DatagramSocket
): Unit = {
val snapshot = counter.getCount
send(fullName(name), format(snapshot), COUNTER)
counter.dec(snapshot) // reset counter
}

private def reportHistogram(name: String, histogram: Histogram)(implicit
socket: DatagramSocket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import collection.JavaConverters._
import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer}
import com.google.protobuf.util.Timestamps
import feast.proto.types.ValueProto.ValueType
import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkEnv}
import org.joda.time.{DateTime, Seconds}
import org.scalacheck._
import org.scalatest._
Expand All @@ -46,25 +46,28 @@ case class TestRow(
class BatchPipelineIT extends SparkSpec with ForAllTestContainer {

override val container = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379))
val statsDStub = new StatsDStub

override def withSparkConfOverrides(conf: SparkConf): SparkConf = conf
.set("spark.redis.host", container.host)
.set("spark.redis.port", container.mappedPort(6379).toString)
.set("spark.metrics.conf.*.sink.statsd.port", statsDStub.port.toString)

trait Scope {
val jedis = new Jedis("localhost", container.mappedPort(6379))
jedis.flushAll()

statsDStub.receivedMetrics // clean the buffer

implicit def testRowEncoder: Encoder[TestRow] = ExpressionEncoder()
val statsDStub = new StatsDStub

def rowGenerator(start: DateTime, end: DateTime, customerGen: Option[Gen[String]] = None) =
for {
customer <- customerGen.getOrElse(Gen.asciiPrintableStr)
feature1 <- Gen.choose(0, 100)
feature2 <- Gen.choose[Float](0, 1)
eventTimestamp <- Gen
.choose(0, Seconds.secondsBetween(start, end).getSeconds)
.choose(0, Seconds.secondsBetween(start, end).getSeconds - 1)
.map(start.withMillisOfSecond(0).plusSeconds)
} yield TestRow(
customer,
Expand Down Expand Up @@ -98,7 +101,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
),
startTime = DateTime.parse("2020-08-01"),
endTime = DateTime.parse("2020-09-01"),
metrics = Some(StatsDConfig(host="localhost", port=statsDStub.port))
metrics = Some(StatsDConfig(host = "localhost", port = statsDStub.port))
)
}

Expand Down Expand Up @@ -129,6 +132,14 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
keyTTL shouldEqual -1

})

SparkEnv.get.metricsSystem.report()
statsDStub.receivedMetrics should contain.allElementsOf(
Map(
"driver.ingestion_pipeline.read_from_source_count" -> rows.length,
"driver.redis_sink.feature_row_ingested_count" -> rows.length
)
)
}

"Parquet source file" should "be ingested in redis with expiry time equal to the largest of (event_timestamp + max_age) for" +
Expand Down Expand Up @@ -466,6 +477,13 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
.toString
)
.count() should be(rows.length)

SparkEnv.get.metricsSystem.report()
statsDStub.receivedMetrics should contain.allElementsOf(
Map(
"driver.ingestion_pipeline.deadletter_count" -> rows.length
)
)
}

"Columns from source" should "be mapped according to configuration" in new Scope {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class SparkSpec extends UnitSpec with BeforeAndAfter {
"org.apache.spark.metrics.sink.StatsdSinkWithTags"
)
.set("spark.metrics.conf.*.sink.statsd.host", "localhost")
.set("spark.metrics.conf.*.sink.statsd.port", "8125")
.set("spark.metrics.conf.*.sink.statsd.period", "999") // disable scheduled reporting
.set("spark.metrics.conf.*.sink.statsd.unit", "minutes")
.set("spark.metrics.labels", "job_id=test")
.set("spark.metrics.namespace", "")
.set("spark.sql.legacy.allowUntypedScalaUDF", "true")
.set("spark.sql.execution.arrow.maxRecordsPerBatch", "50000")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2020 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.metrics

import java.net.{DatagramPacket, DatagramSocket, SocketTimeoutException}
Expand All @@ -12,11 +28,11 @@ class StatsDStub {

def receive: Array[String] = {
val messages: ArrayBuffer[String] = ArrayBuffer()
var finished = false
var finished = false

do {
val buf = new Array[Byte](65535)
val p = new DatagramPacket(buf, buf.length)
val p = new DatagramPacket(buf, buf.length)
try {
socket.receive(p)
} catch {
Expand All @@ -28,4 +44,18 @@ class StatsDStub {

messages.toArray
}

private val metricLine = """(.+):(.+)\|(.+)#(.+)""".r

def receivedMetrics: Map[String, Float] = {
receive
.flatMap {
case metricLine(name, value, type_, tags) =>
Seq(name -> value.toFloat)
case s: String =>
Seq()
}
.groupBy(_._1)
.mapValues(_.map(_._2).sum)
}
}

0 comments on commit 372bd69

Please sign in to comment.