diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml
index cda054efe9..0e0bf8385c 100644
--- a/spark/ingestion/pom.xml
+++ b/spark/ingestion/pom.xml
@@ -189,14 +189,14 @@
com.dimafeng
testcontainers-scala-scalatest_${scala.version}
- 0.38.6
+ 0.38.8
test
com.dimafeng
testcontainers-scala-kafka_${scala.version}
- 0.38.6
+ 0.38.8
test
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala
index 3b8675f8ec..14e7356943 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala
@@ -39,6 +39,7 @@ object BatchPipeline extends BasePipeline {
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
+ val metrics = new IngestionPipelineMetrics
val input = config.source match {
case source: BQSource =>
@@ -68,7 +69,7 @@ object BatchPipeline extends BasePipeline {
}
val validRows = projected
- .mapPartitions(IngestionPipelineMetrics.incrementRead)
+ .mapPartitions(metrics.incrementRead)
.filter(rowValidator.allChecks)
validRows.write
@@ -84,7 +85,7 @@ object BatchPipeline extends BasePipeline {
case Some(path) =>
projected
.filter(!rowValidator.allChecks)
- .mapPartitions(IngestionPipelineMetrics.incrementDeadletters)
+ .mapPartitions(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
index c6ed87df85..974c7d921d 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
@@ -56,8 +56,8 @@ object StreamingPipeline extends BasePipeline with Serializable {
val featureTable = config.featureTable
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
- val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
-
+ val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
+ val metrics = new IngestionPipelineMetrics
val validationUDF = createValidationUDF(sparkSession, config)
val input = config.source match {
@@ -107,7 +107,7 @@ object StreamingPipeline extends BasePipeline with Serializable {
implicit def rowEncoder: Encoder[Row] = RowEncoder(rowsAfterValidation.schema)
rowsAfterValidation
- .mapPartitions(IngestionPipelineMetrics.incrementRead)
+ .mapPartitions(metrics.incrementRead)
.filter(if (config.doNotIngestInvalidRows) expr("_isValid") else rowValidator.allChecks)
.write
.format("feast.ingestion.stores.redis")
@@ -120,10 +120,9 @@ object StreamingPipeline extends BasePipeline with Serializable {
config.deadLetterPath match {
case Some(path) =>
-
rowsAfterValidation
.filter("!_isValid")
- .mapPartitions(IngestionPipelineMetrics.incrementDeadletters)
+ .mapPartitions(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala b/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala
index 060547e611..03d41eddd5 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala
@@ -18,36 +18,39 @@ package feast.ingestion.metrics
import org.apache.spark.SparkEnv
import org.apache.spark.metrics.source.IngestionPipelineMetricSource
+import org.apache.spark.sql.Row
-object IngestionPipelineMetrics {
- def incrementDeadletters[A](rowIterator: Iterator[A]): Iterator[A] = {
+class IngestionPipelineMetrics extends Serializable {
+
+ def incrementDeadLetters(rowIterator: Iterator[Row]): Iterator[Row] = {
+ val materialized = rowIterator.toArray
if (metricSource.nonEmpty)
- metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(rowIterator.length)
+ metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(materialized.length)
- rowIterator
+ materialized.toIterator
}
- def incrementRead[A](rowIterator: Iterator[A]): Iterator[A] = {
+ def incrementRead(rowIterator: Iterator[Row]): Iterator[Row] = {
+ val materialized = rowIterator.toArray
if (metricSource.nonEmpty)
- metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(rowIterator.length)
+ metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(materialized.length)
- rowIterator
+ materialized.toIterator
}
private lazy val metricSource: Option[IngestionPipelineMetricSource] = {
- this.synchronized {
- if (
- SparkEnv.get.metricsSystem
- .getSourcesByName(IngestionPipelineMetricSource.sourceName)
- .isEmpty
- ) {
- SparkEnv.get.metricsSystem.registerSource(new IngestionPipelineMetricSource)
+ val metricsSystem = SparkEnv.get.metricsSystem
+ IngestionPipelineMetricsLock.synchronized {
+ if (metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName).isEmpty) {
+ metricsSystem.registerSource(new IngestionPipelineMetricSource)
}
}
- SparkEnv.get.metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match {
+ metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match {
case Seq(head) => Some(head.asInstanceOf[IngestionPipelineMetricSource])
case _ => None
}
}
}
+
+private object IngestionPipelineMetricsLock
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala b/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala
index 894014b6fd..880fc2b982 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala
@@ -125,8 +125,13 @@ class StatsdReporterWithTags(
private def reportGauge(name: String, gauge: Gauge[_])(implicit socket: DatagramSocket): Unit =
formatAny(gauge.getValue).foreach(v => send(fullName(name), v, GAUGE))
- private def reportCounter(name: String, counter: Counter)(implicit socket: DatagramSocket): Unit =
- send(fullName(name), format(counter.getCount), COUNTER)
+ private def reportCounter(name: String, counter: Counter)(implicit
+ socket: DatagramSocket
+ ): Unit = {
+ val snapshot = counter.getCount
+ send(fullName(name), format(snapshot), COUNTER)
+ counter.dec(snapshot) // reset counter
+ }
private def reportHistogram(name: String, histogram: Histogram)(implicit
socket: DatagramSocket
diff --git a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
index f3054c0423..0e9a96eb42 100644
--- a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
+++ b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
@@ -23,7 +23,7 @@ import collection.JavaConverters._
import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer}
import com.google.protobuf.util.Timestamps
import feast.proto.types.ValueProto.ValueType
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
import org.joda.time.{DateTime, Seconds}
import org.scalacheck._
import org.scalatest._
@@ -46,17 +46,20 @@ case class TestRow(
class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
override val container = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379))
+ val statsDStub = new StatsDStub
override def withSparkConfOverrides(conf: SparkConf): SparkConf = conf
.set("spark.redis.host", container.host)
.set("spark.redis.port", container.mappedPort(6379).toString)
+ .set("spark.metrics.conf.*.sink.statsd.port", statsDStub.port.toString)
trait Scope {
val jedis = new Jedis("localhost", container.mappedPort(6379))
jedis.flushAll()
+ statsDStub.receivedMetrics // clean the buffer
+
implicit def testRowEncoder: Encoder[TestRow] = ExpressionEncoder()
- val statsDStub = new StatsDStub
def rowGenerator(start: DateTime, end: DateTime, customerGen: Option[Gen[String]] = None) =
for {
@@ -64,7 +67,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
feature1 <- Gen.choose(0, 100)
feature2 <- Gen.choose[Float](0, 1)
eventTimestamp <- Gen
- .choose(0, Seconds.secondsBetween(start, end).getSeconds)
+ .choose(0, Seconds.secondsBetween(start, end).getSeconds - 1)
.map(start.withMillisOfSecond(0).plusSeconds)
} yield TestRow(
customer,
@@ -98,7 +101,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
),
startTime = DateTime.parse("2020-08-01"),
endTime = DateTime.parse("2020-09-01"),
- metrics = Some(StatsDConfig(host="localhost", port=statsDStub.port))
+ metrics = Some(StatsDConfig(host = "localhost", port = statsDStub.port))
)
}
@@ -129,6 +132,14 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
keyTTL shouldEqual -1
})
+
+ SparkEnv.get.metricsSystem.report()
+ statsDStub.receivedMetrics should contain.allElementsOf(
+ Map(
+ "driver.ingestion_pipeline.read_from_source_count" -> rows.length,
+ "driver.redis_sink.feature_row_ingested_count" -> rows.length
+ )
+ )
}
"Parquet source file" should "be ingested in redis with expiry time equal to the largest of (event_timestamp + max_age) for" +
@@ -466,6 +477,13 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
.toString
)
.count() should be(rows.length)
+
+ SparkEnv.get.metricsSystem.report()
+ statsDStub.receivedMetrics should contain.allElementsOf(
+ Map(
+ "driver.ingestion_pipeline.deadletter_count" -> rows.length
+ )
+ )
}
"Columns from source" should "be mapped according to configuration" in new Scope {
diff --git a/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala b/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala
index 1a03073334..025a3b8be1 100644
--- a/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala
+++ b/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala
@@ -36,7 +36,10 @@ class SparkSpec extends UnitSpec with BeforeAndAfter {
"org.apache.spark.metrics.sink.StatsdSinkWithTags"
)
.set("spark.metrics.conf.*.sink.statsd.host", "localhost")
- .set("spark.metrics.conf.*.sink.statsd.port", "8125")
+ .set("spark.metrics.conf.*.sink.statsd.period", "999") // disable scheduled reporting
+ .set("spark.metrics.conf.*.sink.statsd.unit", "minutes")
+ .set("spark.metrics.labels", "job_id=test")
+ .set("spark.metrics.namespace", "")
.set("spark.sql.legacy.allowUntypedScalaUDF", "true")
.set("spark.sql.execution.arrow.maxRecordsPerBatch", "50000")
diff --git a/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala b/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala
index 29ddcb75d7..62a6d5a8e3 100644
--- a/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala
+++ b/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala
@@ -1,3 +1,19 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package feast.ingestion.metrics
import java.net.{DatagramPacket, DatagramSocket, SocketTimeoutException}
@@ -12,11 +28,11 @@ class StatsDStub {
def receive: Array[String] = {
val messages: ArrayBuffer[String] = ArrayBuffer()
- var finished = false
+ var finished = false
do {
val buf = new Array[Byte](65535)
- val p = new DatagramPacket(buf, buf.length)
+ val p = new DatagramPacket(buf, buf.length)
try {
socket.receive(p)
} catch {
@@ -28,4 +44,18 @@ class StatsDStub {
messages.toArray
}
+
+ private val metricLine = """(.+):(.+)\|(.+)#(.+)""".r
+
+ def receivedMetrics: Map[String, Float] = {
+ receive
+ .flatMap {
+ case metricLine(name, value, type_, tags) =>
+ Seq(name -> value.toFloat)
+ case s: String =>
+ Seq()
+ }
+ .groupBy(_._1)
+ .mapValues(_.map(_._2).sum)
+ }
}