diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml index d78984837c..e055006964 100644 --- a/spark/ingestion/pom.xml +++ b/spark/ingestion/pom.xml @@ -182,14 +182,14 @@ com.dimafeng testcontainers-scala-scalatest_${scala.version} - 0.38.3 + 0.38.6 test com.dimafeng testcontainers-scala-kafka_${scala.version} - 0.38.3 + 0.38.6 test diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala index a54c83140f..1348914b86 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala @@ -72,6 +72,7 @@ object BatchPipeline extends BasePipeline { .option("namespace", featureTable.name) .option("project_name", featureTable.project) .option("timestamp_column", config.source.eventTimestampColumn) + .option("max_age", config.featureTable.maxAge.getOrElse(0)) .save() config.deadLetterPath match { diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala index ce2a1652e8..2c0c391843 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala @@ -17,7 +17,6 @@ package feast.ingestion import org.joda.time.DateTime - import org.json4s._ import org.json4s.jackson.JsonMethods.{parse => parseJSON} import org.json4s.ext.JavaEnumNameSerializer @@ -29,7 +28,7 @@ object IngestionJob { new JavaEnumNameSerializer[feast.proto.types.ValueProto.ValueType.Enum]() + ShortTypeHints(List(classOf[ProtoFormat], classOf[AvroFormat])) - val parser = new scopt.OptionParser[IngestionJobConfig]("IngestionJon") { + val parser = new scopt.OptionParser[IngestionJobConfig]("IngestionJob") { // ToDo: read version from Manifest head("feast.ingestion.IngestionJob", "0.9.0-SNAPSHOT") diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala index c922a1c096..8b0ae25d7a 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala @@ -91,7 +91,8 @@ case class FeatureTable( name: String, project: String, entities: Seq[Field], - features: Seq[Field] + features: Seq[Field], + maxAge: Option[Int] = None ) case class IngestionJobConfig( diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala index 1945d4aa0f..99d5e66c88 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala @@ -87,6 +87,7 @@ object StreamingPipeline extends BasePipeline with Serializable { .option("namespace", featureTable.name) .option("project_name", featureTable.project) .option("timestamp_column", config.source.eventTimestampColumn) + .option("max_age", config.featureTable.maxAge.getOrElse(0)) .save() config.deadLetterPath match { diff --git a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/HashTypePersistence.scala b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/HashTypePersistence.scala index b34f0667c0..00ab873630 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/HashTypePersistence.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/HashTypePersistence.scala @@ -16,16 +16,17 @@ */ package feast.ingestion.stores.redis -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ -import redis.clients.jedis.{Pipeline, Response} import java.nio.charset.StandardCharsets +import java.util import com.google.common.hash.Hashing - -import scala.jdk.CollectionConverters._ import com.google.protobuf.Timestamp import feast.ingestion.utils.TypeConversion +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import redis.clients.jedis.{Pipeline, Response} + +import scala.jdk.CollectionConverters._ /** * Use Redis hash type as storage layout. Every feature is stored as separate entry in Hash. @@ -35,10 +36,10 @@ import feast.ingestion.utils.TypeConversion * Values are serialized with protobuf (`ValueProto`). */ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Serializable { - def encodeRow( - keyColumns: Array[String], - timestampField: String, - value: Row + + private def encodeRow( + value: Row, + maxExpiryTimestamp: java.sql.Timestamp ): Map[Array[Byte], Array[Byte]] = { val fields = value.schema.fields.map(_.name) val types = value.schema.fields.map(f => (f.name, f.dataType)).toMap @@ -51,49 +52,87 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser } .filter { case (k, _) => // don't store entities & timestamp - !keyColumns.contains(k) && k != config.timestampColumn + !config.entityColumns.contains(k) && k != config.timestampColumn } .map { case (k, v) => encodeKey(k) -> encodeValue(v, types(k)) } - val timestamp = Seq( + val timestampHash = Seq( ( - timestampField.getBytes, + timestampHashKey(config.namespace).getBytes, encodeValue(value.getAs[Timestamp](config.timestampColumn), TimestampType) ) ) - values ++ timestamp + val expiryUnixTimestamp = { + if (config.maxAge > 0) + value.getAs[java.sql.Timestamp](config.timestampColumn).getTime + config.maxAge * 1000 + else maxExpiryTimestamp.getTime + } + val expiryTimestamp = new java.sql.Timestamp(expiryUnixTimestamp) + val expiryTimestampHash = Seq( + ( + expiryTimestampHashKey(config.namespace).getBytes, + encodeValue(expiryTimestamp, TimestampType) + ) + ) + + values ++ timestampHash ++ expiryTimestampHash } - def encodeValue(value: Any, `type`: DataType): Array[Byte] = { + private def encodeValue(value: Any, `type`: DataType): Array[Byte] = { TypeConversion.sqlTypeToProtoValue(value, `type`).toByteArray } - def encodeKey(key: String): Array[Byte] = { + private def encodeKey(key: String): Array[Byte] = { val fullFeatureReference = s"${config.namespace}:$key" Hashing.murmur3_32.hashString(fullFeatureReference, StandardCharsets.UTF_8).asBytes() } - def save( + private def timestampHashKey(namespace: String): String = { + s"${config.timestampPrefix}:${namespace}" + } + + private def expiryTimestampHashKey(namespace: String): String = { + s"${config.expiryPrefix}:${namespace}" + } + + private def decodeTimestamp(encodedTimestamp: Array[Byte]): java.sql.Timestamp = { + new java.sql.Timestamp(Timestamp.parseFrom(encodedTimestamp).getSeconds * 1000) + } + + override def save( pipeline: Pipeline, key: Array[Byte], - value: Map[Array[Byte], Array[Byte]], - ttl: Int + row: Row, + expiryTimestamp: java.sql.Timestamp, + maxExpiryTimestamp: java.sql.Timestamp ): Unit = { - pipeline.hset(key, value.asJava) - if (ttl > 0) { - pipeline.expire(key, ttl) + val value = encodeRow(row, maxExpiryTimestamp).asJava + pipeline.hset(key, value) + if (expiryTimestamp.equals(maxExpiryTimestamp)) { + pipeline.persist(key) + } else { + pipeline.expireAt(key, expiryTimestamp.getTime / 1000) } } - def getTimestamp( + override def get( pipeline: Pipeline, - key: Array[Byte], - timestampField: String - ): Response[Array[Byte]] = { - pipeline.hget(key, timestampField.getBytes) + key: Array[Byte] + ): Response[util.Map[Array[Byte], Array[Byte]]] = { + pipeline.hgetAll(key) } + override def storedTimestamp( + value: util.Map[Array[Byte], Array[Byte]] + ): Option[java.sql.Timestamp] = { + value.asScala.toMap + .map { case (key, value) => + (key.map(_.toChar).mkString, value) + } + .get(timestampHashKey(config.namespace)) + .map(value => decodeTimestamp(value)) + } } diff --git a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/Persistence.scala b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/Persistence.scala index 47161358c2..4c4b1690c0 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/Persistence.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/Persistence.scala @@ -16,26 +16,58 @@ */ package feast.ingestion.stores.redis +import java.sql.Timestamp +import java.util + import org.apache.spark.sql.Row import redis.clients.jedis.{Pipeline, Response} +/** + * Determine how a Spark row should be serialized and stored on Redis. + */ trait Persistence { - def encodeRow( - keyColumns: Array[String], - timestampField: String, - value: Row - ): Map[Array[Byte], Array[Byte]] + /** + * Persist a Spark row to Redis + * + * @param pipeline Redis pipeline + * @param key Redis key in serialized bytes format + * @param row Row representing the value to be persist + * @param expiryTimestamp Expiry timestamp for the row + * @param maxExpiryTimestamp No ttl should be set if the expiry timestamp + * is equal to the maxExpiryTimestamp + */ def save( pipeline: Pipeline, key: Array[Byte], - value: Map[Array[Byte], Array[Byte]], - ttl: Int + row: Row, + expiryTimestamp: Timestamp, + maxExpiryTimestamp: Timestamp ): Unit - def getTimestamp( + /** + * Returns a Redis response, which can be used by `storedTimestamp` and `newExpiryTimestamp` to + * derive the currently stored event timestamp, and the updated expiry timestamp. This method will + * be called prior to persisting the row to Redis, so that `RedisSinkRelation` can decide whether + * the currently stored value should be updated. + * + * @param pipeline Redis pipeline + * @param key Redis key in serialized bytes format + * @return Redis response representing the row value + */ + def get( pipeline: Pipeline, - key: Array[Byte], - timestampField: String - ): Response[Array[Byte]] + key: Array[Byte] + ): Response[util.Map[Array[Byte], Array[Byte]]] + + /** + * Returns the currently stored event timestamp for the key and the feature table associated with the ingestion job. + * + * @param value Response returned from `get` + * @return Stored event timestamp associated with the key. Returns `None` if + * the key is not present in Redis, or if timestamp information is + * unavailable on the stored value. + */ + def storedTimestamp(value: util.Map[Array[Byte], Array[Byte]]): Option[Timestamp] + } diff --git a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/RedisSinkRelation.scala b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/RedisSinkRelation.scala index d880a6461c..9be82fbc33 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/RedisSinkRelation.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/RedisSinkRelation.scala @@ -16,21 +16,24 @@ */ package feast.ingestion.stores.redis +import java.util + import com.google.protobuf.Timestamp -import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisEndpoint, RedisNode} -import redis.clients.jedis.util.JedisClusterCRC16 +import com.google.protobuf.util.Timestamps import com.redislabs.provider.redis.util.PipelineUtils.{foreachWithPipeline, mapWithPipeline} +import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisEndpoint, RedisNode} import feast.ingestion.utils.TypeConversion +import feast.proto.storage.RedisProto.RedisKeyV2 +import feast.proto.types.ValueProto import org.apache.spark.SparkEnv import org.apache.spark.metrics.source.RedisSinkMetricSource +import org.apache.spark.sql.functions.col import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.functions.col import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import redis.clients.jedis.util.JedisClusterCRC16 -import collection.JavaConverters._ -import feast.proto.storage.RedisProto.RedisKeyV2 -import feast.proto.types.ValueProto +import scala.collection.JavaConverters._ /** * High-level writer to Redis. Relies on `Persistence` implementation for actual storage layout. @@ -57,6 +60,8 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC override def schema: StructType = ??? + val MAX_EXPIRED_TIMESTAMP = new java.sql.Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000) + val persistence: Persistence = new HashTypePersistence(config) override def insert(data: DataFrame, overwrite: Boolean): Unit = { @@ -75,27 +80,27 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC groupKeysByNode(redisConfig.hosts, rowsWithKey.keysIterator).foreach { case (node, keys) => val conn = node.connect() - // retrieve latest stored timestamp per key - val timestamps = mapWithPipeline(conn, keys) { (pipeline, key) => - persistence.getTimestamp(pipeline, key.toByteArray, timestampField) - } - - val timestampByKey = timestamps - .map(_.asInstanceOf[Array[Byte]]) - .map( - Option(_) - .map(Timestamp.parseFrom) - .map(t => new java.sql.Timestamp(t.getSeconds * 1000)) - ) - .zip(keys) - .map(_.swap) + // retrieve latest stored values + val storedValues = mapWithPipeline(conn, keys) { (pipeline, key) => + persistence.get(pipeline, key.toByteArray) + }.map(_.asInstanceOf[util.Map[Array[Byte], Array[Byte]]]) + + val timestamps = storedValues.map(persistence.storedTimestamp) + val timestampByKey = keys.zip(timestamps).toMap + + val expiryTimestampByKey = keys + .zip(storedValues) + .map { case (key, storedValue) => + (key, newExpiryTimestamp(rowsWithKey(key), storedValue)) + } .toMap foreachWithPipeline(conn, keys) { (pipeline, key) => val row = rowsWithKey(key) timestampByKey(key) match { - case Some(t) if !t.before(row.getAs[java.sql.Timestamp](config.timestampColumn)) => () + case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) => + () case _ => if (metricSource.nonEmpty) { val lag = System.currentTimeMillis() - row @@ -105,9 +110,13 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc() metricSource.get.METRIC_ROWS_LAG.update(lag) } - - val encodedRow = persistence.encodeRow(config.entityColumns, timestampField, row) - persistence.save(pipeline, key.toByteArray, encodedRow, ttl = 0) + persistence.save( + pipeline, + key.toByteArray, + row, + expiryTimestampByKey(key), + MAX_EXPIRED_TIMESTAMP + ) } } conn.close() @@ -142,10 +151,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC .build } - private def timestampField: String = { - s"${config.timestampPrefix}:${config.namespace}" - } - private lazy val metricSource: Option[RedisSinkMetricSource] = SparkEnv.get.metricsSystem.getSourcesByName(RedisSinkMetricSource.sourceName) match { case Seq(head) => Some(head.asInstanceOf[RedisSinkMetricSource]) @@ -169,4 +174,31 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC nodes.filter { node => node.startSlot <= slot && node.endSlot >= slot }.filter(_.idx == 0)(0) } + + private def newExpiryTimestamp( + row: Row, + value: util.Map[Array[Byte], Array[Byte]] + ): java.sql.Timestamp = { + val maxExpiryOtherFeatureTables: Long = value.asScala.toMap + .map { case (key, value) => + (key.map(_.toChar).mkString, value) + } + .filterKeys(_.startsWith(config.expiryPrefix)) + .filterKeys(_.split(":").last != config.namespace) + .values + .map(value => Timestamp.parseFrom(value).getSeconds * 1000) + .reduceOption(_ max _) + .getOrElse(0) + + val rowExpiry: Long = + if (config.maxAge > 0) + (row + .getAs[java.sql.Timestamp](config.timestampColumn) + .getTime + config.maxAge * 1000) + else MAX_EXPIRED_TIMESTAMP.getTime + + val maxExpiry = maxExpiryOtherFeatureTables max rowExpiry + new java.sql.Timestamp(maxExpiry) + + } } diff --git a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/SparkRedisConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/SparkRedisConfig.scala index 389607ce99..cac12a6c27 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/SparkRedisConfig.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/stores/redis/SparkRedisConfig.scala @@ -23,7 +23,9 @@ case class SparkRedisConfig( timestampColumn: String, iteratorGroupingSize: Int = 1000, timestampPrefix: String = "_ts", - repartitionByEntity: Boolean = true + repartitionByEntity: Boolean = true, + maxAge: Int = 0, + expiryPrefix: String = "_ex" ) object SparkRedisConfig { @@ -32,6 +34,7 @@ object SparkRedisConfig { val TS_COLUMN = "timestamp_column" val ENTITY_REPARTITION = "entity_repartition" val PROJECT_NAME = "project_name" + val MAX_AGE = "max_age" def parse(parameters: Map[String, String]): SparkRedisConfig = SparkRedisConfig( @@ -39,6 +42,7 @@ object SparkRedisConfig { projectName = parameters.getOrElse(PROJECT_NAME, "default"), entityColumns = parameters.getOrElse(ENTITY_COLUMNS, "").split(","), timestampColumn = parameters.getOrElse(TS_COLUMN, "event_timestamp"), - repartitionByEntity = parameters.getOrElse(ENTITY_REPARTITION, "true") == "true" + repartitionByEntity = parameters.getOrElse(ENTITY_REPARTITION, "true") == "true", + maxAge = parameters.get(MAX_AGE).map(_.toInt).getOrElse(0) ) } diff --git a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala index 6ccfe9ee34..70f8a1f718 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala @@ -17,9 +17,11 @@ package feast.ingestion import java.nio.file.Paths +import java.sql.Timestamp import collection.JavaConverters._ import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer} +import com.google.protobuf.util.Timestamps import feast.proto.types.ValueProto.ValueType import org.apache.spark.SparkConf import org.joda.time.{DateTime, Seconds} @@ -106,14 +108,253 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable) rows.foreach(r => { - val storedValues = jedis.hgetAll(encodeEntityKey(r, config.featureTable)).asScala.toMap + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap storedValues should beStoredRow( Map( featureKeyEncoder("feature1") -> r.feature1, featureKeyEncoder("feature2") -> r.feature2, - "_ts:test-fs" -> r.eventTimestamp + "_ts:test-fs" -> r.eventTimestamp, + "_ex:test-fs" -> new Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000) + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toInt + keyTTL shouldEqual -1 + + }) + } + + "Parquet source file" should "be ingested in redis with expiry time equal to the largest of (event_timestamp + max_age) for" + + "all feature tables associated with the entity" in new Scope { + val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay() + val endDate = new DateTime().withTimeAtStartOfDay() + val gen = rowGenerator(startDate, endDate) + val rows = generateDistinctRows(gen, 1000, groupByEntity) + val tempPath = storeAsParquet(sparkSession, rows) + val maxAge = 86400 * 2 + val configWithMaxAge = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy(maxAge = Some(maxAge)), + startTime = startDate, + endTime = endDate + ) + + val ingestionTimeUnix = System.currentTimeMillis() + BatchPipeline.createPipeline(sparkSession, configWithMaxAge) + + val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable) + + rows.foreach(r => { + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + val expectedExpiryTimestamp = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * maxAge) + storedValues should beStoredRow( + Map( + featureKeyEncoder("feature1") -> r.feature1, + featureKeyEncoder("feature2") -> r.feature2, + "_ts:test-fs" -> r.eventTimestamp, + "_ex:test-fs" -> expectedExpiryTimestamp + ) ) + val keyTTL = jedis.ttl(encodedEntityKey).toLong + keyTTL should (be <= (expectedExpiryTimestamp.getTime - ingestionTimeUnix) / 1000 and be > 0L) + + }) + + val increasedMaxAge = 86400 * 3 + val configWithSecondFeatureTable = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy( + name = "test-fs-2", + maxAge = Some(increasedMaxAge) + ), + startTime = startDate, + endTime = endDate ) + + val secondIngestionTimeUnix = System.currentTimeMillis() + BatchPipeline.createPipeline(sparkSession, configWithSecondFeatureTable) + + val featureKeyEncoderSecondTable: String => String = + encodeFeatureKey(configWithSecondFeatureTable.featureTable) + + rows.foreach(r => { + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + val expectedExpiryTimestamp1 = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * maxAge) + val expectedExpiryTimestamp2 = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * increasedMaxAge) + storedValues should beStoredRow( + Map( + featureKeyEncoder("feature1") -> r.feature1, + featureKeyEncoder("feature2") -> r.feature2, + featureKeyEncoderSecondTable("feature1") -> r.feature1, + featureKeyEncoderSecondTable("feature2") -> r.feature2, + "_ts:test-fs" -> r.eventTimestamp, + "_ts:test-fs-2" -> r.eventTimestamp, + "_ex:test-fs" -> expectedExpiryTimestamp1, + "_ex:test-fs-2" -> expectedExpiryTimestamp2 + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toLong + keyTTL should (be <= (expectedExpiryTimestamp2.getTime - secondIngestionTimeUnix) / 1000 and be > (expectedExpiryTimestamp1.getTime - secondIngestionTimeUnix) / 1000) + + }) + } + + "Redis key TTL" should "not be updated, when a second feature table associated with the same entity is registered and ingested, if (event_timestamp + max_age) of the second " + + "Feature Table is not later than the expiry timestamp of the first feature table" in new Scope { + val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay() + val endDate = new DateTime().withTimeAtStartOfDay() + val gen = rowGenerator(startDate, endDate) + val rows = generateDistinctRows(gen, 1000, groupByEntity) + val tempPath = storeAsParquet(sparkSession, rows) + val maxAge = 86400 * 3 + val configWithMaxAge = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy(maxAge = Some(maxAge)), + startTime = startDate, + endTime = endDate + ) + + val ingestionTimeUnix = System.currentTimeMillis() + BatchPipeline.createPipeline(sparkSession, configWithMaxAge) + + val reducedMaxAge = 86400 * 2 + val configWithSecondFeatureTable = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy( + name = "test-fs-2", + maxAge = Some(reducedMaxAge) + ), + startTime = startDate, + endTime = endDate + ) + + BatchPipeline.createPipeline(sparkSession, configWithSecondFeatureTable) + + val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable) + val featureKeyEncoderSecondTable: String => String = + encodeFeatureKey(configWithSecondFeatureTable.featureTable) + + rows.foreach(r => { + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + val expectedExpiryTimestamp1 = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * maxAge) + val expectedExpiryTimestamp2 = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * reducedMaxAge) + storedValues should beStoredRow( + Map( + featureKeyEncoder("feature1") -> r.feature1, + featureKeyEncoder("feature2") -> r.feature2, + featureKeyEncoderSecondTable("feature1") -> r.feature1, + featureKeyEncoderSecondTable("feature2") -> r.feature2, + "_ts:test-fs" -> r.eventTimestamp, + "_ts:test-fs-2" -> r.eventTimestamp, + "_ex:test-fs" -> expectedExpiryTimestamp1, + "_ex:test-fs-2" -> expectedExpiryTimestamp2 + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toLong + keyTTL should (be <= (expectedExpiryTimestamp1.getTime - ingestionTimeUnix) / 1000 and + be > (expectedExpiryTimestamp2.getTime - ingestionTimeUnix) / 1000) + + }) + } + + "Redis key TTL" should "be updated, when the same feature table is re-ingested, with a smaller max age" in new Scope { + val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay() + val endDate = new DateTime().withTimeAtStartOfDay() + val gen = rowGenerator(startDate, endDate) + val rows = generateDistinctRows(gen, 1000, groupByEntity) + val tempPath = storeAsParquet(sparkSession, rows) + val maxAge = 86400 * 3 + val configWithMaxAge = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy(maxAge = Some(maxAge)), + startTime = startDate, + endTime = endDate + ) + + val ingestionTimeUnix = System.currentTimeMillis() + BatchPipeline.createPipeline(sparkSession, configWithMaxAge) + + val reducedMaxAge = 86400 * 2 + val configWithUpdatedFeatureTable = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy( + maxAge = Some(reducedMaxAge) + ), + startTime = startDate, + endTime = endDate + ) + + BatchPipeline.createPipeline(sparkSession, configWithUpdatedFeatureTable) + + val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable) + + rows.foreach(r => { + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + val expiryTimestampAfterUpdate = + new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * reducedMaxAge) + storedValues should beStoredRow( + Map( + featureKeyEncoder("feature1") -> r.feature1, + featureKeyEncoder("feature2") -> r.feature2, + "_ts:test-fs" -> r.eventTimestamp, + "_ex:test-fs" -> expiryTimestampAfterUpdate + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toLong + keyTTL should (be <= (expiryTimestampAfterUpdate.getTime - ingestionTimeUnix) / 1000 and be > 0L) + + }) + } + + "Redis key TTL" should "be removed, when the same feature table is re-ingested without max age" in new Scope { + val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay() + val endDate = new DateTime().withTimeAtStartOfDay() + val gen = rowGenerator(startDate, endDate) + val rows = generateDistinctRows(gen, 1000, groupByEntity) + val tempPath = storeAsParquet(sparkSession, rows) + val maxAge = 86400 * 3 + val configWithMaxAge = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + featureTable = config.featureTable.copy(maxAge = Some(maxAge)), + startTime = startDate, + endTime = endDate + ) + + BatchPipeline.createPipeline(sparkSession, configWithMaxAge) + + val configWithoutMaxAge = config.copy( + source = FileSource(tempPath, Map.empty, "eventTimestamp"), + startTime = startDate, + endTime = endDate + ) + + BatchPipeline.createPipeline(sparkSession, configWithoutMaxAge) + + val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable) + + rows.foreach(r => { + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + storedValues should beStoredRow( + Map( + featureKeyEncoder("feature1") -> r.feature1, + featureKeyEncoder("feature2") -> r.feature2, + "_ts:test-fs" -> r.eventTimestamp, + "_ex:test-fs" -> new Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000) + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toInt + keyTTL shouldEqual -1 + }) } diff --git a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala index 39c44ad55a..20e41dc17b 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala @@ -17,6 +17,7 @@ package feast.ingestion import java.nio.file.Paths +import java.sql import java.util.Properties import com.dimafeng.testcontainers.{ @@ -30,6 +31,7 @@ import org.apache.spark.SparkConf import org.joda.time.DateTime import org.apache.kafka.clients.producer._ import com.example.protos.{AllTypesMessage, InnerMessage, TestMessage, VehicleType} +import com.google.protobuf.util.Timestamps import com.google.protobuf.{AbstractMessage, ByteString, Timestamp} import org.scalacheck.Gen import redis.clients.jedis.Jedis @@ -39,10 +41,8 @@ import feast.ingestion.helpers.RedisStorageHelper._ import feast.ingestion.helpers.DataHelper._ import feast.proto.storage.RedisProto.RedisKeyV2 import feast.proto.types.ValueProto -import org.apache.spark.sql.Row import org.apache.spark.sql.avro.to_avro import org.apache.spark.sql.functions.{col, struct} -import org.apache.spark.sql.types.StructType class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { val redisContainer = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379)) @@ -136,16 +136,86 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { query.processAllAvailable() rows.foreach { r => - val storedValues = jedis.hgetAll(encodeEntityKey(r, config.featureTable)).asScala.toMap + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap storedValues should beStoredRow( Map( featureKeyEncoder("unique_drivers") -> r.getUniqueDrivers, - "_ts:driver-fs" -> new java.sql.Timestamp(r.getEventTimestamp.getSeconds * 1000) + "_ts:driver-fs" -> new java.sql.Timestamp(r.getEventTimestamp.getSeconds * 1000), + "_ex:driver-fs" -> new java.sql.Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000) ) ) + val keyTTL = jedis.ttl(encodedEntityKey).toInt + keyTTL shouldEqual -1 } } + "Streaming pipeline" should "store messages from kafka to redis with expiry time equal to the largest of (event_timestamp + max_age) for all feature " + + "tables associated with the entity" in new Scope { + val maxAge = 86400 + val configWithMaxAge = config.copy( + source = kafkaSource, + featureTable = config.featureTable.copy(maxAge = Some(maxAge)) + ) + val query = StreamingPipeline.createPipeline(sparkSession, configWithMaxAge).get + query.processAllAvailable() // to init kafka consumer + + val rows = generateDistinctRows(rowGenerator, 100, groupByEntity) + + val ingestionTimeUnix = System.currentTimeMillis() + rows.foreach(sendToKafka(kafkaSource.topic, _)) + + query.processAllAvailable() + + rows.foreach { r => + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + storedValues should beStoredRow( + Map( + featureKeyEncoder("unique_drivers") -> r.getUniqueDrivers, + "_ts:driver-fs" -> new java.sql.Timestamp(r.getEventTimestamp.getSeconds * 1000), + "_ex:driver-fs" -> new java.sql.Timestamp( + (r.getEventTimestamp.getSeconds + maxAge) * 1000 + ) + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toLong + keyTTL should (be <= (r.getEventTimestamp.getSeconds + maxAge - ingestionTimeUnix / 1000) and be > 0L) + } + + val kafkaSourceSecondFeatureTable = kafkaSource.copy(topic = "topic-2") + val configWithSecondFeatureTable = config.copy( + source = kafkaSourceSecondFeatureTable, + featureTable = config.featureTable.copy(name = "driver-fs-2") + ) + val querySecondFeatureTable = + StreamingPipeline.createPipeline(sparkSession, configWithSecondFeatureTable).get + querySecondFeatureTable.processAllAvailable() // to init kafka consumer + rows.foreach(sendToKafka(kafkaSourceSecondFeatureTable.topic, _)) + querySecondFeatureTable.processAllAvailable() + + val featureKeyEncoderSecondFeatureTable: String => String = + encodeFeatureKey(configWithSecondFeatureTable.featureTable) + rows.foreach { r => + val encodedEntityKey = encodeEntityKey(r, config.featureTable) + val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap + storedValues should beStoredRow( + Map( + featureKeyEncoder("unique_drivers") -> r.getUniqueDrivers, + featureKeyEncoderSecondFeatureTable("unique_drivers") -> r.getUniqueDrivers, + "_ts:driver-fs" -> new java.sql.Timestamp(r.getEventTimestamp.getSeconds * 1000), + "_ex:driver-fs" -> new java.sql.Timestamp( + (r.getEventTimestamp.getSeconds + maxAge) * 1000 + ), + "_ex:driver-fs-2" -> new java.sql.Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000) + ) + ) + val keyTTL = jedis.ttl(encodedEntityKey).toInt + keyTTL shouldEqual -1 + } + + } + "Streaming pipeline" should "store invalid proto messages to deadletter path" in new Scope { val configWithDeadletter = config.copy( source = kafkaSource, diff --git a/spark/ingestion/src/test/scala/feast/ingestion/helpers/RedisStorageHelper.scala b/spark/ingestion/src/test/scala/feast/ingestion/helpers/RedisStorageHelper.scala index 921d65d477..d15126d9d3 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/helpers/RedisStorageHelper.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/helpers/RedisStorageHelper.scala @@ -38,14 +38,15 @@ object RedisStorageHelper { m compose { (_: Map[Array[Byte], Array[Byte]]) - .map { case (k, v) => - if (k.length == 4) + .map { + case (k, v) if k.length == 4 => ( ByteBuffer.wrap(k).order(ByteOrder.LITTLE_ENDIAN).getInt.toHexString, ValueProto.Value.parseFrom(v).asScala ) - else + case (k, v) if k.startsWith("_ts".getBytes) || k.startsWith("_ex".getBytes) => (new String(k), Timestamp.parseFrom(v).asScala) + case (k, v) => (new String(k), ValueProto.Value.parseFrom(v).asScala) } } }