Skip to content

Commit

Permalink
connector writer for spark2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicole00 committed Aug 31, 2022
1 parent 209da4b commit d93cdc3
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 460 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand All @@ -23,12 +23,14 @@ import org.apache.spark.sql.sources.{
DataSourceRegister,
RelationProvider
}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

class NebulaDataSource
extends RelationProvider
with CreatableRelationProvider
with DataSourceRegister {
with DataSourceRegister
with Serializable {
private val LOG = LoggerFactory.getLogger(this.getClass)

/**
Expand Down Expand Up @@ -58,7 +60,6 @@ class NebulaDataSource
data: DataFrame): BaseRelation = {

val nebulaOptions = getNebulaOptions(parameters, OperaType.WRITE)
val dataType = nebulaOptions.dataType
if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
LOG.warn(s"Currently do not support mode")
}
Expand All @@ -67,7 +68,27 @@ class NebulaDataSource
LOG.info(s"options ${parameters}")

val schema = data.schema
val writer: NebulaWriter =
data.foreachPartition(iterator => {
savePartition(nebulaOptions, schema, iterator)
})

new NebulaWriterResultRelation(sqlContext, data.schema)
}

/**
* construct nebula options with DataSourceOptions
*/
def getNebulaOptions(options: Map[String, String],
operateType: OperaType.Value): NebulaOptions = {
val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType)
nebulaOptions
}

private def savePartition(nebulaOptions: NebulaOptions,
schema: StructType,
iterator: Iterator[Row]): Unit = {
val dataType = nebulaOptions.dataType
val writer: NebulaWriter = {
if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) {
val vertexFiled = nebulaOptions.vertexField
val vertexIndex: Int = {
Expand Down Expand Up @@ -128,28 +149,15 @@ class NebulaDataSource
edgeFieldsIndex._3,
schema).asInstanceOf[NebulaWriter]
}

val wc: (TaskContext, Iterator[Row]) => NebulaCommitMessage = writer.writeData()
val rdd = data.rdd
val commitMessages = sqlContext.sparkContext.runJob(rdd, wc)

LOG.info(s"runJob finished...${commitMessages.length}")
for (msg <- commitMessages) {
if (msg.executeStatements.nonEmpty) {
LOG.error(s"failed execs:\n ${msg.executeStatements.toString()}")
} else {
LOG.info(s"execs for spark partition ${msg.partitionId} all succeed")
}
}
new NebulaWriterResultRelation(sqlContext, data.schema)
}
val message = writer.writeData(iterator)
LOG.debug(
s"spark partition id ${message.partitionId} write failed size: ${message.executeStatements.length}")
if (message.executeStatements.nonEmpty) {
LOG.error(s"failed execs:\n ${message.executeStatements.toString()}")
} else {
LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed")
}

/**
* construct nebula options with DataSourceOptions
*/
def getNebulaOptions(options: Map[String, String],
operateType: OperaType.Value): NebulaOptions = {
val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType)
nebulaOptions
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand All @@ -7,9 +7,7 @@ package com.vesoft.nebula.connector

import com.vesoft.nebula.connector.ssl.SSLSignType
import com.vesoft.nebula.connector.writer.NebulaExecutor
import org.apache.commons.codec.digest.MurmurHash2
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.{
DataFrame,
DataFrameReader,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down Expand Up @@ -48,6 +48,19 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions,

prepareSpace()

override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = {
while (iterator.hasNext) {
val internalRow = rowEncoder.toRow(iterator.next())
write(internalRow)
}
if (edges.nonEmpty) {
execute()
}
graphProvider.close()
metaProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}

/**
* write one edge record to buffer
*/
Expand Down Expand Up @@ -93,17 +106,4 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions,
edges.clear()
submit(exec)
}

override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage =
(context, iterRow) => {
while (iterRow.hasNext) {
val internalRow = rowEncoder.toRow(iterRow.next())
write(internalRow)
}
if (edges.nonEmpty) {
execute()
}
graphProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down Expand Up @@ -40,6 +40,19 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema:

prepareSpace()

override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = {
while (iterator.hasNext) {
val internalRow = rowEncoder.toRow(iterator.next())
write(internalRow)
}
if (vertices.nonEmpty) {
execute()
}
graphProvider.close()
metaProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}

/**
* write one vertex row to buffer
*/
Expand Down Expand Up @@ -78,17 +91,4 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema:
vertices.clear()
submit(exec)
}

override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage =
(context, iterRow) => {
while (iterRow.hasNext) {
val internalRow = rowEncoder.toRow(iterRow.next())
write(internalRow)
}
if (vertices.nonEmpty) {
execute()
}
graphProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down Expand Up @@ -68,5 +68,7 @@ abstract class NebulaWriter(nebulaOptions: NebulaOptions, schema: StructType) ex

def write(row: InternalRow): Unit

def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage
/** write dataframe data into nebula for each partition */
def writeData(iterator: Iterator[Row]): NebulaCommitMessage

}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
Expand Down
Loading

0 comments on commit d93cdc3

Please sign in to comment.