From c87f1b1338eb49cc5f36cad200458b530e288903 Mon Sep 17 00:00:00 2001 From: yaphet Date: Thu, 2 Jan 2020 15:22:34 +0800 Subject: [PATCH] Use async client in Spark writer (#1405) --- src/tools/spark-sstfile-generator/.gitignore | 1 + src/tools/spark-sstfile-generator/pom.xml | 242 ++++++++++++++-- .../src/main/resources/application.conf | 14 +- .../generator/v2/SparkClientGenerator.scala | 260 ++++++++++++++---- 4 files changed, 445 insertions(+), 72 deletions(-) diff --git a/src/tools/spark-sstfile-generator/.gitignore b/src/tools/spark-sstfile-generator/.gitignore index 916e17c097a..98316955f9d 100644 --- a/src/tools/spark-sstfile-generator/.gitignore +++ b/src/tools/spark-sstfile-generator/.gitignore @@ -1 +1,2 @@ dependency-reduced-pom.xml +*.iml diff --git a/src/tools/spark-sstfile-generator/pom.xml b/src/tools/spark-sstfile-generator/pom.xml index 9ca9a98e3fa..36f3b0df1f4 100644 --- a/src/tools/spark-sstfile-generator/pom.xml +++ b/src/tools/spark-sstfile-generator/pom.xml @@ -6,7 +6,7 @@ com.vesoft sst.generator - 1.0.0-beta + 1.0.0-rc2 1.8 @@ -21,8 +21,7 @@ 1.4.0 3.9.2 3.7.1 - 1.4.0 - 1.0.0-beta + 1.0.0-rc2 1.0.0 @@ -49,11 +48,11 @@ - com/vesoft/tools/** + com/vesoft/tools/** + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA @@ -80,6 +79,20 @@ shade + + + org.apache.spark:* + org.apache.hadoop:* + org.apache.hive:* + log4j:log4j + org.apache.orc:* + xml-apis:xml-apis + javax.inject:javax.inject + org.spark-project.hive:hive-exec + stax:stax-api + org.glassfish.hk2.external:aopalliance-repackaged + + *:* @@ -104,26 +117,212 @@ org.apache.spark spark-core_2.11 ${spark.version} + + + snappy-java + org.xerial.snappy + + + paranamer + com.thoughtworks.paranamer + + + slf4j-api + org.slf4j + + + commons-codec + commons-codec + + + avro + org.apache.avro + + + commons-lang + commons-lang + + + commons-collections + commons-collections + + + commons-compress + org.apache.commons + + + commons-math3 + org.apache.commons + + + guava + com.google.guava + + + httpclient + org.apache.httpcomponents + + + slf4j-log4j12 + org.slf4j + + + netty + io.netty + + + jackson-annotations + com.fasterxml.jackson.core + + + scala-reflect + org.scala-lang + + + scala-library + org.scala-lang + + + jackson-databind + com.fasterxml.jackson.core + + + scala-xml_2.11 + org.scala-lang.modules + + + log4j + log4j + + org.apache.spark spark-sql_2.11 ${spark.version} + + + snappy-java + org.xerial.snappy + + + jsr305 + com.google.code.findbugs + + + slf4j-api + org.slf4j + + + jackson-core + com.fasterxml.jackson.core + + + joda-time + joda-time + + + commons-codec + commons-codec + + + snappy-java + org.xerial.snappy + + org.apache.spark spark-hive_2.11 ${spark.version} + + + commons-codec + commons-codec + + + commons-logging + commons-logging + + + avro + org.apache.avro + + + commons-compress + org.apache.commons + + + commons-lang3 + org.apache.commons + + + jackson-mapper-asl + org.codehaus.jackson + + + antlr-runtime + org.antlr + + + jackson-core-asl + org.codehaus.jackson + + + derby + org.apache.derby + + + httpclient + org.apache.httpcomponents + + + httpcore + org.apache.httpcomponents + + org.apache.spark spark-yarn_2.11 ${spark.version} + + + guava + com.google.guava + + + commons-codec + commons-codec + + + commons-compress + org.apache.commons + + + activation + javax.activation + + + slf4j-api + org.slf4j + + com.databricks spark-csv_2.11 1.5.0 + + + scala-library + org.scala-lang + + + univocity-parsers + com.univocity + + org.scalatest @@ -145,22 +344,26 @@ com.typesafe.scala-logging scala-logging_2.11 ${scala-logging.version} + + + scala-library + org.scala-lang + + + scala-reflect + org.scala-lang + + + slf4j-api + org.slf4j + + com.github.scopt scopt_2.11 ${scopt.version} - - com.typesafe - config - ${config.version} - - - com.vesoft - client - ${nebula.version} - mysql mysql-connector-java @@ -171,5 +374,10 @@ s2-geometry-library-java ${s2.version} + + com.vesoft + client + ${nebula.version} + diff --git a/src/tools/spark-sstfile-generator/src/main/resources/application.conf b/src/tools/spark-sstfile-generator/src/main/resources/application.conf index 86fcde30df5..dd2f00d7f2e 100644 --- a/src/tools/spark-sstfile-generator/src/main/resources/application.conf +++ b/src/tools/spark-sstfile-generator/src/main/resources/application.conf @@ -56,6 +56,10 @@ hive-field-1: nebula-field-1, hive-field-2: nebula-field-2 } + vertex: { + field: hive-field-0 + policy: "hash" + } vertex: hive-field-0 partition: 32 } @@ -72,8 +76,14 @@ hive-field-1: nebula-field-1, hive-field-2: nebula-field-2 } - source: hive-field-0 - target: hive-field-1 + source: { + field: hive-field-0 + policy: "hash" + } + target: { + field:hive-field-1 + policy: "uuid" + } ranking: hive-field-2 partition: 32 } diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala index 2a0e409b417..e54b4a3a6bd 100644 --- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala +++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala @@ -12,15 +12,16 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.udf import java.io.File +import com.google.common.base.Optional import com.google.common.geometry.{S2CellId, S2LatLng} import com.google.common.net.HostAndPort +import com.google.common.util.concurrent.{FutureCallback, Futures} +import com.vesoft.nebula.client.graph.async.AsyncGraphClientImpl import com.vesoft.nebula.graph.ErrorCode -import com.vesoft.nebula.graph.client.GraphClientImpl import org.apache.log4j.Logger import org.apache.spark.sql.types._ import scala.collection.JavaConverters._ -import scala.util.Random import util.control.Breaks._ case class Argument(config: File = new File("application.conf"), @@ -35,19 +36,26 @@ object SparkClientGenerator { private[this] val LOG = Logger.getLogger(this.getClass) - private[this] val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s" - private[this] val INSERT_VALUE_TEMPLATE = "%d: (%s)" - private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%d->%d: (%s)" - private[this] val EDGE_VALUE_TEMPLATE = "%d->%d@%d: (%s)" - private[this] val USE_TEMPLATE = "USE %s" - - private[this] val DEFAULT_BATCH = 2 + private[this] val HASH_POLICY = "hash" + private[this] val UUID_POLICY = "uuid" + private[this] val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s" + private[this] val INSERT_VALUE_TEMPLATE = "%d: (%s)" + private[this] val INSERT_VALUE_TEMPLATE_WITH_POLICY = "%s(%d): (%s)" + private[this] val ENDPOINT_TEMPLATE = "%s(%d)" + private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%d->%d: (%s)" + private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY = "%s->%s: (%s)" + private[this] val EDGE_VALUE_TEMPLATE = "%d->%d@%d: (%s)" + private[this] val EDGE_VALUE_TEMPLATE_WITH_POLICY = "%s->%s@%d: (%s)" + private[this] val USE_TEMPLATE = "USE %s" + + private[this] val DEFAULT_BATCH = 64 private[this] val DEFAULT_PARTITION = -1 private[this] val DEFAULT_CONNECTION_TIMEOUT = 3000 private[this] val DEFAULT_CONNECTION_RETRY = 3 private[this] val DEFAULT_EXECUTION_RETRY = 3 private[this] val DEFAULT_EXECUTION_INTERVAL = 3000 private[this] val DEFAULT_EDGE_RANKING = 0L + private[this] val DEFAULT_ERROR_TIMES = 16 // GEO default config private[this] val DEFAULT_MIN_CELL_LEVEL = 5 @@ -148,6 +156,8 @@ object SparkClientGenerator { Some(config.getObject("tags")) else None + class TooManyErrorException(e: String) extends Exception(e) {} + if (tagConfigs.isDefined) { for (tagName <- tagConfigs.get.unwrapped.keySet.asScala) { LOG.info(s"Processing Tag ${tagName}") @@ -164,8 +174,18 @@ object SparkClientGenerator { } val fields = tagConfig.getObject("fields").unwrapped + val vertex = if (tagConfig.hasPath("vertex")) { + tagConfig.getString("vertex") + } else { + tagConfig.getString("vertex.field") + } + + val policyOpt = if (tagConfig.hasPath("vertex.policy")) { + Some(tagConfig.getString("vertex.policy").toLowerCase) + } else { + None + } - val vertex = tagConfig.getString("vertex") val batch = getOrElse(tagConfig, "batch", DEFAULT_BATCH) val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION) @@ -178,6 +198,14 @@ object SparkClientGenerator { fields.asScala.keys.toList } + val sourceColumn = sourceProperties.map { property => + if (property == vertex) { + col(property).cast(LongType) + } else { + col(property) + } + } + val vertexIndex = sourceProperties.indexOf(vertex) val nebulaProperties = properties.mkString(",") @@ -186,8 +214,11 @@ object SparkClientGenerator { val data = createDataSource(spark, pathOpt, tagConfig) if (data.isDefined && !c.dry) { + val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagName}") + val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${tagName}") + repartition(data.get, partition) - .select(sourceProperties.map(col): _*) + .select(sourceColumn: _*) .withColumn(vertex, toVertexUDF(col(vertex))) .map { row => (row.getLong(vertexIndex), @@ -197,14 +228,16 @@ object SparkClientGenerator { }(Encoders.tuple(Encoders.scalaLong, Encoders.STRING)) .foreachPartition { iterator: Iterator[(Long, String)] => val hostAndPorts = addresses.map(HostAndPort.fromString).asJava - val client = - new GraphClientImpl(hostAndPorts, - connectionTimeout, - connectionRetry, - executionRetry) - - if (isSuccessfully(client.connect(user, pswd))) { - if (isSuccessfully(client.execute(USE_TEMPLATE.format(space)))) { + val client = new AsyncGraphClientImpl(hostAndPorts, + connectionTimeout, + connectionRetry, + executionRetry) + client.setUser(user) + client.setPassword(pswd) + + if (isSuccessfully(client.connect())) { + val switchSpaceCode = client.execute(USE_TEMPLATE.format(space)).get().get() + if (isSuccessfully(switchSpaceCode)) { iterator.grouped(batch).foreach { tags => val exec = BATCH_INSERT_TEMPLATE.format( Type.Vertex.toString, @@ -212,20 +245,42 @@ object SparkClientGenerator { nebulaProperties, tags .map { tag => - INSERT_VALUE_TEMPLATE.format(tag._1, tag._2) + if (policyOpt.isEmpty) { + INSERT_VALUE_TEMPLATE.format(tag._1, tag._2) + } else { + policyOpt.get match { + case HASH_POLICY => + INSERT_VALUE_TEMPLATE_WITH_POLICY.format(HASH_POLICY, + tag._1, + tag._2) + case UUID_POLICY => + INSERT_VALUE_TEMPLATE_WITH_POLICY.format(UUID_POLICY, + tag._1, + tag._2) + case _ => throw new IllegalArgumentException + } + } } .mkString(", ") ) LOG.debug(s"Exec : ${exec}") - breakable { - for (time <- 1 to executionRetry - if isSuccessfullyWithSleep( - client.execute(exec), - time * executionInterval + Random.nextInt(10) * 100L)(exec)) { - break + val future = client.execute(exec) + Futures.addCallback( + future, + new FutureCallback[Optional[Integer]] { + override def onSuccess(result: Optional[Integer]): Unit = { + batchSuccess.add(1) + } + + override def onFailure(t: Throwable): Unit = { + if (batchFailure.value > DEFAULT_ERROR_TIMES) { + throw new TooManyErrorException("too many error") + } + batchFailure.add(1) + } } - } + ) } } else { LOG.error(s"Switch ${space} Failed") @@ -261,7 +316,18 @@ object SparkClientGenerator { val fields = edgeConfig.getObject("fields").unwrapped val isGeo = checkGeoSupported(edgeConfig) - val target = edgeConfig.getString("target") + val target = if (edgeConfig.hasPath("target")) { + edgeConfig.getString("target") + } else { + edgeConfig.getString("target.field") + } + + val targetPolicyOpt = if (edgeConfig.hasPath("target.policy")) { + Some(edgeConfig.getString("target.policy").toLowerCase) + } else { + None + } + val rankingOpt = if (edgeConfig.hasPath("ranking")) { Some(edgeConfig.getString("ranking")) } else { @@ -274,7 +340,12 @@ object SparkClientGenerator { val valueProperties = fields.asScala.keys.toList val sourceProperties = if (!isGeo) { - val source = edgeConfig.getString("source") + val source = if (edgeConfig.hasPath("source")) { + edgeConfig.getString("source") + } else { + edgeConfig.getString("source.field") + } + if (!fields.containsKey(source) || !fields.containsKey(target)) { (fields.asScala.keySet + source + target).toList @@ -293,6 +364,33 @@ object SparkClientGenerator { } } + val sourcePolicyOpt = if (edgeConfig.hasPath("source.policy")) { + Some(edgeConfig.getString("source.policy").toLowerCase) + } else { + None + } + + val sourceColumn = if (!isGeo) { + val source = edgeConfig.getString("source") + sourceProperties.map { property => + if (property == source || property == target) { + col(property).cast(LongType) + } else { + col(property) + } + } + } else { + val latitude = edgeConfig.getString("latitude") + val longitude = edgeConfig.getString("longitude") + sourceProperties.map { property => + if (property == latitude || property == longitude) { + col(property).cast(DoubleType) + } else { + col(property) + } + } + } + val nebulaProperties = properties.mkString(",") val data = createDataSource(spark, pathOpt, edgeConfig) @@ -300,8 +398,11 @@ object SparkClientGenerator { Encoders.tuple(Encoders.STRING, Encoders.scalaLong, Encoders.scalaLong, Encoders.STRING) if (data.isDefined && !c.dry) { + val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeName}") + val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeName}") + repartition(data.get, partition) - .select(sourceProperties.map(col): _*) + .select(sourceColumn: _*) .map { row => val sourceField = if (!isGeo) { val source = edgeConfig.getString("source") @@ -330,13 +431,16 @@ object SparkClientGenerator { }(encoder) .foreachPartition { iterator: Iterator[(String, Long, Long, String)] => val hostAndPorts = addresses.map(HostAndPort.fromString).asJava - val client = - new GraphClientImpl(hostAndPorts, - connectionTimeout, - connectionRetry, - executionRetry) - if (isSuccessfully(client.connect(user, pswd))) { - if (isSuccessfully(client.execute(USE_TEMPLATE.format(space)))) { + val client = new AsyncGraphClientImpl(hostAndPorts, + connectionTimeout, + connectionRetry, + executionRetry) + + client.setUser(user) + client.setPassword(pswd) + if (isSuccessfully(client.connect())) { + val switchSpaceCode = client.switchSpace(space).get().get() + if (isSuccessfully(switchSpaceCode)) { iterator.grouped(batch).foreach { edges => val values = if (rankingOpt.isEmpty) @@ -345,8 +449,29 @@ object SparkClientGenerator { // TODO: (darion.yaphet) dataframe.explode() would be better ? (for (source <- edge._1.split(",")) yield - EDGE_VALUE_WITHOUT_RANKING_TEMPLATE - .format(source.toLong, edge._2, edge._4)).mkString(", ") + if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) { + EDGE_VALUE_WITHOUT_RANKING_TEMPLATE + .format(source.toLong, edge._2, edge._4) + } else { + val source = sourcePolicyOpt.get match { + case HASH_POLICY => + ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1) + case UUID_POLICY => + ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1) + case _ => throw new IllegalArgumentException + } + + val target = targetPolicyOpt.get match { + case HASH_POLICY => + ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2) + case UUID_POLICY => + ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2) + case _ => throw new IllegalArgumentException + } + + EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY + .format(source, target, edge._4) + }).mkString(", ") } .toList .mkString(", ") @@ -356,8 +481,29 @@ object SparkClientGenerator { // TODO: (darion.yaphet) dataframe.explode() would be better ? (for (source <- edge._1.split(",")) yield - EDGE_VALUE_TEMPLATE - .format(source.toLong, edge._2, edge._3, edge._4)) + if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) { + EDGE_VALUE_TEMPLATE + .format(source.toLong, edge._2, edge._3, edge._4) + } else { + val source = sourcePolicyOpt.get match { + case HASH_POLICY => + ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1) + case UUID_POLICY => + ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1) + case _ => throw new IllegalArgumentException + } + + val target = targetPolicyOpt.get match { + case HASH_POLICY => + ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2) + case UUID_POLICY => + ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2) + case _ => throw new IllegalArgumentException + } + + EDGE_VALUE_TEMPLATE_WITH_POLICY + .format(source, target, edge._3, edge._4) + }) .mkString(", ") } .toList @@ -366,14 +512,22 @@ object SparkClientGenerator { val exec = BATCH_INSERT_TEMPLATE .format(Type.Edge.toString, edgeName, nebulaProperties, values) LOG.debug(s"Exec : ${exec}") - breakable { - for (time <- 1 to executionRetry - if isSuccessfullyWithSleep( - client.execute(exec), - time * executionInterval + Random.nextInt(10) * 100L)(exec)) { - break + val future = client.execute(exec) + Futures.addCallback( + future, + new FutureCallback[Optional[Integer]] { + override def onSuccess(result: Optional[Integer]): Unit = { + batchSuccess.add(1) + } + + override def onFailure(t: Throwable): Unit = { + if (batchFailure.value > DEFAULT_ERROR_TIMES) { + throw new TooManyErrorException("too many error") + } + batchFailure.add(1) + } } - } + ) } } else { LOG.error(s"Switch ${space} Failed") @@ -400,7 +554,7 @@ object SparkClientGenerator { */ private[this] def createDataSource(session: SparkSession, pathOpt: Option[String], - config: Config) = { + config: Config): Option[DataFrame] = { val `type` = config.getString("type") pathOpt match { @@ -480,7 +634,7 @@ object SparkClientGenerator { * @param field The field name. * @return */ - private[this] def extraValue(row: Row, field: String) = { + private[this] def extraValue(row: Row, field: String): Any = { val index = row.schema.fieldIndex(field) row.schema.fields(index).dataType match { case StringType => @@ -605,7 +759,7 @@ object SparkClientGenerator { * @param edgeConfig The config of edge. * @return */ - private[this] def checkGeoSupported(edgeConfig: Config) = { + private[this] def checkGeoSupported(edgeConfig: Config): Boolean = { !edgeConfig.hasPath("source") && edgeConfig.hasPath("latitude") && edgeConfig.hasPath("longitude") @@ -634,7 +788,7 @@ object SparkClientGenerator { * @param defaultValue The default value for the path. * @return */ - private[this] def getOrElse[T](config: Config, path: String, defaultValue: T) = { + private[this] def getOrElse[T](config: Config, path: String, defaultValue: T): T = { if (config.hasPath(path)) { config.getAnyRef(path).asInstanceOf[T] } else { @@ -649,7 +803,7 @@ object SparkClientGenerator { * @param lng The longitude of coordinate. * @return */ - private[this] def indexCells(lat: Double, lng: Double) = { + private[this] def indexCells(lat: Double, lng: Double): IndexedSeq[Long] = { val coordinate = S2LatLng.fromDegrees(lat, lng) val s2CellId = S2CellId.fromLatLng(coordinate) for (index <- DEFAULT_MIN_CELL_LEVEL to DEFAULT_MAX_CELL_LEVEL)