diff --git a/src/tools/spark-sstfile-generator/.gitignore b/src/tools/spark-sstfile-generator/.gitignore
index 916e17c097a..98316955f9d 100644
--- a/src/tools/spark-sstfile-generator/.gitignore
+++ b/src/tools/spark-sstfile-generator/.gitignore
@@ -1 +1,2 @@
dependency-reduced-pom.xml
+*.iml
diff --git a/src/tools/spark-sstfile-generator/pom.xml b/src/tools/spark-sstfile-generator/pom.xml
index 9ca9a98e3fa..36f3b0df1f4 100644
--- a/src/tools/spark-sstfile-generator/pom.xml
+++ b/src/tools/spark-sstfile-generator/pom.xml
@@ -6,7 +6,7 @@
com.vesoft
sst.generator
- 1.0.0-beta
+ 1.0.0-rc2
1.8
@@ -21,8 +21,7 @@
1.4.0
3.9.2
3.7.1
- 1.4.0
- 1.0.0-beta
+ 1.0.0-rc2
1.0.0
@@ -49,11 +48,11 @@
- com/vesoft/tools/**
+ com/vesoft/tools/**
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
- META-INF/*.SF
- META-INF/*.DSA
- META-INF/*.RSA
@@ -80,6 +79,20 @@
shade
+
+
+ org.apache.spark:*
+ org.apache.hadoop:*
+ org.apache.hive:*
+ log4j:log4j
+ org.apache.orc:*
+ xml-apis:xml-apis
+ javax.inject:javax.inject
+ org.spark-project.hive:hive-exec
+ stax:stax-api
+ org.glassfish.hk2.external:aopalliance-repackaged
+
+
*:*
@@ -104,26 +117,212 @@
org.apache.spark
spark-core_2.11
${spark.version}
+
+
+ snappy-java
+ org.xerial.snappy
+
+
+ paranamer
+ com.thoughtworks.paranamer
+
+
+ slf4j-api
+ org.slf4j
+
+
+ commons-codec
+ commons-codec
+
+
+ avro
+ org.apache.avro
+
+
+ commons-lang
+ commons-lang
+
+
+ commons-collections
+ commons-collections
+
+
+ commons-compress
+ org.apache.commons
+
+
+ commons-math3
+ org.apache.commons
+
+
+ guava
+ com.google.guava
+
+
+ httpclient
+ org.apache.httpcomponents
+
+
+ slf4j-log4j12
+ org.slf4j
+
+
+ netty
+ io.netty
+
+
+ jackson-annotations
+ com.fasterxml.jackson.core
+
+
+ scala-reflect
+ org.scala-lang
+
+
+ scala-library
+ org.scala-lang
+
+
+ jackson-databind
+ com.fasterxml.jackson.core
+
+
+ scala-xml_2.11
+ org.scala-lang.modules
+
+
+ log4j
+ log4j
+
+
org.apache.spark
spark-sql_2.11
${spark.version}
+
+
+ snappy-java
+ org.xerial.snappy
+
+
+ jsr305
+ com.google.code.findbugs
+
+
+ slf4j-api
+ org.slf4j
+
+
+ jackson-core
+ com.fasterxml.jackson.core
+
+
+ joda-time
+ joda-time
+
+
+ commons-codec
+ commons-codec
+
+
+ snappy-java
+ org.xerial.snappy
+
+
org.apache.spark
spark-hive_2.11
${spark.version}
+
+
+ commons-codec
+ commons-codec
+
+
+ commons-logging
+ commons-logging
+
+
+ avro
+ org.apache.avro
+
+
+ commons-compress
+ org.apache.commons
+
+
+ commons-lang3
+ org.apache.commons
+
+
+ jackson-mapper-asl
+ org.codehaus.jackson
+
+
+ antlr-runtime
+ org.antlr
+
+
+ jackson-core-asl
+ org.codehaus.jackson
+
+
+ derby
+ org.apache.derby
+
+
+ httpclient
+ org.apache.httpcomponents
+
+
+ httpcore
+ org.apache.httpcomponents
+
+
org.apache.spark
spark-yarn_2.11
${spark.version}
+
+
+ guava
+ com.google.guava
+
+
+ commons-codec
+ commons-codec
+
+
+ commons-compress
+ org.apache.commons
+
+
+ activation
+ javax.activation
+
+
+ slf4j-api
+ org.slf4j
+
+
com.databricks
spark-csv_2.11
1.5.0
+
+
+ scala-library
+ org.scala-lang
+
+
+ univocity-parsers
+ com.univocity
+
+
org.scalatest
@@ -145,22 +344,26 @@
com.typesafe.scala-logging
scala-logging_2.11
${scala-logging.version}
+
+
+ scala-library
+ org.scala-lang
+
+
+ scala-reflect
+ org.scala-lang
+
+
+ slf4j-api
+ org.slf4j
+
+
com.github.scopt
scopt_2.11
${scopt.version}
-
- com.typesafe
- config
- ${config.version}
-
-
- com.vesoft
- client
- ${nebula.version}
-
mysql
mysql-connector-java
@@ -171,5 +374,10 @@
s2-geometry-library-java
${s2.version}
+
+ com.vesoft
+ client
+ ${nebula.version}
+
diff --git a/src/tools/spark-sstfile-generator/src/main/resources/application.conf b/src/tools/spark-sstfile-generator/src/main/resources/application.conf
index 86fcde30df5..dd2f00d7f2e 100644
--- a/src/tools/spark-sstfile-generator/src/main/resources/application.conf
+++ b/src/tools/spark-sstfile-generator/src/main/resources/application.conf
@@ -56,6 +56,10 @@
hive-field-1: nebula-field-1,
hive-field-2: nebula-field-2
}
+ vertex: {
+ field: hive-field-0
+ policy: "hash"
+ }
vertex: hive-field-0
partition: 32
}
@@ -72,8 +76,14 @@
hive-field-1: nebula-field-1,
hive-field-2: nebula-field-2
}
- source: hive-field-0
- target: hive-field-1
+ source: {
+ field: hive-field-0
+ policy: "hash"
+ }
+ target: {
+ field:hive-field-1
+ policy: "uuid"
+ }
ranking: hive-field-2
partition: 32
}
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
index 2a0e409b417..e54b4a3a6bd 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
@@ -12,15 +12,16 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.udf
import java.io.File
+import com.google.common.base.Optional
import com.google.common.geometry.{S2CellId, S2LatLng}
import com.google.common.net.HostAndPort
+import com.google.common.util.concurrent.{FutureCallback, Futures}
+import com.vesoft.nebula.client.graph.async.AsyncGraphClientImpl
import com.vesoft.nebula.graph.ErrorCode
-import com.vesoft.nebula.graph.client.GraphClientImpl
import org.apache.log4j.Logger
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
-import scala.util.Random
import util.control.Breaks._
case class Argument(config: File = new File("application.conf"),
@@ -35,19 +36,26 @@ object SparkClientGenerator {
private[this] val LOG = Logger.getLogger(this.getClass)
- private[this] val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s"
- private[this] val INSERT_VALUE_TEMPLATE = "%d: (%s)"
- private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%d->%d: (%s)"
- private[this] val EDGE_VALUE_TEMPLATE = "%d->%d@%d: (%s)"
- private[this] val USE_TEMPLATE = "USE %s"
-
- private[this] val DEFAULT_BATCH = 2
+ private[this] val HASH_POLICY = "hash"
+ private[this] val UUID_POLICY = "uuid"
+ private[this] val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s"
+ private[this] val INSERT_VALUE_TEMPLATE = "%d: (%s)"
+ private[this] val INSERT_VALUE_TEMPLATE_WITH_POLICY = "%s(%d): (%s)"
+ private[this] val ENDPOINT_TEMPLATE = "%s(%d)"
+ private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%d->%d: (%s)"
+ private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY = "%s->%s: (%s)"
+ private[this] val EDGE_VALUE_TEMPLATE = "%d->%d@%d: (%s)"
+ private[this] val EDGE_VALUE_TEMPLATE_WITH_POLICY = "%s->%s@%d: (%s)"
+ private[this] val USE_TEMPLATE = "USE %s"
+
+ private[this] val DEFAULT_BATCH = 64
private[this] val DEFAULT_PARTITION = -1
private[this] val DEFAULT_CONNECTION_TIMEOUT = 3000
private[this] val DEFAULT_CONNECTION_RETRY = 3
private[this] val DEFAULT_EXECUTION_RETRY = 3
private[this] val DEFAULT_EXECUTION_INTERVAL = 3000
private[this] val DEFAULT_EDGE_RANKING = 0L
+ private[this] val DEFAULT_ERROR_TIMES = 16
// GEO default config
private[this] val DEFAULT_MIN_CELL_LEVEL = 5
@@ -148,6 +156,8 @@ object SparkClientGenerator {
Some(config.getObject("tags"))
else None
+ class TooManyErrorException(e: String) extends Exception(e) {}
+
if (tagConfigs.isDefined) {
for (tagName <- tagConfigs.get.unwrapped.keySet.asScala) {
LOG.info(s"Processing Tag ${tagName}")
@@ -164,8 +174,18 @@ object SparkClientGenerator {
}
val fields = tagConfig.getObject("fields").unwrapped
+ val vertex = if (tagConfig.hasPath("vertex")) {
+ tagConfig.getString("vertex")
+ } else {
+ tagConfig.getString("vertex.field")
+ }
+
+ val policyOpt = if (tagConfig.hasPath("vertex.policy")) {
+ Some(tagConfig.getString("vertex.policy").toLowerCase)
+ } else {
+ None
+ }
- val vertex = tagConfig.getString("vertex")
val batch = getOrElse(tagConfig, "batch", DEFAULT_BATCH)
val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION)
@@ -178,6 +198,14 @@ object SparkClientGenerator {
fields.asScala.keys.toList
}
+ val sourceColumn = sourceProperties.map { property =>
+ if (property == vertex) {
+ col(property).cast(LongType)
+ } else {
+ col(property)
+ }
+ }
+
val vertexIndex = sourceProperties.indexOf(vertex)
val nebulaProperties = properties.mkString(",")
@@ -186,8 +214,11 @@ object SparkClientGenerator {
val data = createDataSource(spark, pathOpt, tagConfig)
if (data.isDefined && !c.dry) {
+ val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagName}")
+ val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${tagName}")
+
repartition(data.get, partition)
- .select(sourceProperties.map(col): _*)
+ .select(sourceColumn: _*)
.withColumn(vertex, toVertexUDF(col(vertex)))
.map { row =>
(row.getLong(vertexIndex),
@@ -197,14 +228,16 @@ object SparkClientGenerator {
}(Encoders.tuple(Encoders.scalaLong, Encoders.STRING))
.foreachPartition { iterator: Iterator[(Long, String)] =>
val hostAndPorts = addresses.map(HostAndPort.fromString).asJava
- val client =
- new GraphClientImpl(hostAndPorts,
- connectionTimeout,
- connectionRetry,
- executionRetry)
-
- if (isSuccessfully(client.connect(user, pswd))) {
- if (isSuccessfully(client.execute(USE_TEMPLATE.format(space)))) {
+ val client = new AsyncGraphClientImpl(hostAndPorts,
+ connectionTimeout,
+ connectionRetry,
+ executionRetry)
+ client.setUser(user)
+ client.setPassword(pswd)
+
+ if (isSuccessfully(client.connect())) {
+ val switchSpaceCode = client.execute(USE_TEMPLATE.format(space)).get().get()
+ if (isSuccessfully(switchSpaceCode)) {
iterator.grouped(batch).foreach { tags =>
val exec = BATCH_INSERT_TEMPLATE.format(
Type.Vertex.toString,
@@ -212,20 +245,42 @@ object SparkClientGenerator {
nebulaProperties,
tags
.map { tag =>
- INSERT_VALUE_TEMPLATE.format(tag._1, tag._2)
+ if (policyOpt.isEmpty) {
+ INSERT_VALUE_TEMPLATE.format(tag._1, tag._2)
+ } else {
+ policyOpt.get match {
+ case HASH_POLICY =>
+ INSERT_VALUE_TEMPLATE_WITH_POLICY.format(HASH_POLICY,
+ tag._1,
+ tag._2)
+ case UUID_POLICY =>
+ INSERT_VALUE_TEMPLATE_WITH_POLICY.format(UUID_POLICY,
+ tag._1,
+ tag._2)
+ case _ => throw new IllegalArgumentException
+ }
+ }
}
.mkString(", ")
)
LOG.debug(s"Exec : ${exec}")
- breakable {
- for (time <- 1 to executionRetry
- if isSuccessfullyWithSleep(
- client.execute(exec),
- time * executionInterval + Random.nextInt(10) * 100L)(exec)) {
- break
+ val future = client.execute(exec)
+ Futures.addCallback(
+ future,
+ new FutureCallback[Optional[Integer]] {
+ override def onSuccess(result: Optional[Integer]): Unit = {
+ batchSuccess.add(1)
+ }
+
+ override def onFailure(t: Throwable): Unit = {
+ if (batchFailure.value > DEFAULT_ERROR_TIMES) {
+ throw new TooManyErrorException("too many error")
+ }
+ batchFailure.add(1)
+ }
}
- }
+ )
}
} else {
LOG.error(s"Switch ${space} Failed")
@@ -261,7 +316,18 @@ object SparkClientGenerator {
val fields = edgeConfig.getObject("fields").unwrapped
val isGeo = checkGeoSupported(edgeConfig)
- val target = edgeConfig.getString("target")
+ val target = if (edgeConfig.hasPath("target")) {
+ edgeConfig.getString("target")
+ } else {
+ edgeConfig.getString("target.field")
+ }
+
+ val targetPolicyOpt = if (edgeConfig.hasPath("target.policy")) {
+ Some(edgeConfig.getString("target.policy").toLowerCase)
+ } else {
+ None
+ }
+
val rankingOpt = if (edgeConfig.hasPath("ranking")) {
Some(edgeConfig.getString("ranking"))
} else {
@@ -274,7 +340,12 @@ object SparkClientGenerator {
val valueProperties = fields.asScala.keys.toList
val sourceProperties = if (!isGeo) {
- val source = edgeConfig.getString("source")
+ val source = if (edgeConfig.hasPath("source")) {
+ edgeConfig.getString("source")
+ } else {
+ edgeConfig.getString("source.field")
+ }
+
if (!fields.containsKey(source) ||
!fields.containsKey(target)) {
(fields.asScala.keySet + source + target).toList
@@ -293,6 +364,33 @@ object SparkClientGenerator {
}
}
+ val sourcePolicyOpt = if (edgeConfig.hasPath("source.policy")) {
+ Some(edgeConfig.getString("source.policy").toLowerCase)
+ } else {
+ None
+ }
+
+ val sourceColumn = if (!isGeo) {
+ val source = edgeConfig.getString("source")
+ sourceProperties.map { property =>
+ if (property == source || property == target) {
+ col(property).cast(LongType)
+ } else {
+ col(property)
+ }
+ }
+ } else {
+ val latitude = edgeConfig.getString("latitude")
+ val longitude = edgeConfig.getString("longitude")
+ sourceProperties.map { property =>
+ if (property == latitude || property == longitude) {
+ col(property).cast(DoubleType)
+ } else {
+ col(property)
+ }
+ }
+ }
+
val nebulaProperties = properties.mkString(",")
val data = createDataSource(spark, pathOpt, edgeConfig)
@@ -300,8 +398,11 @@ object SparkClientGenerator {
Encoders.tuple(Encoders.STRING, Encoders.scalaLong, Encoders.scalaLong, Encoders.STRING)
if (data.isDefined && !c.dry) {
+ val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeName}")
+ val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeName}")
+
repartition(data.get, partition)
- .select(sourceProperties.map(col): _*)
+ .select(sourceColumn: _*)
.map { row =>
val sourceField = if (!isGeo) {
val source = edgeConfig.getString("source")
@@ -330,13 +431,16 @@ object SparkClientGenerator {
}(encoder)
.foreachPartition { iterator: Iterator[(String, Long, Long, String)] =>
val hostAndPorts = addresses.map(HostAndPort.fromString).asJava
- val client =
- new GraphClientImpl(hostAndPorts,
- connectionTimeout,
- connectionRetry,
- executionRetry)
- if (isSuccessfully(client.connect(user, pswd))) {
- if (isSuccessfully(client.execute(USE_TEMPLATE.format(space)))) {
+ val client = new AsyncGraphClientImpl(hostAndPorts,
+ connectionTimeout,
+ connectionRetry,
+ executionRetry)
+
+ client.setUser(user)
+ client.setPassword(pswd)
+ if (isSuccessfully(client.connect())) {
+ val switchSpaceCode = client.switchSpace(space).get().get()
+ if (isSuccessfully(switchSpaceCode)) {
iterator.grouped(batch).foreach { edges =>
val values =
if (rankingOpt.isEmpty)
@@ -345,8 +449,29 @@ object SparkClientGenerator {
// TODO: (darion.yaphet) dataframe.explode() would be better ?
(for (source <- edge._1.split(","))
yield
- EDGE_VALUE_WITHOUT_RANKING_TEMPLATE
- .format(source.toLong, edge._2, edge._4)).mkString(", ")
+ if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) {
+ EDGE_VALUE_WITHOUT_RANKING_TEMPLATE
+ .format(source.toLong, edge._2, edge._4)
+ } else {
+ val source = sourcePolicyOpt.get match {
+ case HASH_POLICY =>
+ ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1)
+ case UUID_POLICY =>
+ ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1)
+ case _ => throw new IllegalArgumentException
+ }
+
+ val target = targetPolicyOpt.get match {
+ case HASH_POLICY =>
+ ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2)
+ case UUID_POLICY =>
+ ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2)
+ case _ => throw new IllegalArgumentException
+ }
+
+ EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY
+ .format(source, target, edge._4)
+ }).mkString(", ")
}
.toList
.mkString(", ")
@@ -356,8 +481,29 @@ object SparkClientGenerator {
// TODO: (darion.yaphet) dataframe.explode() would be better ?
(for (source <- edge._1.split(","))
yield
- EDGE_VALUE_TEMPLATE
- .format(source.toLong, edge._2, edge._3, edge._4))
+ if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) {
+ EDGE_VALUE_TEMPLATE
+ .format(source.toLong, edge._2, edge._3, edge._4)
+ } else {
+ val source = sourcePolicyOpt.get match {
+ case HASH_POLICY =>
+ ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1)
+ case UUID_POLICY =>
+ ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1)
+ case _ => throw new IllegalArgumentException
+ }
+
+ val target = targetPolicyOpt.get match {
+ case HASH_POLICY =>
+ ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2)
+ case UUID_POLICY =>
+ ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2)
+ case _ => throw new IllegalArgumentException
+ }
+
+ EDGE_VALUE_TEMPLATE_WITH_POLICY
+ .format(source, target, edge._3, edge._4)
+ })
.mkString(", ")
}
.toList
@@ -366,14 +512,22 @@ object SparkClientGenerator {
val exec = BATCH_INSERT_TEMPLATE
.format(Type.Edge.toString, edgeName, nebulaProperties, values)
LOG.debug(s"Exec : ${exec}")
- breakable {
- for (time <- 1 to executionRetry
- if isSuccessfullyWithSleep(
- client.execute(exec),
- time * executionInterval + Random.nextInt(10) * 100L)(exec)) {
- break
+ val future = client.execute(exec)
+ Futures.addCallback(
+ future,
+ new FutureCallback[Optional[Integer]] {
+ override def onSuccess(result: Optional[Integer]): Unit = {
+ batchSuccess.add(1)
+ }
+
+ override def onFailure(t: Throwable): Unit = {
+ if (batchFailure.value > DEFAULT_ERROR_TIMES) {
+ throw new TooManyErrorException("too many error")
+ }
+ batchFailure.add(1)
+ }
}
- }
+ )
}
} else {
LOG.error(s"Switch ${space} Failed")
@@ -400,7 +554,7 @@ object SparkClientGenerator {
*/
private[this] def createDataSource(session: SparkSession,
pathOpt: Option[String],
- config: Config) = {
+ config: Config): Option[DataFrame] = {
val `type` = config.getString("type")
pathOpt match {
@@ -480,7 +634,7 @@ object SparkClientGenerator {
* @param field The field name.
* @return
*/
- private[this] def extraValue(row: Row, field: String) = {
+ private[this] def extraValue(row: Row, field: String): Any = {
val index = row.schema.fieldIndex(field)
row.schema.fields(index).dataType match {
case StringType =>
@@ -605,7 +759,7 @@ object SparkClientGenerator {
* @param edgeConfig The config of edge.
* @return
*/
- private[this] def checkGeoSupported(edgeConfig: Config) = {
+ private[this] def checkGeoSupported(edgeConfig: Config): Boolean = {
!edgeConfig.hasPath("source") &&
edgeConfig.hasPath("latitude") &&
edgeConfig.hasPath("longitude")
@@ -634,7 +788,7 @@ object SparkClientGenerator {
* @param defaultValue The default value for the path.
* @return
*/
- private[this] def getOrElse[T](config: Config, path: String, defaultValue: T) = {
+ private[this] def getOrElse[T](config: Config, path: String, defaultValue: T): T = {
if (config.hasPath(path)) {
config.getAnyRef(path).asInstanceOf[T]
} else {
@@ -649,7 +803,7 @@ object SparkClientGenerator {
* @param lng The longitude of coordinate.
* @return
*/
- private[this] def indexCells(lat: Double, lng: Double) = {
+ private[this] def indexCells(lat: Double, lng: Double): IndexedSeq[Long] = {
val coordinate = S2LatLng.fromDegrees(lat, lng)
val s2CellId = S2CellId.fromLatLng(coordinate)
for (index <- DEFAULT_MIN_CELL_LEVEL to DEFAULT_MAX_CELL_LEVEL)