diff --git a/exchange-common/src/main/scala/com/vesoft/exchange/common/config/Configs.scala b/exchange-common/src/main/scala/com/vesoft/exchange/common/config/Configs.scala index 4f7e481c..0f82bac6 100644 --- a/exchange-common/src/main/scala/com/vesoft/exchange/common/config/Configs.scala +++ b/exchange-common/src/main/scala/com/vesoft/exchange/common/config/Configs.scala @@ -16,6 +16,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataInputStream, FileSystem, Path} import org.apache.log4j.Logger +import scala.collection.JavaConversions.asScalaBuffer import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.collection.JavaConverters._ @@ -166,6 +167,12 @@ case class CaSignParam(caCrtFilePath: String, crtFilePath: String, keyFilePath: case class SelfSignParam(crtFilePath: String, keyFilePath: String, password: String) +case class UdfConfigEntry(sep: String, oldColNames: List[String], newColName: String) { + override def toString(): String = { + s"sep:$sep, oldColNames: $oldColNames, newColName: $newColName" + } +} + /** * */ @@ -431,6 +438,13 @@ object Configs { val repartitionWithNebula = getOrElse(tagConfig, "repartitionWithNebula", true) val ignoreIndex = getOrElse(tagConfig, "ignoreIndex", false) + val vertexUdf = if (tagConfig.hasPath("vertex.udf")) { + val sep = tagConfig.getString("vertex.udf.separator") + val cols: List[String] = tagConfig.getStringList("vertex.udf.oldColNames").toList + val newCol = tagConfig.getString("vertex.udf.newColName") + Some(UdfConfigEntry(sep, cols, newCol)) + } else None + LOG.info(s"name ${tagName} batch ${batch}") val entry = TagConfigEntry( tagName, @@ -445,7 +459,8 @@ object Configs { checkPointPath, repartitionWithNebula, enableTagless, - ignoreIndex + ignoreIndex, + vertexUdf ) LOG.info(s"Tag Config: ${entry}") tags += entry @@ -553,6 +568,20 @@ object Configs { val repartitionWithNebula = getOrElse(edgeConfig, "repartitionWithNebula", false) val ignoreIndex = getOrElse(edgeConfig, "ignoreIndex", false) + val srcUdf = if (edgeConfig.hasPath("source.udf")) { + val sep = edgeConfig.getString("source.udf.separator") + val cols: List[String] = edgeConfig.getStringList("source.udf.oldColNames").toList + val newCol = edgeConfig.getString("source.udf.newColName") + Some(UdfConfigEntry(sep, cols, newCol)) + } else None + + val dstUdf = if (edgeConfig.hasPath("target.udf")) { + val sep = edgeConfig.getString("target.udf.separator") + val cols: List[String] = edgeConfig.getStringList("target.udf.oldColNames").toList + val newCol = edgeConfig.getString("target.udf.newColName") + Some(UdfConfigEntry(sep, cols, newCol)) + } else None + val entry = EdgeConfigEntry( edgeName, sourceConfig, @@ -571,7 +600,9 @@ object Configs { partition, checkPointPath, repartitionWithNebula, - ignoreIndex + ignoreIndex, + srcUdf, + dstUdf ) LOG.info(s"Edge Config: ${entry}") edges += entry diff --git a/exchange-common/src/main/scala/com/vesoft/exchange/common/config/SchemaConfigs.scala b/exchange-common/src/main/scala/com/vesoft/exchange/common/config/SchemaConfigs.scala index 8eefaa56..99d1e7b0 100644 --- a/exchange-common/src/main/scala/com/vesoft/exchange/common/config/SchemaConfigs.scala +++ b/exchange-common/src/main/scala/com/vesoft/exchange/common/config/SchemaConfigs.scala @@ -62,7 +62,8 @@ case class TagConfigEntry(override val name: String, override val checkPointPath: Option[String], repartitionWithNebula: Boolean = true, enableTagless: Boolean = false, - ignoreIndex: Boolean = false) + ignoreIndex: Boolean = false, + vertexUdf: Option[UdfConfigEntry] = None) extends SchemaConfigEntry { require( name.trim.nonEmpty && vertexField.trim.nonEmpty @@ -77,7 +78,9 @@ case class TagConfigEntry(override val name: String, s"batch: $batch, " + s"partition: $partition, " + s"repartitionWithNebula: $repartitionWithNebula, " + - s"enableTagless: $enableTagless." + s"enableTagless: $enableTagless, " + + s"ignoreIndex: $ignoreIndex, " + + s"vertexUdf: $vertexUdf." } } @@ -117,7 +120,9 @@ case class EdgeConfigEntry(override val name: String, override val partition: Int, override val checkPointPath: Option[String], repartitionWithNebula: Boolean = false, - ignoreIndex: Boolean = false) + ignoreIndex: Boolean = false, + srcVertexUdf: Option[UdfConfigEntry] = None, + dstVertexUdf: Option[UdfConfigEntry] = None) extends SchemaConfigEntry { require( name.trim.nonEmpty && sourceField.trim.nonEmpty && targetField.trim.nonEmpty @@ -136,7 +141,10 @@ case class EdgeConfigEntry(override val name: String, s"target field: $targetField, " + s"target policy: $targetPolicy, " + s"batch: $batch, " + - s"partition: $partition." + s"partition: $partition, " + + s"ignoreIndex: $ignoreIndex, " + + s"srcVertexUdf: $srcVertexUdf" + + s"dstVertexUdf: $dstVertexUdf." } else { s"Edge name: $name, " + s"source: $dataSourceConfigEntry, " + @@ -147,7 +155,10 @@ case class EdgeConfigEntry(override val name: String, s"target field: $targetField, " + s"target policy: $targetPolicy, " + s"batch: $batch, " + - s"partition: $partition." + s"partition: $partition, " + + s"ignoreIndex: $ignoreIndex, " + + s"srcVertexUdf: $srcVertexUdf" + + s"dstVertexUdf: $dstVertexUdf." } } } diff --git a/nebula-exchange_spark_2.2/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala b/nebula-exchange_spark_2.2/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala index 44bcf0b3..7a217a81 100644 --- a/nebula-exchange_spark_2.2/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala +++ b/nebula-exchange_spark_2.2/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala @@ -5,7 +5,7 @@ package com.vesoft.nebula.exchange -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import java.io.File import com.vesoft.exchange.Argument @@ -27,7 +27,8 @@ import com.vesoft.exchange.common.config.{ PostgreSQLSourceConfigEntry, PulsarSourceConfigEntry, SinkCategory, - SourceCategory + SourceCategory, + UdfConfigEntry } import com.vesoft.nebula.exchange.reader.{ CSVReader, @@ -51,8 +52,12 @@ import com.vesoft.exchange.common.processor.ReloadProcessor import com.vesoft.exchange.common.utils.SparkValidate import com.vesoft.nebula.exchange.processor.{EdgeProcessor, VerticesProcessor} import org.apache.log4j.Logger +import org.apache.spark.sql.functions.{col, concat_ws} +import org.apache.spark.sql.types.StringType import org.apache.spark.{SparkConf, SparkEnv} +import scala.collection.mutable.ListBuffer + final case class TooManyErrorsException(private val message: String) extends Exception(message) /** @@ -142,8 +147,13 @@ object Exchange { val fields = tagConfig.vertexField :: tagConfig.fields val data = createDataSource(spark, tagConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + val df = if (tagConfig.vertexUdf.isDefined) { + dataUdf(data.get, tagConfig.vertexUdf.get) + } else { + data.get + } + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagConfig.name}") @@ -152,7 +162,7 @@ object Exchange { val processor = new VerticesProcessor( spark, - repartition(data.get, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), + repartition(df, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), tagConfig, fieldKeys, nebulaKeys, @@ -161,7 +171,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info( s"import for tag ${tagConfig.name}: data total count: $count, total time: ${costTime}s") @@ -195,15 +205,23 @@ object Exchange { } val data = createDataSource(spark, edgeConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + var df = data.get + if (edgeConfig.srcVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.srcVertexUdf.get) + } + if (edgeConfig.dstVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.dstVertexUdf.get) + } + + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeConfig.name}") val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}") val processor = new EdgeProcessor( spark, - repartition(data.get, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), + repartition(df, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), edgeConfig, fieldKeys, nebulaKeys, @@ -212,7 +230,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info( s"import for edge ${edgeConfig.name}: data total count: $count, total time: ${costTime}s") @@ -363,4 +381,17 @@ object Exchange { frame } } + + private[this] def dataUdf(data: DataFrame, udfConfig: UdfConfigEntry): DataFrame = { + val oldCols = udfConfig.oldColNames + val sep = udfConfig.sep + val newCol = udfConfig.newColName + val originalFieldsNames = data.schema.fieldNames.toList + val finalColNames: ListBuffer[Column] = new ListBuffer[Column] + for (field <- originalFieldsNames) { + finalColNames.append(col(field)) + } + finalColNames.append(concat_ws(sep, oldCols.map(c => col(c)): _*).cast(StringType).as(newCol)) + data.select(finalColNames: _*) + } } diff --git a/nebula-exchange_spark_2.4/src/main/resources/application.conf b/nebula-exchange_spark_2.4/src/main/resources/application.conf index 131f0cf7..bbb8ddc2 100644 --- a/nebula-exchange_spark_2.4/src/main/resources/application.conf +++ b/nebula-exchange_spark_2.4/src/main/resources/application.conf @@ -97,10 +97,16 @@ sink: client } path: hdfs tag path 0 + fields: [parquet-field-0, parquet-field-1, parquet-field-2] nebula.fields: [nebula-field-0, nebula-field-1, nebula-field-2] vertex: { - field:parquet-field-0 + field:new-parquet-field + udf:{ + separator:"_" + oldColNames:[parquet-field-0] + newColNames:[new-parquet-field] + } #policy:hash } batch: 2000 @@ -367,10 +373,20 @@ nebula.fields: [nebula-field-0 nebula-field-1 nebula-field-2] source: { field:parquet-field-0 + udf:{ + separator:"_" + oldColNames:[parquet-field-0] + newColName:[new-parquet-field] + } #policy:hash } target: { field:parquet-field-1 + udf:{ + separator:"_" + oldColNames:[parquet-field-0] + newColName:[new-parquet-field] + } #policy:hash } ranking: parquet-field-2 diff --git a/nebula-exchange_spark_2.4/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala b/nebula-exchange_spark_2.4/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala index 073f8045..f6757f4c 100644 --- a/nebula-exchange_spark_2.4/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala +++ b/nebula-exchange_spark_2.4/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala @@ -5,7 +5,7 @@ package com.vesoft.nebula.exchange -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import java.io.File import com.vesoft.exchange.Argument @@ -27,7 +27,8 @@ import com.vesoft.exchange.common.config.{ PostgreSQLSourceConfigEntry, PulsarSourceConfigEntry, SinkCategory, - SourceCategory + SourceCategory, + UdfConfigEntry } import com.vesoft.nebula.exchange.reader.{ CSVReader, @@ -51,8 +52,12 @@ import com.vesoft.exchange.common.processor.ReloadProcessor import com.vesoft.exchange.common.utils.SparkValidate import com.vesoft.nebula.exchange.processor.{EdgeProcessor, VerticesProcessor} import org.apache.log4j.Logger +import org.apache.spark.sql.functions.{col, concat_ws} +import org.apache.spark.sql.types.StringType import org.apache.spark.{SparkConf, SparkEnv} +import scala.collection.mutable.ListBuffer + final case class TooManyErrorsException(private val message: String) extends Exception(message) /** @@ -142,8 +147,13 @@ object Exchange { val fields = tagConfig.vertexField :: tagConfig.fields val data = createDataSource(spark, tagConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + val df = if (tagConfig.vertexUdf.isDefined) { + dataUdf(data.get, tagConfig.vertexUdf.get) + } else { + data.get + } + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagConfig.name}") @@ -152,7 +162,7 @@ object Exchange { val processor = new VerticesProcessor( spark, - repartition(data.get, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), + repartition(df, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), tagConfig, fieldKeys, nebulaKeys, @@ -161,7 +171,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info(s"import for tag ${tagConfig.name}, data count: $count, cost time: ${costTime}s") if (tagConfig.dataSinkConfigEntry.category == SinkCategory.CLIENT) { @@ -194,15 +204,23 @@ object Exchange { } val data = createDataSource(spark, edgeConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + var df = data.get + if (edgeConfig.srcVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.srcVertexUdf.get) + } + if (edgeConfig.dstVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.dstVertexUdf.get) + } + + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeConfig.name}") val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}") val processor = new EdgeProcessor( spark, - repartition(data.get, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), + repartition(df, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), edgeConfig, fieldKeys, nebulaKeys, @@ -211,7 +229,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info( s"import for edge ${edgeConfig.name}, data count: $count, cost time: ${costTime}s") @@ -362,4 +380,17 @@ object Exchange { frame } } + + private[this] def dataUdf(data: DataFrame, udfConfig: UdfConfigEntry): DataFrame = { + val oldCols = udfConfig.oldColNames + val sep = udfConfig.sep + val newCol = udfConfig.newColName + val originalFieldsNames = data.schema.fieldNames.toList + val finalColNames: ListBuffer[Column] = new ListBuffer[Column] + for (field <- originalFieldsNames) { + finalColNames.append(col(field)) + } + finalColNames.append(concat_ws(sep, oldCols.map(c => col(c)): _*).cast(StringType).as(newCol)) + data.select(finalColNames: _*) + } } diff --git a/nebula-exchange_spark_3.0/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala b/nebula-exchange_spark_3.0/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala index 58143415..6ee1917d 100644 --- a/nebula-exchange_spark_3.0/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala +++ b/nebula-exchange_spark_3.0/src/main/scala/com/vesoft/nebula/exchange/Exchange.scala @@ -5,7 +5,7 @@ package com.vesoft.nebula.exchange -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import java.io.File import com.vesoft.exchange.Argument @@ -27,7 +27,8 @@ import com.vesoft.exchange.common.config.{ PostgreSQLSourceConfigEntry, PulsarSourceConfigEntry, SinkCategory, - SourceCategory + SourceCategory, + UdfConfigEntry } import com.vesoft.nebula.exchange.reader.{ CSVReader, @@ -51,8 +52,12 @@ import com.vesoft.exchange.common.processor.ReloadProcessor import com.vesoft.exchange.common.utils.SparkValidate import com.vesoft.nebula.exchange.processor.{EdgeProcessor, VerticesProcessor} import org.apache.log4j.Logger +import org.apache.spark.sql.functions.{col, concat_ws} +import org.apache.spark.sql.types.StringType import org.apache.spark.{SparkConf, SparkEnv} +import scala.collection.mutable.ListBuffer + final case class TooManyErrorsException(private val message: String) extends Exception(message) /** @@ -142,8 +147,13 @@ object Exchange { val fields = tagConfig.vertexField :: tagConfig.fields val data = createDataSource(spark, tagConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + val df = if (tagConfig.vertexUdf.isDefined) { + dataUdf(data.get, tagConfig.vertexUdf.get) + } else { + data.get + } + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagConfig.name}") @@ -152,7 +162,7 @@ object Exchange { val processor = new VerticesProcessor( spark, - repartition(data.get, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), + repartition(df, tagConfig.partition, tagConfig.dataSourceConfigEntry.category), tagConfig, fieldKeys, nebulaKeys, @@ -161,7 +171,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info(s"import for tag ${tagConfig.name}, data count: $count, cost time: ${costTime}s") if (tagConfig.dataSinkConfigEntry.category == SinkCategory.CLIENT) { @@ -194,15 +204,23 @@ object Exchange { } val data = createDataSource(spark, edgeConfig.dataSourceConfigEntry, fields) if (data.isDefined && !c.dry) { - data.get.cache() - val count = data.get.count() + var df = data.get + if (edgeConfig.srcVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.srcVertexUdf.get) + } + if (edgeConfig.dstVertexUdf.isDefined) { + df = dataUdf(df, edgeConfig.dstVertexUdf.get) + } + + df.cache() + val count = df.count() val startTime = System.currentTimeMillis() val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeConfig.name}") val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}") val processor = new EdgeProcessor( spark, - repartition(data.get, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), + repartition(df, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category), edgeConfig, fieldKeys, nebulaKeys, @@ -211,7 +229,7 @@ object Exchange { batchFailure ) processor.process() - data.get.unpersist() + df.unpersist() val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f") LOG.info( s"import for edge ${edgeConfig.name}, data count: $count, cost time: ${costTime}s") @@ -362,4 +380,17 @@ object Exchange { frame } } + + private[this] def dataUdf(data: DataFrame, udfConfig: UdfConfigEntry): DataFrame = { + val oldCols = udfConfig.oldColNames + val sep = udfConfig.sep + val newCol = udfConfig.newColName + val originalFieldsNames = data.schema.fieldNames.toList + val finalColNames: ListBuffer[Column] = new ListBuffer[Column] + for (field <- originalFieldsNames) { + finalColNames.append(col(field)) + } + finalColNames.append(concat_ws(sep, oldCols.map(c => col(c)): _*).cast(StringType).as(newCol)) + data.select(finalColNames: _*) + } }