diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala index b54cf954..aa4078f3 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala @@ -164,4 +164,25 @@ object NebulaSparkReaderExample { vertex.show(20) println("vertex count: " + vertex.count()) } + + def readEdgeWithNgql(spark: SparkSession): Unit = { + LOG.info("start to read nebula edge with ngql") + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaReadConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test") + .withLabel("friend") + .withNgql("match (v)-[e:friend]-(v2) return e") + .build() + val edge = spark.read.nebula(config, nebulaReadConfig).loadEdgesToDfByNgql() + edge.printSchema() + edge.show(20) + println("veedgertex count: " + edge.count()) + } } diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala index 9e8a0543..f6fa9629 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala @@ -670,20 +670,45 @@ object WriteNebulaEdgeConfig { * you can set noColumn to true to read no vertex col, and you can set returnCols to read the specific cols, if the returnCols is empty, then read all the columns. * you can set partitionNum to define spark partition nums to read nebula graph. */ -class ReadNebulaConfig(space: String, - label: String, - returnCols: List[String], - noColumn: Boolean, - partitionNum: Int, - limit: Int) - extends Serializable { - def getSpace = space - def getLabel = label - def getReturnCols = returnCols - def getNoColumn = noColumn - def getPartitionNum = partitionNum - def getLimit = limit +class ReadNebulaConfig extends Serializable { + var getSpace: String = _ + var getLabel: String = _ + var getReturnCols: List[String] = _ + var getNoColumn: Boolean = _ + var getPartitionNum: Int = _ + var getLimit: Int = _ + var getNgql: String = _ // todo add filter + def this(space: String, + label: String, + returnCols: List[String], + noColumn: Boolean, + partitionNum: Int, + limit: Int) = { + this() + this.getSpace = space + this.getLabel = label + this.getReturnCols = returnCols + this.getNoColumn = noColumn + this.getPartitionNum = partitionNum + this.getLimit = limit + } + + def this(space: String, + label: String, + returnCols: List[String], + noColumn: Boolean, + ngql: String, + limit: Int) = { + this() + this.getNgql = ngql + this.getSpace = space + this.getLabel = label + this.getReturnCols = returnCols + this.getNoColumn = noColumn + this.getLimit = limit + this.getPartitionNum = 1 + } } /** @@ -699,6 +724,7 @@ object ReadNebulaConfig { var noColumn: Boolean = false var partitionNum: Int = 100 var limit: Int = 1000 + var ngql: String = _ def withSpace(space: String): ReadConfigBuilder = { this.space = space @@ -740,9 +766,18 @@ object ReadNebulaConfig { this } + def withNgql(ngql: String): ReadConfigBuilder = { + this.ngql = ngql + this + } + def build(): ReadNebulaConfig = { check() - new ReadNebulaConfig(space, label, returnCols.toList, noColumn, partitionNum, limit) + if (ngql != null && !ngql.isEmpty) { + new ReadNebulaConfig(space, label, returnCols.toList, noColumn, ngql, limit) + } else { + new ReadNebulaConfig(space, label, returnCols.toList, noColumn, partitionNum, limit) + } } private def check(): Unit = { diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala index a2026d97..21195c5c 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala @@ -108,11 +108,19 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( var partitionNums: String = _ var noColumn: Boolean = _ var limit: Int = _ + var ngql: String = _ if (operaType == OperaType.READ) { returnCols = parameters(RETURN_COLS) noColumn = parameters.getOrElse(NO_COLUMN, false).toString.toBoolean partitionNums = parameters(PARTITION_NUMBER) limit = parameters.getOrElse(LIMIT, DEFAULT_LIMIT).toString.toInt + ngql = parameters.getOrElse(NGQL,EMPTY_STRING) + ngql = parameters.getOrElse(NGQL,EMPTY_STRING) + if(ngql!=EMPTY_STRING){ + require(parameters.isDefinedAt(GRAPH_ADDRESS), + s"option $GRAPH_ADDRESS is required for ngql and can not be blank") + graphAddress = parameters(GRAPH_ADDRESS) + } } /** write parameters */ @@ -235,6 +243,9 @@ object NebulaOptions { val PARTITION_NUMBER: String = "partitionNumber" val LIMIT: String = "limit" + /** read by ngql **/ + val NGQL: String = "ngql" + /** write config */ val RATE_LIMIT: String = "rateLimit" val VID_POLICY: String = "vidPolicy" diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index 307441de..c298a890 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -9,7 +9,11 @@ import java.util.Map.Entry import java.util.Optional import com.vesoft.nebula.connector.exception.IllegalOptionException -import com.vesoft.nebula.connector.reader.{NebulaDataSourceEdgeReader, NebulaDataSourceVertexReader} +import com.vesoft.nebula.connector.reader.{ + NebulaDataSourceEdgeReader, + NebulaDataSourceNgqlEdgeReader, + NebulaDataSourceVertexReader +} import com.vesoft.nebula.connector.writer.{NebulaDataSourceEdgeWriter, NebulaDataSourceVertexWriter} import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -46,6 +50,8 @@ class NebulaDataSource if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { new NebulaDataSourceVertexReader(nebulaOptions) + } else if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) { + new NebulaDataSourceNgqlEdgeReader(nebulaOptions) } else { new NebulaDataSourceEdgeReader(nebulaOptions) } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala index 78a6b5ed..e9f122c1 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -110,6 +110,45 @@ package object connector { dfReader.load() } + /** + * Reading edges from Nebula Graph by ngql + * @return DataFrame + */ + def loadEdgesToDfByNgql(): DataFrame = { + assert(connectionConfig != null && readConfig != null, + "nebula config is not set, please call nebula() before loadEdgesToDfByNgql") + + val dfReader = reader + .format(classOf[NebulaDataSource].getName) + .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) + .option(NebulaOptions.LABEL, readConfig.getLabel) + .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) + .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn) + .option(NebulaOptions.LIMIT, readConfig.getLimit) + .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) + .option(NebulaOptions.NGQL, readConfig.getNgql) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() + } + /** * read nebula vertex edge to graphx's vertex * use hash() for String type vertex id. diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgePartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgePartitionReader.scala new file mode 100644 index 00000000..c6da1078 --- /dev/null +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgePartitionReader.scala @@ -0,0 +1,158 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.Value +import com.vesoft.nebula.client.graph.data.{Relationship, ResultSet, ValueWrapper} +import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter +import com.vesoft.nebula.connector.nebula.GraphProvider +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConversions.asScalaBuffer +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * create reader by ngql + */ +class NebulaNgqlEdgePartitionReader extends InputPartitionReader[InternalRow] { + + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var nebulaOptions: NebulaOptions = _ + private var graphProvider: GraphProvider = _ + private var schema: StructType = _ + private var resultSet: ResultSet = _ + private var edgeIterator: Iterator[ListBuffer[ValueWrapper]] = _ + + def this(nebulaOptions: NebulaOptions, schema: StructType) { + this() + this.schema = schema + this.nebulaOptions = nebulaOptions + this.graphProvider = new GraphProvider( + nebulaOptions.getGraphAddress, + nebulaOptions.timeout, + nebulaOptions.enableGraphSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + // add exception when session build failed + graphProvider.switchSpace(nebulaOptions.user, nebulaOptions.passwd, nebulaOptions.spaceName) + resultSet = graphProvider.submit(nebulaOptions.ngql) + edgeIterator = query() + } + + def query(): Iterator[ListBuffer[ValueWrapper]] = { + val edges: ListBuffer[ListBuffer[ValueWrapper]] = new ListBuffer[ListBuffer[ValueWrapper]] + val properties = nebulaOptions.getReturnCols + for (i <- 0 until resultSet.rowsSize()) { + val rowValues = resultSet.rowValues(i).values() + for (j <- 0 until rowValues.size()) { + val value = rowValues.get(j) + val valueType = value.getValue.getSetField + if (valueType == Value.EVAL) { + val relationship = value.asRelationship() + if (checkLabel(relationship)) { + edges.append(convertToEdge(relationship, properties)) + } + } else if (valueType == Value.LVAL) { + val list: mutable.Buffer[ValueWrapper] = value.asList() + edges.appendAll( + list.toStream + .filter(e => checkLabel(e.asRelationship())) + .map(e => convertToEdge(e.asRelationship(), properties)) + ) + } else { + LOG.error(s"Exception convert edge type ${valueType} ") + throw new RuntimeException(" convert value type failed"); + } + } + } + edges.iterator + } + + def checkLabel(relationship: Relationship): Boolean = { + this.nebulaOptions.label.equals(relationship.edgeName()) + } + + def convertToEdge(relationship: Relationship, + properties: List[String]): ListBuffer[ValueWrapper] = { + val edge: ListBuffer[ValueWrapper] = new ListBuffer[ValueWrapper] + edge.append(relationship.srcId()) + edge.append(relationship.dstId()) + edge.append(new ValueWrapper(new Value(3, relationship.ranking()), "utf-8")) + if (properties == null || properties.isEmpty) + return edge + else { + for (i <- properties.indices) { + edge.append(relationship.properties().get(properties(i))) + } + } + edge + } + + override def next(): Boolean = { + edgeIterator.hasNext + } + + override def get(): InternalRow = { + val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) + val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) + + val edge = edgeIterator.next(); + for (i <- getters.indices) { + val value: ValueWrapper = edge(i) + var resolved = false + if (value.isNull) { + mutableRow.setNullAt(i) + resolved = true + } + if (value.isString) { + getters(i).apply(value.asString(), mutableRow, i) + resolved = true + } + if (value.isDate) { + getters(i).apply(value.asDate(), mutableRow, i) + resolved = true + } + if (value.isTime) { + getters(i).apply(value.asTime(), mutableRow, i) + resolved = true + } + if (value.isDateTime) { + getters(i).apply(value.asDateTime(), mutableRow, i) + resolved = true + } + if (value.isLong) { + getters(i).apply(value.asLong(), mutableRow, i) + } + if (value.isBoolean) { + getters(i).apply(value.asBoolean(), mutableRow, i) + } + if (value.isDouble) { + getters(i).apply(value.asDouble(), mutableRow, i) + } + if (value.isGeography) { + getters(i).apply(value.asGeography(), mutableRow, i) + } + if (value.isDuration) { + getters(i).apply(value.asDuration(), mutableRow, i) + } + } + mutableRow + + } + + override def close(): Unit = { + graphProvider.close(); + } +} diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala index b702be89..2b48794b 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala @@ -21,3 +21,9 @@ class NebulaEdgePartition(index: Int, nebulaOptions: NebulaOptions, schema: Stru override def createPartitionReader(): InputPartitionReader[InternalRow] = new NebulaEdgePartitionReader(index, nebulaOptions, schema) } + +class NebulaNgqlEdgePartition(nebulaOptions: NebulaOptions, schema: StructType) + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new NebulaNgqlEdgePartitionReader(nebulaOptions, schema) +} \ No newline at end of file diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala index ff2a43f2..0ca8ee54 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala @@ -131,3 +131,16 @@ class NebulaDataSourceEdgeReader(nebulaOptions: NebulaOptions) partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava } } + +/** + * DataSourceReader for Nebula Edge by ngql + */ +class NebulaDataSourceNgqlEdgeReader(nebulaOptions: NebulaOptions) + extends NebulaSourceReader(nebulaOptions) { + + override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { + val partitions = new util.ArrayList[InputPartition[InternalRow]]() + partitions.add(new NebulaNgqlEdgePartition(nebulaOptions, getSchema)) + partitions + } +} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala index e840a596..21c59f77 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -173,6 +173,45 @@ package object connector { dfReader.load() } + /** + * Reading edges from Nebula Graph by ngql + * @return DataFrame + */ + def loadEdgesToDfByNgql(): DataFrame = { + assert(connectionConfig != null && readConfig != null, + "nebula config is not set, please call nebula() before loadEdgesToDfByNgql") + + val dfReader = reader + .format(classOf[NebulaDataSource].getName) + .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) + .option(NebulaOptions.LABEL, readConfig.getLabel) + .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) + .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn) + .option(NebulaOptions.LIMIT, readConfig.getLimit) + .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) + .option(NebulaOptions.NGQL, readConfig.getNgql) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() + } + /** * read nebula vertex edge to graphx's vertex * use hash() for String type vertex id. diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgeReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgeReader.scala new file mode 100644 index 00000000..3e6a1aad --- /dev/null +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgeReader.scala @@ -0,0 +1,160 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import java.util + +import com.vesoft.nebula.Value +import com.vesoft.nebula.client.graph.data.{Relationship, ResultSet, ValueWrapper} +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} +import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter +import com.vesoft.nebula.connector.nebula.GraphProvider +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConversions.asScalaBuffer +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * create reader by ngql + */ +class NebulaNgqlEdgeReader extends Iterator[InternalRow] { + + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var nebulaOptions: NebulaOptions = _ + private var graphProvider: GraphProvider = _ + private var schema: StructType = _ + private var resultSet: ResultSet = _ + private var edgeIterator: Iterator[ListBuffer[ValueWrapper]] = _ + + def this(nebulaOptions: NebulaOptions, schema: StructType) { + this() + this.schema = schema + this.nebulaOptions = nebulaOptions + this.graphProvider = new GraphProvider( + nebulaOptions.getGraphAddress, + nebulaOptions.timeout, + nebulaOptions.enableGraphSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + // add exception when session build failed + graphProvider.switchSpace(nebulaOptions.user, nebulaOptions.passwd, nebulaOptions.spaceName) + resultSet = graphProvider.submit(nebulaOptions.ngql) + close() + edgeIterator = query() + } + + def query(): Iterator[ListBuffer[ValueWrapper]] = { + val edges: ListBuffer[ListBuffer[ValueWrapper]] = new ListBuffer[ListBuffer[ValueWrapper]] + val properties = nebulaOptions.getReturnCols + for (i <- 0 until resultSet.rowsSize()) { + val rowValues = resultSet.rowValues(i).values() + for (j <- 0 until rowValues.size()) { + val value = rowValues.get(j) + val valueType = value.getValue.getSetField + if (valueType == Value.EVAL) { + val relationship = value.asRelationship() + if (checkLabel(relationship)) { + edges.append(convertToEdge(relationship, properties)) + } + } else if (valueType == Value.LVAL) { + val list: mutable.Buffer[ValueWrapper] = value.asList() + edges.appendAll( + list.toStream + .filter(e => checkLabel(e.asRelationship())) + .map(e => convertToEdge(e.asRelationship(), properties)) + ) + } else { + LOG.error(s"Exception convert edge type ${valueType} ") + throw new RuntimeException(" convert value type failed"); + } + } + } + edges.iterator + } + + def checkLabel(relationship: Relationship): Boolean = { + this.nebulaOptions.label.equals(relationship.edgeName()) + } + + def convertToEdge(relationship: Relationship, + properties: List[String]): ListBuffer[ValueWrapper] = { + val edge: ListBuffer[ValueWrapper] = new ListBuffer[ValueWrapper] + edge.append(relationship.srcId()) + edge.append(relationship.dstId()) + edge.append(new ValueWrapper(new Value(3, relationship.ranking()), "utf-8")) + if (properties == null || properties.isEmpty) + return edge + else { + for (i <- properties.indices) { + edge.append(relationship.properties().get(properties(i))) + } + } + edge + } + + override def hasNext(): Boolean = { + edgeIterator.hasNext + } + + override def next(): InternalRow = { + val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) + val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) + + val edge = edgeIterator.next(); + for (i <- getters.indices) { + val value: ValueWrapper = edge(i) + var resolved = false + if (value.isNull) { + mutableRow.setNullAt(i) + resolved = true + } + if (value.isString) { + getters(i).apply(value.asString(), mutableRow, i) + resolved = true + } + if (value.isDate) { + getters(i).apply(value.asDate(), mutableRow, i) + resolved = true + } + if (value.isTime) { + getters(i).apply(value.asTime(), mutableRow, i) + resolved = true + } + if (value.isDateTime) { + getters(i).apply(value.asDateTime(), mutableRow, i) + resolved = true + } + if (value.isLong) { + getters(i).apply(value.asLong(), mutableRow, i) + } + if (value.isBoolean) { + getters(i).apply(value.asBoolean(), mutableRow, i) + } + if (value.isDouble) { + getters(i).apply(value.asDouble(), mutableRow, i) + } + if (value.isGeography) { + getters(i).apply(value.asGeography(), mutableRow, i) + } + if (value.isDuration) { + getters(i).apply(value.asDuration(), mutableRow, i) + } + } + mutableRow + + } + + def close(): Unit = { + graphProvider.close(); + } +} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlRDD.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlRDD.scala new file mode 100644 index 00000000..20baabb1 --- /dev/null +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlRDD.scala @@ -0,0 +1,46 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +import scala.collection.mutable.ListBuffer + +class NebulaNgqlRDD(val sqlContext: SQLContext, + var nebulaOptions: NebulaOptions, + schema: StructType) + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + + /** + * start to get edge data from query resultSet + * + * @param split + * @param context + * @return Iterator + */ + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + new NebulaNgqlEdgeReader() + } + + override def getPartitions: Array[Partition] = { + val partitions = new Array[Partition](1) + partitions(0) = NebulaNgqlPartition(0) + partitions + } + +} + +/** + * An identifier for a partition in an NebulaRDD. + */ +case class NebulaNgqlPartition(indexNum: Int) extends Partition { + override def index: Int = indexNum +} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala index f499da41..53733061 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala @@ -94,6 +94,10 @@ case class NebulaRelation(override val sqlContext: SQLContext, nebulaOptions: Ne } override def buildScan(): RDD[Row] = { - new NebulaRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] + if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) { + new NebulaNgqlRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] + } else { + new NebulaRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] + } } }