diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala index 2aadccde..0d3a773c 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala @@ -215,13 +215,15 @@ private[connector] class WriteNebulaConfig(space: String, user: String, passwd: String, batch: Int, - writeMode: String) + writeMode: String, + overwrite: Boolean) extends Serializable { def getSpace = space def getBatch = batch def getUser = user def getPasswd = passwd def getWriteMode = writeMode + def isOverwrite = overwrite } /** @@ -242,8 +244,9 @@ class WriteNebulaVertexConfig(space: String, user: String, passwd: String, writeMode: String, - deleteEdge: Boolean) - extends WriteNebulaConfig(space, user, passwd, batch, writeMode) { + deleteEdge: Boolean, + overwrite: Boolean) + extends WriteNebulaConfig(space, user, passwd, batch, writeMode, overwrite) { def getTagName = tagName def getVidField = vidField def getVidPolicy = if (vidPolicy == null) "" else vidPolicy @@ -275,6 +278,9 @@ object WriteNebulaVertexConfig { /** whether delete the related edges of vertex */ var deleteEdge: Boolean = false + /** whether overwrite the exists vertex */ + var overwrite: Boolean = true + /** * set space name */ @@ -356,6 +362,14 @@ object WriteNebulaVertexConfig { this } + /** + * set whether overwrite the exists vertex + */ + def withOverwrite(overwrite: Boolean): WriteVertexConfigBuilder = { + this.overwrite = overwrite + this; + } + /** * check and get WriteNebulaVertexConfig */ @@ -370,7 +384,8 @@ object WriteNebulaVertexConfig { user, passwd, writeMode, - deleteEdge) + deleteEdge, + overwrite) } private def check(): Unit = { @@ -436,8 +451,9 @@ class WriteNebulaEdgeConfig(space: String, rankAsProp: Boolean, user: String, passwd: String, - writeMode: String) - extends WriteNebulaConfig(space, user, passwd, batch, writeMode) { + writeMode: String, + overwrite: Boolean) + extends WriteNebulaConfig(space, user, passwd, batch, writeMode, overwrite) { def getEdgeName = edgeName def getSrcFiled = srcFiled def getSrcPolicy = if (srcPolicy == null) "" else srcPolicy @@ -487,6 +503,9 @@ object WriteNebulaEdgeConfig { /** write mode for nebula, insert or update */ var writeMode: String = WriteMode.INSERT.toString + /** whether overwrite the exists edge */ + var overwrite: Boolean = true + /** * set space name */ @@ -600,6 +619,14 @@ object WriteNebulaEdgeConfig { this } + /** + * set whether overwrite the exists edge + */ + def withOverwrite(overwrite: Boolean): WriteEdgeConfigBuilder = { + this.overwrite = overwrite + this + } + /** * check configs and get WriteNebulaEdgeConfig */ @@ -618,7 +645,8 @@ object WriteNebulaEdgeConfig { rankAsProp, user, passwd, - writeMode) + writeMode, + overwrite) } private def check(): Unit = { diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala index 2bb630cf..6cfa6e42 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala @@ -139,6 +139,7 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String]) exten var rankAsProp: Boolean = _ var writeMode: WriteMode.Value = _ var deleteEdge: Boolean = _ + var overwrite: Boolean = _ if (operaType == OperaType.WRITE) { require(parameters.isDefinedAt(GRAPH_ADDRESS), @@ -175,6 +176,7 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String]) exten writeMode = WriteMode.withName(parameters.getOrElse(WRITE_MODE, DEFAULT_WRITE_MODE).toString.toLowerCase) deleteEdge = parameters.getOrElse(DELETE_EDGE, false).toString.toBoolean + overwrite = parameters.getOrElse(OVERWRITE, true).toString.toBoolean } def getReturnCols: List[String] = { @@ -260,6 +262,7 @@ object NebulaOptions { val RANK_AS_PROP: String = "rankAsProp" val WRITE_MODE: String = "writeMode" val DELETE_EDGE: String = "deleteEdge" + val OVERWRITE: String = "overwrite" val DEFAULT_TIMEOUT: Int = 3000 val DEFAULT_CONNECTION_TIMEOUT: Int = 3000 diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/Template.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/Template.scala index 748612df..b8a447f5 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/Template.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/Template.scala @@ -7,7 +7,9 @@ package com.vesoft.nebula.connector object NebulaTemplate { - private[connector] val BATCH_INSERT_TEMPLATE = "INSERT %s `%s`(%s) VALUES %s" + private[connector] val BATCH_INSERT_TEMPLATE = "INSERT %s `%s`(%s) VALUES %s" + private[connector] val BATCH_INSERT_NO_OVERWRITE_TEMPLATE = + "INSERT %s IF NOT EXISTS `%s`(%s) VALUES %s" private[connector] val VERTEX_VALUE_TEMPLATE = "%s: (%s)" private[connector] val VERTEX_VALUE_TEMPLATE_WITH_POLICY = "%s(\"%s\"): (%s)" private[connector] val ENDPOINT_TEMPLATE = "%s(\"%s\")" diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/writer/NebulaExecutor.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/writer/NebulaExecutor.scala index 5d3ecdf0..4c9323f7 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/writer/NebulaExecutor.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/writer/NebulaExecutor.scala @@ -7,6 +7,7 @@ package com.vesoft.nebula.connector.writer import com.vesoft.nebula.PropertyType import com.vesoft.nebula.connector.NebulaTemplate.{ + BATCH_INSERT_NO_OVERWRITE_TEMPLATE, BATCH_INSERT_TEMPLATE, DELETE_EDGE_TEMPLATE, DELETE_VERTEX_TEMPLATE, @@ -215,8 +216,8 @@ object NebulaExecutor { /** * construct insert statement for vertex */ - def toExecuteSentence(tagName: String, vertices: NebulaVertices): String = { - BATCH_INSERT_TEMPLATE.format( + def toExecuteSentence(tagName: String, vertices: NebulaVertices, overwrite: Boolean): String = { + (if (overwrite) BATCH_INSERT_TEMPLATE else BATCH_INSERT_NO_OVERWRITE_TEMPLATE).format( DataTypeEnum.VERTEX.toString, tagName, vertices.propertyNames, @@ -244,7 +245,7 @@ object NebulaExecutor { /** * construct insert statement for edge */ - def toExecuteSentence(edgeName: String, edges: NebulaEdges): String = { + def toExecuteSentence(edgeName: String, edges: NebulaEdges, overwrite: Boolean): String = { val values = edges.values .map { edge => val source = edges.getSourcePolicy match { @@ -278,7 +279,8 @@ object NebulaExecutor { EDGE_VALUE_TEMPLATE.format(source, target, edge.rank.get, edge.propertyValues) } .mkString(", ") - BATCH_INSERT_TEMPLATE.format(DataTypeEnum.EDGE.toString, edgeName, edges.propertyNames, values) + (if (overwrite) BATCH_INSERT_TEMPLATE else BATCH_INSERT_NO_OVERWRITE_TEMPLATE) + .format(DataTypeEnum.EDGE.toString, edgeName, edges.propertyNames, values) } /** diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala index 0ee3b8bb..7d59ed41 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -248,6 +248,7 @@ package object connector { .option(NebulaOptions.VID_AS_PROP, writeConfig.getVidAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) .option(NebulaOptions.DELETE_EDGE, writeConfig.getDeleteEdge) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) @@ -296,6 +297,7 @@ package object connector { .option(NebulaOptions.DST_AS_PROP, writeConfig.getDstAsProp) .option(NebulaOptions.RANK_AS_PROP, writeConfig.getRankAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala index 9b04b229..1410730d 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala @@ -83,7 +83,8 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions, def execute(): Unit = { val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) case WriteMode.DELETE => diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala index 8d418af1..a3710017 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala @@ -67,7 +67,10 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: def execute(): Unit = { val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaVertices) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, + nebulaVertices, + nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) case WriteMode.DELETE => diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala index 7a95f623..0ea2d454 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala @@ -140,12 +140,19 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { vertices.append(NebulaVertex("\"vid2\"", props2)) val nebulaVertices = NebulaVertices(propNames, vertices.toList, None) - val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, true) val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\": (" + props1 .mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")" assert(expectStatement.equals(vertexStatement)) + + val vertexWithoutOverwriteStatement = + NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, false) + val expectWithoutOverwriteStatement = "INSERT vertex IF NOT EXISTS `person`(`col_string`," + + "`col_fixed_string`,`col_bool`,`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) " + + "VALUES \"vid1\": (" + props1.mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")" + assert(expectWithoutOverwriteStatement.equals(vertexWithoutOverwriteStatement)) } test("test toExecuteSentence for vertex with hash policy") { @@ -167,7 +174,7 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { vertices.append(NebulaVertex("vid2", props2)) val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH)) - val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, true) val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES hash(\"vid1\"): (" + props1 @@ -201,12 +208,20 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2)) val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None) - val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges) + val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges, true) val expectStatement = "INSERT edge `friend`(`col_string`,`col_fixed_string`,`col_bool`,`col_int`" + ",`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\"->\"vid2\"@1: (" + props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + props2.mkString(", ") + ")" assert(expectStatement.equals(edgeStatement)) + + val edgeWithoutOverwriteStatement = + NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges, false) + val expectWithoutOverwriteStatement = "INSERT edge IF NOT EXISTS `friend`(`col_string`," + + "`col_fixed_string`,`col_bool`,`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) " + + "VALUES \"vid1\"->\"vid2\"@1: (" + props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + + props2.mkString(", ") + ")" + assert(expectWithoutOverwriteStatement.equals(edgeWithoutOverwriteStatement)) } test("test toUpdateExecuteSentence for vertex") { diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala index 8c616727..6af7663f 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -310,6 +310,7 @@ package object connector { .option(NebulaOptions.BATCH, writeConfig.getBatch) .option(NebulaOptions.VID_AS_PROP, writeConfig.getVidAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) @@ -358,6 +359,7 @@ package object connector { .option(NebulaOptions.DST_AS_PROP, writeConfig.getDstAsProp) .option(NebulaOptions.RANK_AS_PROP, writeConfig.getRankAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala index a849f5b1..14087319 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala @@ -95,7 +95,8 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions, def execute(): Unit = { val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) case WriteMode.DELETE => diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala index 28c009a8..e021cdd1 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala @@ -80,7 +80,10 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: private def execute(): Unit = { val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaVertices) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, + nebulaVertices, + nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) case WriteMode.DELETE => diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala index 79c4ba87..12fceb96 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -209,6 +209,7 @@ package object connector { .option(NebulaOptions.VID_AS_PROP, writeConfig.getVidAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) .option(NebulaOptions.DELETE_EDGE, writeConfig.getDeleteEdge) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) @@ -257,6 +258,7 @@ package object connector { .option(NebulaOptions.DST_AS_PROP, writeConfig.getDstAsProp) .option(NebulaOptions.RANK_AS_PROP, writeConfig.getRankAsProp) .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.OVERWRITE, writeConfig.isOverwrite) .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala index 9c28df82..de159887 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala @@ -83,7 +83,8 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions, def execute(): Unit = { val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) case WriteMode.DELETE => diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala index 22a5d311..517e98a9 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala @@ -67,7 +67,10 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: def execute(): Unit = { val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) val exec = nebulaOptions.writeMode match { - case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaVertices) + case WriteMode.INSERT => + NebulaExecutor.toExecuteSentence(nebulaOptions.label, + nebulaVertices, + nebulaOptions.overwrite) case WriteMode.UPDATE => NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) case WriteMode.DELETE => diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala index 7a95f623..0ea2d454 100644 --- a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala @@ -140,12 +140,19 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { vertices.append(NebulaVertex("\"vid2\"", props2)) val nebulaVertices = NebulaVertices(propNames, vertices.toList, None) - val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, true) val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\": (" + props1 .mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")" assert(expectStatement.equals(vertexStatement)) + + val vertexWithoutOverwriteStatement = + NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, false) + val expectWithoutOverwriteStatement = "INSERT vertex IF NOT EXISTS `person`(`col_string`," + + "`col_fixed_string`,`col_bool`,`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) " + + "VALUES \"vid1\": (" + props1.mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")" + assert(expectWithoutOverwriteStatement.equals(vertexWithoutOverwriteStatement)) } test("test toExecuteSentence for vertex with hash policy") { @@ -167,7 +174,7 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { vertices.append(NebulaVertex("vid2", props2)) val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH)) - val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices, true) val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES hash(\"vid1\"): (" + props1 @@ -201,12 +208,20 @@ class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2)) val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None) - val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges) + val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges, true) val expectStatement = "INSERT edge `friend`(`col_string`,`col_fixed_string`,`col_bool`,`col_int`" + ",`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\"->\"vid2\"@1: (" + props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + props2.mkString(", ") + ")" assert(expectStatement.equals(edgeStatement)) + + val edgeWithoutOverwriteStatement = + NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges, false) + val expectWithoutOverwriteStatement = "INSERT edge IF NOT EXISTS `friend`(`col_string`," + + "`col_fixed_string`,`col_bool`,`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) " + + "VALUES \"vid1\"->\"vid2\"@1: (" + props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + + props2.mkString(", ") + ")" + assert(expectWithoutOverwriteStatement.equals(edgeWithoutOverwriteStatement)) } test("test toUpdateExecuteSentence for vertex") {