Skip to content

Commit

Permalink
modify writer
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicole00 committed Aug 31, 2022
1 parent 2410de1 commit cfc673f
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ import org.apache.spark.sql.sources.{
DataSourceRegister,
RelationProvider
}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

class NebulaDataSource
extends RelationProvider
with CreatableRelationProvider
with DataSourceRegister {
with DataSourceRegister
with Serializable {
private val LOG = LoggerFactory.getLogger(this.getClass)

/**
Expand Down Expand Up @@ -58,7 +60,6 @@ class NebulaDataSource
data: DataFrame): BaseRelation = {

val nebulaOptions = getNebulaOptions(parameters, OperaType.WRITE)
val dataType = nebulaOptions.dataType
if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
LOG.warn(s"Currently do not support mode")
}
Expand All @@ -67,7 +68,27 @@ class NebulaDataSource
LOG.info(s"options ${parameters}")

val schema = data.schema
val writer: NebulaWriter =
data.foreachPartition(iterator => {
savePartition(nebulaOptions, schema, iterator)
})

new NebulaWriterResultRelation(sqlContext, data.schema)
}

/**
* construct nebula options with DataSourceOptions
*/
def getNebulaOptions(options: Map[String, String],
operateType: OperaType.Value): NebulaOptions = {
val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType)
nebulaOptions
}

private def savePartition(nebulaOptions: NebulaOptions,
schema: StructType,
iterator: Iterator[Row]): Unit = {
val dataType = nebulaOptions.dataType
val writer: NebulaWriter = {
if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) {
val vertexFiled = nebulaOptions.vertexField
val vertexIndex: Int = {
Expand Down Expand Up @@ -128,28 +149,15 @@ class NebulaDataSource
edgeFieldsIndex._3,
schema).asInstanceOf[NebulaWriter]
}

val wc: (TaskContext, Iterator[Row]) => NebulaCommitMessage = writer.writeData()
val rdd = data.rdd
val commitMessages = sqlContext.sparkContext.runJob(rdd, wc)

LOG.info(s"runJob finished...${commitMessages.length}")
for (msg <- commitMessages) {
if (msg.executeStatements.nonEmpty) {
LOG.error(s"failed execs:\n ${msg.executeStatements.toString()}")
} else {
LOG.info(s"execs for spark partition ${msg.partitionId} all succeed")
}
}
new NebulaWriterResultRelation(sqlContext, data.schema)
}
val message = writer.writeData(iterator)
LOG.debug(
s"spark partition id ${message.partitionId} write failed size: ${message.executeStatements.length}")
if (message.executeStatements.nonEmpty) {
LOG.error(s"failed execs:\n ${message.executeStatements.toString()}")
} else {
LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed")
}

/**
* construct nebula options with DataSourceOptions
*/
def getNebulaOptions(options: Map[String, String],
operateType: OperaType.Value): NebulaOptions = {
val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType)
nebulaOptions
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions,

prepareSpace()

override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = {
while (iterator.hasNext) {
val internalRow = rowEncoder.toRow(iterator.next())
write(internalRow)
}
if (edges.nonEmpty) {
execute()
}
graphProvider.close()
metaProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}

/**
* write one edge record to buffer
*/
Expand Down Expand Up @@ -93,17 +106,4 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions,
edges.clear()
submit(exec)
}

override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage =
(context, iterRow) => {
while (iterRow.hasNext) {
val internalRow = rowEncoder.toRow(iterRow.next())
write(internalRow)
}
if (edges.nonEmpty) {
execute()
}
graphProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema:

prepareSpace()

override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage =
(context, iterRow) => {
while (iterRow.hasNext) {
val internalRow = rowEncoder.toRow(iterRow.next())
write(internalRow)
}
if (vertices.nonEmpty) {
execute()
}
graphProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = {
while (iterator.hasNext) {
val internalRow = rowEncoder.toRow(iterator.next())
write(internalRow)
}
if (vertices.nonEmpty) {
execute()
}
graphProvider.close()
metaProvider.close()
NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList)
}

/**
* write one vertex row to buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,7 @@ abstract class NebulaWriter(nebulaOptions: NebulaOptions, schema: StructType) ex

def write(row: InternalRow): Unit

def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage
/** write dataframe data into nebula for each partition */
def writeData(iterator: Iterator[Row]): NebulaCommitMessage

}

0 comments on commit cfc673f

Please sign in to comment.