diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index aae3c939..8d978501 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -23,12 +23,14 @@ import org.apache.spark.sql.sources.{ DataSourceRegister, RelationProvider } +import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory class NebulaDataSource extends RelationProvider with CreatableRelationProvider - with DataSourceRegister { + with DataSourceRegister + with Serializable { private val LOG = LoggerFactory.getLogger(this.getClass) /** @@ -58,7 +60,6 @@ class NebulaDataSource data: DataFrame): BaseRelation = { val nebulaOptions = getNebulaOptions(parameters, OperaType.WRITE) - val dataType = nebulaOptions.dataType if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { LOG.warn(s"Currently do not support mode") } @@ -67,7 +68,27 @@ class NebulaDataSource LOG.info(s"options ${parameters}") val schema = data.schema - val writer: NebulaWriter = + data.foreachPartition(iterator => { + savePartition(nebulaOptions, schema, iterator) + }) + + new NebulaWriterResultRelation(sqlContext, data.schema) + } + + /** + * construct nebula options with DataSourceOptions + */ + def getNebulaOptions(options: Map[String, String], + operateType: OperaType.Value): NebulaOptions = { + val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType) + nebulaOptions + } + + private def savePartition(nebulaOptions: NebulaOptions, + schema: StructType, + iterator: Iterator[Row]): Unit = { + val dataType = nebulaOptions.dataType + val writer: NebulaWriter = { if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { val vertexFiled = nebulaOptions.vertexField val vertexIndex: Int = { @@ -128,28 +149,15 @@ class NebulaDataSource edgeFieldsIndex._3, schema).asInstanceOf[NebulaWriter] } - - val wc: (TaskContext, Iterator[Row]) => NebulaCommitMessage = writer.writeData() - val rdd = data.rdd - val commitMessages = sqlContext.sparkContext.runJob(rdd, wc) - - LOG.info(s"runJob finished...${commitMessages.length}") - for (msg <- commitMessages) { - if (msg.executeStatements.nonEmpty) { - LOG.error(s"failed execs:\n ${msg.executeStatements.toString()}") - } else { - LOG.info(s"execs for spark partition ${msg.partitionId} all succeed") - } } - new NebulaWriterResultRelation(sqlContext, data.schema) - } + val message = writer.writeData(iterator) + LOG.debug( + s"spark partition id ${message.partitionId} write failed size: ${message.executeStatements.length}") + if (message.executeStatements.nonEmpty) { + LOG.error(s"failed execs:\n ${message.executeStatements.toString()}") + } else { + LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") + } - /** - * construct nebula options with DataSourceOptions - */ - def getNebulaOptions(options: Map[String, String], - operateType: OperaType.Value): NebulaOptions = { - val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType) - nebulaOptions } } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala index 8d75431b..a849f5b1 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala @@ -48,6 +48,19 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions, prepareSpace() + override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = { + while (iterator.hasNext) { + val internalRow = rowEncoder.toRow(iterator.next()) + write(internalRow) + } + if (edges.nonEmpty) { + execute() + } + graphProvider.close() + metaProvider.close() + NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) + } + /** * write one edge record to buffer */ @@ -93,17 +106,4 @@ class NebulaEdgeWriter(nebulaOptions: NebulaOptions, edges.clear() submit(exec) } - - override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage = - (context, iterRow) => { - while (iterRow.hasNext) { - val internalRow = rowEncoder.toRow(iterRow.next()) - write(internalRow) - } - if (edges.nonEmpty) { - execute() - } - graphProvider.close() - NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) - } } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala index 838ac198..28c009a8 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala @@ -40,18 +40,18 @@ class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: prepareSpace() - override def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage = - (context, iterRow) => { - while (iterRow.hasNext) { - val internalRow = rowEncoder.toRow(iterRow.next()) - write(internalRow) - } - if (vertices.nonEmpty) { - execute() - } - graphProvider.close() - NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) + override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = { + while (iterator.hasNext) { + val internalRow = rowEncoder.toRow(iterator.next()) + write(internalRow) } + if (vertices.nonEmpty) { + execute() + } + graphProvider.close() + metaProvider.close() + NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) + } /** * write one vertex row to buffer diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala index 764f4f50..ca3f3725 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala @@ -68,5 +68,7 @@ abstract class NebulaWriter(nebulaOptions: NebulaOptions, schema: StructType) ex def write(row: InternalRow): Unit - def writeData(): (TaskContext, Iterator[Row]) => NebulaCommitMessage + /** write dataframe data into nebula for each partition */ + def writeData(iterator: Iterator[Row]): NebulaCommitMessage + }