diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml index cda054efe9..0e0bf8385c 100644 --- a/spark/ingestion/pom.xml +++ b/spark/ingestion/pom.xml @@ -189,14 +189,14 @@ com.dimafeng testcontainers-scala-scalatest_${scala.version} - 0.38.6 + 0.38.8 test com.dimafeng testcontainers-scala-kafka_${scala.version} - 0.38.6 + 0.38.8 test diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala index 3b8675f8ec..14e7356943 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala @@ -39,6 +39,7 @@ object BatchPipeline extends BasePipeline { val projection = inputProjection(config.source, featureTable.features, featureTable.entities) val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn) + val metrics = new IngestionPipelineMetrics val input = config.source match { case source: BQSource => @@ -68,7 +69,7 @@ object BatchPipeline extends BasePipeline { } val validRows = projected - .mapPartitions(IngestionPipelineMetrics.incrementRead) + .mapPartitions(metrics.incrementRead) .filter(rowValidator.allChecks) validRows.write @@ -84,7 +85,7 @@ object BatchPipeline extends BasePipeline { case Some(path) => projected .filter(!rowValidator.allChecks) - .mapPartitions(IngestionPipelineMetrics.incrementDeadletters) + .mapPartitions(metrics.incrementDeadLetters) .write .format("parquet") .mode(SaveMode.Append) diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala index c6ed87df85..974c7d921d 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala @@ -56,8 +56,8 @@ object StreamingPipeline extends BasePipeline with Serializable { val featureTable = config.featureTable val projection = inputProjection(config.source, featureTable.features, featureTable.entities) - val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn) - + val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn) + val metrics = new IngestionPipelineMetrics val validationUDF = createValidationUDF(sparkSession, config) val input = config.source match { @@ -107,7 +107,7 @@ object StreamingPipeline extends BasePipeline with Serializable { implicit def rowEncoder: Encoder[Row] = RowEncoder(rowsAfterValidation.schema) rowsAfterValidation - .mapPartitions(IngestionPipelineMetrics.incrementRead) + .mapPartitions(metrics.incrementRead) .filter(if (config.doNotIngestInvalidRows) expr("_isValid") else rowValidator.allChecks) .write .format("feast.ingestion.stores.redis") @@ -120,10 +120,9 @@ object StreamingPipeline extends BasePipeline with Serializable { config.deadLetterPath match { case Some(path) => - rowsAfterValidation .filter("!_isValid") - .mapPartitions(IngestionPipelineMetrics.incrementDeadletters) + .mapPartitions(metrics.incrementDeadLetters) .write .format("parquet") .mode(SaveMode.Append) diff --git a/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala b/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala index 060547e611..03d41eddd5 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/metrics/IngestionPipelineMetrics.scala @@ -18,36 +18,39 @@ package feast.ingestion.metrics import org.apache.spark.SparkEnv import org.apache.spark.metrics.source.IngestionPipelineMetricSource +import org.apache.spark.sql.Row -object IngestionPipelineMetrics { - def incrementDeadletters[A](rowIterator: Iterator[A]): Iterator[A] = { +class IngestionPipelineMetrics extends Serializable { + + def incrementDeadLetters(rowIterator: Iterator[Row]): Iterator[Row] = { + val materialized = rowIterator.toArray if (metricSource.nonEmpty) - metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(rowIterator.length) + metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(materialized.length) - rowIterator + materialized.toIterator } - def incrementRead[A](rowIterator: Iterator[A]): Iterator[A] = { + def incrementRead(rowIterator: Iterator[Row]): Iterator[Row] = { + val materialized = rowIterator.toArray if (metricSource.nonEmpty) - metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(rowIterator.length) + metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(materialized.length) - rowIterator + materialized.toIterator } private lazy val metricSource: Option[IngestionPipelineMetricSource] = { - this.synchronized { - if ( - SparkEnv.get.metricsSystem - .getSourcesByName(IngestionPipelineMetricSource.sourceName) - .isEmpty - ) { - SparkEnv.get.metricsSystem.registerSource(new IngestionPipelineMetricSource) + val metricsSystem = SparkEnv.get.metricsSystem + IngestionPipelineMetricsLock.synchronized { + if (metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName).isEmpty) { + metricsSystem.registerSource(new IngestionPipelineMetricSource) } } - SparkEnv.get.metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match { + metricsSystem.getSourcesByName(IngestionPipelineMetricSource.sourceName) match { case Seq(head) => Some(head.asInstanceOf[IngestionPipelineMetricSource]) case _ => None } } } + +private object IngestionPipelineMetricsLock diff --git a/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala b/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala index 894014b6fd..880fc2b982 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/metrics/StatsdReporterWithTags.scala @@ -125,8 +125,13 @@ class StatsdReporterWithTags( private def reportGauge(name: String, gauge: Gauge[_])(implicit socket: DatagramSocket): Unit = formatAny(gauge.getValue).foreach(v => send(fullName(name), v, GAUGE)) - private def reportCounter(name: String, counter: Counter)(implicit socket: DatagramSocket): Unit = - send(fullName(name), format(counter.getCount), COUNTER) + private def reportCounter(name: String, counter: Counter)(implicit + socket: DatagramSocket + ): Unit = { + val snapshot = counter.getCount + send(fullName(name), format(snapshot), COUNTER) + counter.dec(snapshot) // reset counter + } private def reportHistogram(name: String, histogram: Histogram)(implicit socket: DatagramSocket diff --git a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala index f3054c0423..0e9a96eb42 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala @@ -23,7 +23,7 @@ import collection.JavaConverters._ import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer} import com.google.protobuf.util.Timestamps import feast.proto.types.ValueProto.ValueType -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.joda.time.{DateTime, Seconds} import org.scalacheck._ import org.scalatest._ @@ -46,17 +46,20 @@ case class TestRow( class BatchPipelineIT extends SparkSpec with ForAllTestContainer { override val container = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379)) + val statsDStub = new StatsDStub override def withSparkConfOverrides(conf: SparkConf): SparkConf = conf .set("spark.redis.host", container.host) .set("spark.redis.port", container.mappedPort(6379).toString) + .set("spark.metrics.conf.*.sink.statsd.port", statsDStub.port.toString) trait Scope { val jedis = new Jedis("localhost", container.mappedPort(6379)) jedis.flushAll() + statsDStub.receivedMetrics // clean the buffer + implicit def testRowEncoder: Encoder[TestRow] = ExpressionEncoder() - val statsDStub = new StatsDStub def rowGenerator(start: DateTime, end: DateTime, customerGen: Option[Gen[String]] = None) = for { @@ -64,7 +67,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { feature1 <- Gen.choose(0, 100) feature2 <- Gen.choose[Float](0, 1) eventTimestamp <- Gen - .choose(0, Seconds.secondsBetween(start, end).getSeconds) + .choose(0, Seconds.secondsBetween(start, end).getSeconds - 1) .map(start.withMillisOfSecond(0).plusSeconds) } yield TestRow( customer, @@ -98,7 +101,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { ), startTime = DateTime.parse("2020-08-01"), endTime = DateTime.parse("2020-09-01"), - metrics = Some(StatsDConfig(host="localhost", port=statsDStub.port)) + metrics = Some(StatsDConfig(host = "localhost", port = statsDStub.port)) ) } @@ -129,6 +132,14 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { keyTTL shouldEqual -1 }) + + SparkEnv.get.metricsSystem.report() + statsDStub.receivedMetrics should contain.allElementsOf( + Map( + "driver.ingestion_pipeline.read_from_source_count" -> rows.length, + "driver.redis_sink.feature_row_ingested_count" -> rows.length + ) + ) } "Parquet source file" should "be ingested in redis with expiry time equal to the largest of (event_timestamp + max_age) for" + @@ -466,6 +477,13 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { .toString ) .count() should be(rows.length) + + SparkEnv.get.metricsSystem.report() + statsDStub.receivedMetrics should contain.allElementsOf( + Map( + "driver.ingestion_pipeline.deadletter_count" -> rows.length + ) + ) } "Columns from source" should "be mapped according to configuration" in new Scope { diff --git a/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala b/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala index 1a03073334..025a3b8be1 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala @@ -36,7 +36,10 @@ class SparkSpec extends UnitSpec with BeforeAndAfter { "org.apache.spark.metrics.sink.StatsdSinkWithTags" ) .set("spark.metrics.conf.*.sink.statsd.host", "localhost") - .set("spark.metrics.conf.*.sink.statsd.port", "8125") + .set("spark.metrics.conf.*.sink.statsd.period", "999") // disable scheduled reporting + .set("spark.metrics.conf.*.sink.statsd.unit", "minutes") + .set("spark.metrics.labels", "job_id=test") + .set("spark.metrics.namespace", "") .set("spark.sql.legacy.allowUntypedScalaUDF", "true") .set("spark.sql.execution.arrow.maxRecordsPerBatch", "50000") diff --git a/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala b/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala index 29ddcb75d7..62a6d5a8e3 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/metrics/StatsDStub.scala @@ -1,3 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package feast.ingestion.metrics import java.net.{DatagramPacket, DatagramSocket, SocketTimeoutException} @@ -12,11 +28,11 @@ class StatsDStub { def receive: Array[String] = { val messages: ArrayBuffer[String] = ArrayBuffer() - var finished = false + var finished = false do { val buf = new Array[Byte](65535) - val p = new DatagramPacket(buf, buf.length) + val p = new DatagramPacket(buf, buf.length) try { socket.receive(p) } catch { @@ -28,4 +44,18 @@ class StatsDStub { messages.toArray } + + private val metricLine = """(.+):(.+)\|(.+)#(.+)""".r + + def receivedMetrics: Map[String, Float] = { + receive + .flatMap { + case metricLine(name, value, type_, tags) => + Seq(name -> value.toFloat) + case s: String => + Seq() + } + .groupBy(_._1) + .mapValues(_.map(_._2).sum) + } }