From ba78d7c922d381edc995b7782723a6835d140fee Mon Sep 17 00:00:00 2001 From: Sari Nusier Date: Thu, 26 Nov 2020 10:11:33 +0000 Subject: [PATCH 1/3] Added the option to set ttl from a column. - If ttl.column is set, the connector will use the value in that column to set the ttl of the record. - The ttl column will not be included in the fields of the record. - The ability to use a constant ttl is still there, but ttl.column takes precedence when both are set. --- .../apache/spark/sql/redis/BinaryRedisPersistence.scala | 2 +- .../apache/spark/sql/redis/HashRedisPersistence.scala | 9 ++++++++- .../org/apache/spark/sql/redis/RedisPersistence.scala | 3 ++- .../org/apache/spark/sql/redis/RedisSourceRelation.scala | 6 ++++-- src/main/scala/org/apache/spark/sql/redis/redis.scala | 1 + 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala index c9b0a981..89503d80 100644 --- a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala @@ -25,7 +25,7 @@ class BinaryRedisPersistence extends RedisPersistence[Array[Byte]] { override def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit = pipeline.get(key.getBytes(UTF_8)) - override def encodeRow(keyName: String, value: Row): Array[Byte] = { + override def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): Array[Byte] = { val fields = value.schema.fields.map(_.name) val valuesArray = fields.map(f => value.getAs[Any](f)) SerializationUtils.serialize(valuesArray) diff --git a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala index a22c4614..59e75da7 100644 --- a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala @@ -27,7 +27,7 @@ class HashRedisPersistence extends RedisPersistence[Any] { pipeline.hmget(key, requiredColumns: _*) } - override def encodeRow(keyName: String, value: Row): Map[String, String] = { + override def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): Map[String, String] = { val fields = value.schema.fields.map(_.name) val kvMap = value.getValuesMap[Any](fields) kvMap @@ -39,6 +39,13 @@ class HashRedisPersistence extends RedisPersistence[Any] { // don't store key values k != keyName } + .filter { case (k, _) => + // don't store TTLs + ttlName match { + case Some(ttl) => k != ttl + case None => true + } + } .map { case (k, v) => k -> String.valueOf(v) } diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala index d69eef66..fab86855 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala @@ -18,9 +18,10 @@ trait RedisPersistence[T] extends Serializable { * * @param keyName field name that should be encoded in special way, e.g. in Redis keys. * @param value row to encode. + * @param ttlName field name to be used for setting the ttl and not added as a value * @return encoded row */ - def encodeRow(keyName: String, value: Row): T + def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): T /** * Decode dataframe row stored in Redis. diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index f2c84911..dd399647 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -78,6 +78,7 @@ class RedisSourceRelation(override val sqlContext: SQLContext, private val persistence = RedisPersistence(persistenceModel) private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName) private val ttl = parameters.get(SqlOptionTTL).map(_.toInt).getOrElse(0) + private val ttlColumn: Option[String] = parameters.get(SqlOptionTTLColumn) /** * redis key pattern for rows, based either on the 'keys.pattern' or 'table' parameter @@ -142,8 +143,9 @@ class RedisSourceRelation(override val sqlContext: SQLContext, val conn = node.connect() foreachWithPipeline(conn, keys) { (pipeline, key) => val row = rowsWithKey(key) - val encodedRow = persistence.encodeRow(keyName, row) - persistence.save(pipeline, key, encodedRow, ttl) + val encodedRow = persistence.encodeRow(keyName, row, ttlColumn) + val recordTTL = if (ttlColumn.isEmpty) ttl else row.getAs[Int](ttlColumn.get) + persistence.save(pipeline, key, encodedRow, recordTTL) } conn.close() } diff --git a/src/main/scala/org/apache/spark/sql/redis/redis.scala b/src/main/scala/org/apache/spark/sql/redis/redis.scala index 82b1c1d7..8499fd5d 100644 --- a/src/main/scala/org/apache/spark/sql/redis/redis.scala +++ b/src/main/scala/org/apache/spark/sql/redis/redis.scala @@ -21,6 +21,7 @@ package object redis { val SqlOptionInferSchema = "infer.schema" val SqlOptionKeyColumn = "key.column" val SqlOptionTTL = "ttl" + val SqlOptionTTLColumn = "ttl.column" val SqlOptionMaxPipelineSize = "max.pipeline.size" val SqlOptionScanCount = "scan.count" From 97bb7f9b6c019050d24f3d4291e1cf6ed3390849 Mon Sep 17 00:00:00 2001 From: Sari Nusier Date: Fri, 27 Nov 2020 12:20:34 +0000 Subject: [PATCH 2/3] Renaming ttl column --- .../org/apache/spark/sql/redis/BinaryRedisPersistence.scala | 2 +- .../org/apache/spark/sql/redis/HashRedisPersistence.scala | 4 ++-- .../scala/org/apache/spark/sql/redis/RedisPersistence.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala index 89503d80..2f666693 100644 --- a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala @@ -25,7 +25,7 @@ class BinaryRedisPersistence extends RedisPersistence[Array[Byte]] { override def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit = pipeline.get(key.getBytes(UTF_8)) - override def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): Array[Byte] = { + override def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): Array[Byte] = { val fields = value.schema.fields.map(_.name) val valuesArray = fields.map(f => value.getAs[Any](f)) SerializationUtils.serialize(valuesArray) diff --git a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala index 59e75da7..6743c791 100644 --- a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala @@ -27,7 +27,7 @@ class HashRedisPersistence extends RedisPersistence[Any] { pipeline.hmget(key, requiredColumns: _*) } - override def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): Map[String, String] = { + override def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): Map[String, String] = { val fields = value.schema.fields.map(_.name) val kvMap = value.getValuesMap[Any](fields) kvMap @@ -41,7 +41,7 @@ class HashRedisPersistence extends RedisPersistence[Any] { } .filter { case (k, _) => // don't store TTLs - ttlName match { + ttlColumn match { case Some(ttl) => k != ttl case None => true } diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala index fab86855..bf5f1c58 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala @@ -18,10 +18,10 @@ trait RedisPersistence[T] extends Serializable { * * @param keyName field name that should be encoded in special way, e.g. in Redis keys. * @param value row to encode. - * @param ttlName field name to be used for setting the ttl and not added as a value + * @param ttlColumn field name to be used for setting the ttl and not added as a value * @return encoded row */ - def encodeRow(keyName: String, value: Row, ttlName: Option[String] = None): T + def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): T /** * Decode dataframe row stored in Redis. From a8bacdbcdd280f37b417592b751af0deca4a19d1 Mon Sep 17 00:00:00 2001 From: Sari Nusier Date: Fri, 27 Nov 2020 14:47:33 +0000 Subject: [PATCH 3/3] Checking if both ttl and ttl.column are set --- .../org/apache/spark/sql/redis/RedisSourceRelation.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index dd399647..dc8e97a2 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -107,6 +107,12 @@ class RedisSourceRelation(override val sqlContext: SQLContext, s"You should only use either one.") } + // check if both ttl column and ttl are set + if (ttlColumn.isDefined && ttl > 0) { + throw new IllegalArgumentException(s"Both options '$SqlOptionTTL' and '$SqlOptionTTLColumn' are set. " + + s"You should only use either one.") + } + override def schema: StructType = { if (currentSchema == null) { currentSchema = userSpecifiedSchema.getOrElse {