Skip to content

Commit

Permalink
support ssl for spark connector (#18)
Browse files Browse the repository at this point in the history
* support ssl for connector

* support ssl for writer and reader

* add test for ssl config

* add ssl example

* update copyright
  • Loading branch information
Nicole00 authored Nov 23, 2021
1 parent 3124ea9 commit 43d1158
Show file tree
Hide file tree
Showing 18 changed files with 310 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package com.vesoft.nebula.examples.connector

import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.connector.connector.NebulaDataFrameReader
import com.vesoft.nebula.connector.ssl.SSLSignType
import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -132,4 +133,35 @@ object NebulaSparkReaderExample {
LOG.info("edge rdd count: {}", edgeRDD.count())
}

/**
* read Nebula vertex with SSL
*/
def readVertexWithSSL(spark: SparkSession): Unit = {
LOG.info("start to read nebula vertices with ssl")
val config =
NebulaConnectionConfig
.builder()
.withMetaAddress("127.0.0.1:9559")
.withEnableMetaSSL(true)
.withEnableStorageSSL(true)
.withSSLSignType(SSLSignType.CA)
.withCaSSLSignParam("example/src/main/resources/ssl/casigned.pem",
"example/src/main/resources/ssl/casigned.crt",
"example/src/main/resources/ssl/casigned.key")
.withConenctionRetry(2)
.build()
val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
.builder()
.withSpace("test")
.withLabel("person")
.withNoColumn(false)
.withReturnCols(List("birthday"))
.withLimit(10)
.withPartitionNum(10)
.build()
val vertex = spark.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
vertex.printSchema()
vertex.show(20)
println("vertex count: " + vertex.count())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ package com.vesoft.nebula.examples.connector
import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.connector.{
NebulaConnectionConfig,
SSLSignType,
WriteMode,
WriteNebulaEdgeConfig,
WriteNebulaVertexConfig
}
import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter
import com.vesoft.nebula.connector.ssl.SSLSignType
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -63,6 +63,7 @@ object NebulaSparkWriterExample {
.withMetaAddress("127.0.0.1:9559")
.withGraphAddress("127.0.0.1:9669")
.withConenctionRetry(2)
.withEnableMetaSSL(true)
.withEnableGraphSSL(true)
.withSSLSignType(SSLSignType.CA)
.withCaSSLSignParam("example/src/main/resources/ssl/casigned.pem",
Expand All @@ -77,6 +78,7 @@ object NebulaSparkWriterExample {
.withMetaAddress("127.0.0.1:9559")
.withGraphAddress("127.0.0.1:9669")
.withConenctionRetry(2)
.withEnableMetaSSL(true)
.withEnableGraphSSL(true)
.withSSLSignType(SSLSignType.SELF)
.withSelfSSLSignParam("example/src/main/resources/ssl/selfsigned.pem",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

package com.vesoft.nebula.connector

import com.vesoft.nebula.client.graph.data.{CASignedSSLParam, SelfSignedSSLParam}
import com.vesoft.nebula.connector.NebulaConnectionConfig.ConfigBuilder
import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.ListBuffer
Expand All @@ -20,8 +19,8 @@ class NebulaConnectionConfig(metaAddress: String,
enableGraphSSL: Boolean,
enableStorageSSL: Boolean,
signType: SSLSignType.Value,
caSignParam: CASignedSSLParam,
selfSignParam: SelfSignedSSLParam)
caSignParam: CASSLSignParams,
selfSignParam: SelfSSLSignParams)
extends Serializable {
def getMetaAddress = metaAddress
def getGraphAddress = graphAddress
Expand All @@ -31,9 +30,13 @@ class NebulaConnectionConfig(metaAddress: String,
def getEnableMetaSSL = enableMetaSSL
def getEnableGraphSSL = enableGraphSSL
def getEnableStorageSSL = enableStorageSSL
def getSignType = signType
def getCaSignParam = caSignParam
def getSelfSignParam = selfSignParam
def getSignType = signType.toString
def getCaSignParam: String = {
caSignParam.caCrtFilePath + "," + caSignParam.crtFilePath + "," + caSignParam.keyFilePath
}
def getSelfSignParam: String = {
selfSignParam.crtFilePath + "," + selfSignParam.keyFilePath + "," + selfSignParam.password
}
}

object NebulaConnectionConfig {
Expand All @@ -46,12 +49,12 @@ object NebulaConnectionConfig {
protected var connectionRetry: Int = 1
protected var executeRetry: Int = 1

protected var enableMetaSSL: Boolean = false
protected var enableGraphSSL: Boolean = false
protected var enableStorageSSL: Boolean = false
protected var sslSignType: SSLSignType.Value = _
protected var caSignParam: CASignedSSLParam = null
protected var selfSignParam: SelfSignedSSLParam = null
protected var enableMetaSSL: Boolean = false
protected var enableGraphSSL: Boolean = false
protected var enableStorageSSL: Boolean = false
protected var sslSignType: SSLSignType.Value = _
protected var caSignParam: CASSLSignParams = null
protected var selfSignParam: SelfSSLSignParams = null

def withMetaAddress(metaAddress: String): ConfigBuilder = {
this.metaAddress = metaAddress
Expand Down Expand Up @@ -91,8 +94,7 @@ object NebulaConnectionConfig {
* set enableMetaSSL, enableMetaSSL is optional
*/
def withEnableMetaSSL(enableMetaSSL: Boolean): ConfigBuilder = {
LOG.warn("metaSSL is not supported yet.")
this.enableMetaSSL = false
this.enableMetaSSL = enableMetaSSL
this
}

Expand All @@ -108,8 +110,7 @@ object NebulaConnectionConfig {
* set enableStorageSSL, enableStorageSSL is optional
*/
def withEnableStorageSSL(enableStorageSSL: Boolean): ConfigBuilder = {
LOG.warn("storageSSL is not supported yet.")
this.enableStorageSSL = false
this.enableStorageSSL = enableStorageSSL
this
}

Expand All @@ -127,8 +128,7 @@ object NebulaConnectionConfig {
def withCaSSLSignParam(caCrtFilePath: String,
crtFilePath: String,
keyFilePath: String): ConfigBuilder = {
val caSignParam = new CASignedSSLParam(caCrtFilePath, crtFilePath, keyFilePath)
this.caSignParam = caSignParam
this.caSignParam = CASSLSignParams(caCrtFilePath, crtFilePath, keyFilePath)
this
}

Expand All @@ -138,8 +138,7 @@ object NebulaConnectionConfig {
def withSelfSSLSignParam(crtFilePath: String,
keyFilePath: String,
password: String): ConfigBuilder = {
val selfSignParam = new SelfSignedSSLParam(crtFilePath, keyFilePath, password)
this.selfSignParam = selfSignParam
this.selfSignParam = SelfSSLSignParams(crtFilePath, keyFilePath, password)
this
}

Expand All @@ -156,22 +155,21 @@ object NebulaConnectionConfig {
// check ssl param
if (enableMetaSSL || enableGraphSSL || enableStorageSSL) {
assert(
(enableStorageSSL && enableMetaSSL && enableGraphSSL)
|| (!enableStorageSSL && !enableMetaSSL && enableGraphSSL),
"ssl priority order: storage > meta > graph " +
"please make sure graph ssl is enable when storage and meta ssl is enable."
!enableStorageSSL || enableStorageSSL && enableMetaSSL,
"ssl priority order: storage > meta = graph " +
"please make sure meta ssl is enabled when storage ssl is enabled."
)
sslSignType match {
case SSLSignType.CA =>
assert(
caSignParam != null && caSignParam.getCaCrtFilePath != null
&& caSignParam.getCrtFilePath != null && caSignParam.getKeyFilePath != null,
caSignParam != null && caSignParam.caCrtFilePath != null
&& caSignParam.crtFilePath != null && caSignParam.keyFilePath != null,
"ssl sign type is CA, param can not be null"
)
case SSLSignType.SELF =>
assert(
selfSignParam != null && selfSignParam.getCrtFilePath != null
&& selfSignParam.getKeyFilePath != null && selfSignParam.getPassword != null,
selfSignParam != null && selfSignParam.crtFilePath != null
&& selfSignParam.keyFilePath != null && selfSignParam.password != null,
"ssl sign type is SELF, param can not be null"
)
case _ => assert(false, "SSLSignType config is null")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ object WriteMode extends Enumeration {
val UPDATE = Value("update")
val DELETE = Value("delete")
}

object SSLSignType extends Enumeration {

type signType = Value
val CA = Value("ca")
val SELF = Value("self")
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ package com.vesoft.nebula.connector
import java.util.Properties

import com.google.common.net.HostAndPort
import com.vesoft.nebula.client.graph.data.{CASignedSSLParam, SelfSignedSSLParam}
import com.vesoft.nebula.connector.connector.Address
import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams}
import org.apache.commons.lang.StringUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
Expand Down Expand Up @@ -67,19 +67,21 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
parameters.getOrElse(ENABLE_GRAPH_SSL, DEFAULT_ENABLE_GRAPH_SSL).toString.toBoolean
val enableMetaSSL: Boolean =
parameters.getOrElse(ENABLE_META_SSL, DEFAULT_ENABLE_META_SSL).toString.toBoolean
var sslSignType: String = null
var caSignParam: CASignedSSLParam = null
var selfSignParam: SelfSignedSSLParam = null
val enableStorageSSL: Boolean =
parameters.getOrElse(ENABLE_STORAGE_SSL, DEFAULT_ENABLE_STORAGE_SSL).toString.toBoolean
var sslSignType: String = _
var caSignParam: CASSLSignParams = _
var selfSignParam: SelfSSLSignParams = _
if (enableGraphSSL || enableMetaSSL) {
sslSignType = parameters.get(SSL_SIGN_TYPE).get
SSLSignType.withName(sslSignType) match {
case SSLSignType.CA => {
val params = parameters.get(CA_SIGN_PARAM).get.split(",")
caSignParam = new CASignedSSLParam(params(0), params(1), params(2))
caSignParam = new CASSLSignParams(params(0), params(1), params(2))
}
case SSLSignType.SELF => {
val params = parameters.get(SELF_SIGN_PARAM).get.split(",")
selfSignParam = new SelfSignedSSLParam(params(0), params(1), params(2))
selfSignParam = new SelfSSLSignParams(params(0), params(1), params(2))
}
}
}
Expand Down Expand Up @@ -213,17 +215,18 @@ object NebulaOptions {
val LABEL: String = "label"

/** connection config */
val TIMEOUT: String = "timeout"
val CONNECTION_RETRY: String = "connectionRetry"
val EXECUTION_RETRY: String = "executionRetry"
val RATE_TIME_OUT: String = "reteTimeOut"
val USER_NAME: String = "user"
val PASSWD: String = "passwd"
val ENABLE_GRAPH_SSL: String = "enableGraphSSL"
val ENABLE_META_SSL: String = "enableMetaSSL"
val SSL_SIGN_TYPE: String = "sslSignType"
val CA_SIGN_PARAM: String = "caSignParam"
val SELF_SIGN_PARAM: String = "selfSignParam"
val TIMEOUT: String = "timeout"
val CONNECTION_RETRY: String = "connectionRetry"
val EXECUTION_RETRY: String = "executionRetry"
val RATE_TIME_OUT: String = "reteTimeOut"
val USER_NAME: String = "user"
val PASSWD: String = "passwd"
val ENABLE_GRAPH_SSL: String = "enableGraphSSL"
val ENABLE_META_SSL: String = "enableMetaSSL"
val ENABLE_STORAGE_SSL: String = "enableStorageSSL"
val SSL_SIGN_TYPE: String = "sslSignType"
val CA_SIGN_PARAM: String = "caSignParam"
val SELF_SIGN_PARAM: String = "selfSignParam"

/** read config */
val RETURN_COLS: String = "returnCols"
Expand Down Expand Up @@ -254,8 +257,9 @@ object NebulaOptions {
val DEFAULT_USER_NAME: String = "root"
val DEFAULT_PASSWD: String = "nebula"

val DEFAULT_ENABLE_GRAPH_SSL: Boolean = false
val DEFAULT_ENABLE_META_SSL: Boolean = false
val DEFAULT_ENABLE_GRAPH_SSL: Boolean = false
val DEFAULT_ENABLE_META_SSL: Boolean = false
val DEFAULT_ENABLE_STORAGE_SSL: Boolean = false

val DEFAULT_LIMIT: Int = 1000

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import com.vesoft.nebula.client.graph.data.{
SelfSignedSSLParam
}
import com.vesoft.nebula.client.graph.net.{NebulaPool, Session}
import com.vesoft.nebula.connector.SSLSignType
import com.vesoft.nebula.connector.connector.Address
import com.vesoft.nebula.connector.exception.GraphConnectException
import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams}
import org.apache.log4j.Logger

import scala.collection.JavaConverters._
Expand All @@ -25,10 +25,11 @@ import scala.collection.mutable.ListBuffer
* GraphProvider for Nebula Graph Service
*/
class GraphProvider(addresses: List[Address],
timeout: Int,
enableSSL: Boolean = false,
sslSignType: String = null,
caSignParam: CASignedSSLParam = null,
selfSignParam: SelfSignedSSLParam = null)
caSignParam: CASSLSignParams = null,
selfSignParam: SelfSSLSignParams = null)
extends AutoCloseable
with Serializable {
private[this] lazy val LOG = Logger.getLogger(this.getClass)
Expand All @@ -41,13 +42,22 @@ class GraphProvider(addresses: List[Address],
address.append(new HostAddress(addr._1, addr._2))
}
nebulaPoolConfig.setMaxConnSize(1)
nebulaPoolConfig.setTimeout(timeout)

if (enableSSL) {
nebulaPoolConfig.setEnableSsl(enableSSL)
SSLSignType.withName(sslSignType) match {
case SSLSignType.CA => nebulaPoolConfig.setSslParam(caSignParam)
case SSLSignType.SELF => nebulaPoolConfig.setSslParam(selfSignParam)
case _ => throw new IllegalArgumentException("ssl sign type is not supported")
case SSLSignType.CA =>
nebulaPoolConfig.setSslParam(
new CASignedSSLParam(caSignParam.caCrtFilePath,
caSignParam.crtFilePath,
caSignParam.keyFilePath))
case SSLSignType.SELF =>
nebulaPoolConfig.setSslParam(
new SelfSignedSSLParam(selfSignParam.crtFilePath,
selfSignParam.keyFilePath,
selfSignParam.password))
case _ => throw new IllegalArgumentException("ssl sign type is not supported")
}
}
pool.init(address.asJava, nebulaPoolConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

package com.vesoft.nebula.connector.nebula

import com.vesoft.nebula.client.graph.data.HostAddress
import com.vesoft.nebula.client.graph.data.{
CASignedSSLParam,
HostAddress,
SSLParam,
SelfSignedSSLParam
}
import com.vesoft.nebula.client.meta.MetaClient
import com.vesoft.nebula.connector.connector.Address
import com.vesoft.nebula.connector.DataTypeEnum
import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams}
import com.vesoft.nebula.meta.{PropertyType, Schema}

import scala.collection.JavaConverters._
Expand All @@ -17,11 +23,32 @@ import scala.collection.mutable
class MetaProvider(addresses: List[Address],
timeout: Int,
connectionRetry: Int,
executionRetry: Int)
executionRetry: Int,
enableSSL: Boolean,
sslSignType: String = null,
caSignParam: CASSLSignParams,
selfSignParam: SelfSSLSignParams)
extends AutoCloseable {

val metaAddress = addresses.map(address => new HostAddress(address._1, address._2)).asJava
val client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry)
val metaAddress = addresses.map(address => new HostAddress(address._1, address._2)).asJava
var client: MetaClient = null
var sslParam: SSLParam = null
if (enableSSL) {
SSLSignType.withName(sslSignType) match {
case SSLSignType.CA =>
sslParam = new CASignedSSLParam(caSignParam.caCrtFilePath,
caSignParam.crtFilePath,
caSignParam.keyFilePath)
case SSLSignType.SELF =>
sslParam = new SelfSignedSSLParam(selfSignParam.crtFilePath,
selfSignParam.keyFilePath,
selfSignParam.password)
case _ => throw new IllegalArgumentException("ssl sign type is not supported")
}
client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry, true, sslParam)
} else {
client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry)
}
client.connect()

def getPartitionNumber(space: String): Int = {
Expand Down
Loading

0 comments on commit 43d1158

Please sign in to comment.