Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ssl for metaClient and storageClient #379

Merged
merged 4 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/maven.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ jobs:
cp ../../client/src/test/resources/docker-compose.yaml .
docker-compose up -d
sleep 10
docker-compose ps
popd
popd

- name: Install nebula-graph with CA SSL
run: |
pushd tmp
mkdir ca
pushd ca
cp -r ../../client/src/test/resources/ssl .
cp ../../client/src/test/resources/docker-compose-casigned.yaml .
docker-compose -f docker-compose-casigned.yaml up -d
sleep 30
docker-compose -f docker-compose-casigned.yaml ps
popd
popd

- name: Install nebula-graph with Self SSL
run: |
pushd tmp
mkdir self
pushd self
cp -r ../../client/src/test/resources/ssl .
cp ../../client/src/test/resources/docker-compose-selfsigned.yaml .
docker-compose -f docker-compose-selfsigned.yaml up -d
sleep 30
docker-compose -f docker-compose-selfsigned.yaml ps
popd
popd

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ public boolean ping(HostAddress addr) {
connection.close();
return true;
} catch (IOErrorException | ClientServerIncompatibleException e) {
LOGGER.error("ping failed", e);
return false;
}
}
Expand Down
42 changes: 40 additions & 2 deletions client/src/main/java/com/vesoft/nebula/client/meta/MetaClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import com.google.common.base.Charsets;
import com.vesoft.nebula.ErrorCode;
import com.vesoft.nebula.HostAddr;
import com.vesoft.nebula.client.graph.data.CASignedSSLParam;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.SelfSignedSSLParam;
import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import com.vesoft.nebula.client.meta.exception.ExecuteFailedException;
import com.vesoft.nebula.meta.EdgeItem;
import com.vesoft.nebula.meta.GetEdgeReq;
Expand Down Expand Up @@ -43,12 +47,15 @@
import com.vesoft.nebula.meta.TagItem;
import com.vesoft.nebula.meta.VerifyClientVersionReq;
import com.vesoft.nebula.meta.VerifyClientVersionResp;
import com.vesoft.nebula.util.SslUtil;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -63,6 +70,9 @@ public class MetaClient extends AbstractMetaClient {
private static final int DEFAULT_EXECUTION_RETRY_SIZE = 3;
private static final int RETRY_TIMES = 1;

private boolean enableSSL = false;
private SSLParam sslParam = null;

private MetaService.Client client;
private final List<HostAddress> addresses;

Expand All @@ -88,6 +98,17 @@ public MetaClient(List<HostAddress> addresses, int timeout, int connectionRetry,
this.addresses = addresses;
}

public MetaClient(List<HostAddress> addresses, int timeout, int connectionRetry,
int executionRetry, boolean enableSSL, SSLParam sslParam) {
super(addresses, timeout, connectionRetry, executionRetry);
this.addresses = addresses;
this.enableSSL = enableSSL;
this.sslParam = sslParam;
if (enableSSL && sslParam == null) {
throw new IllegalArgumentException("SSL is enabled, but SSLParam is null.");
}
}

public void connect()
throws TException, ClientServerIncompatibleException {
doConnect();
Expand All @@ -106,8 +127,25 @@ private void doConnect()

private void getClient(String host, int port)
throws TTransportException, ClientServerIncompatibleException {
transport = new TSocket(host, port, timeout, timeout);
transport.open();
if (enableSSL) {
SSLSocketFactory sslSocketFactory;
if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) {
sslSocketFactory = SslUtil.getSSLSocketFactoryWithCA((CASignedSSLParam) sslParam);
} else {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam);
}
try {
transport = new TSocket(sslSocketFactory.createSocket(host, port), timeout,
timeout);
} catch (IOException e) {
throw new TTransportException(IOErrorException.E_UNKNOWN, e);
}
} else {
transport = new TSocket(host, port, timeout, timeout);
transport.open();
}

protocol = new TCompactProtocol(transport);
client = new MetaService.Client(protocol);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.google.common.collect.Maps;
import com.vesoft.nebula.HostAddr;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException;
import com.vesoft.nebula.client.meta.exception.ExecuteFailedException;
import com.vesoft.nebula.meta.EdgeItem;
Expand Down Expand Up @@ -47,6 +48,10 @@ private class SpaceInfo {
private MetaClient metaClient;
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

private static final int DEFAULT_TIMEOUT_MS = 1000;
private static final int DEFAULT_CONNECTION_RETRY_SIZE = 3;
private static final int DEFAULT_EXECUTION_RETRY_SIZE = 3;

/**
* init the meta info cache
*/
Expand All @@ -57,6 +62,18 @@ public MetaManager(List<HostAddress> address)
fillMetaInfo();
}

/**
* init the meta info cache with more config
*/
public MetaManager(List<HostAddress> address, int timeout, int connectionRetry,
int executionRetry, boolean enableSSL, SSLParam sslParam)
throws TException, ClientServerIncompatibleException {
metaClient = new MetaClient(address, timeout, connectionRetry, executionRetry, enableSSL,
sslParam);
metaClient.connect();
fillMetaInfo();
}

/**
* close meta client
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
import com.facebook.thrift.protocol.TProtocol;
import com.facebook.thrift.transport.TSocket;
import com.facebook.thrift.transport.TTransport;
import com.facebook.thrift.transport.TTransportException;
import com.vesoft.nebula.client.graph.data.CASignedSSLParam;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.SelfSignedSSLParam;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import com.vesoft.nebula.storage.GraphStorageService;
import com.vesoft.nebula.storage.ScanEdgeRequest;
import com.vesoft.nebula.storage.ScanEdgeResponse;
import com.vesoft.nebula.storage.ScanVertexRequest;
import com.vesoft.nebula.storage.ScanVertexResponse;
import com.vesoft.nebula.util.SslUtil;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import javax.net.ssl.SSLSocketFactory;

public class GraphStorageConnection {
protected TTransport transport = null;
Expand All @@ -29,15 +37,37 @@ public class GraphStorageConnection {
protected GraphStorageConnection() {
}

protected GraphStorageConnection open(HostAddress address, int timeout) throws Exception {
protected GraphStorageConnection open(HostAddress address, int timeout, boolean enableSSL,
SSLParam sslParam) throws Exception {
this.address = address;
int newTimeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.transport = new TSocket(
InetAddress.getByName(address.getHost()).getHostAddress(),
address.getPort(),
newTimeout,
newTimeout);
this.transport.open();
if (enableSSL) {
SSLSocketFactory sslSocketFactory;
if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) {
sslSocketFactory = SslUtil.getSSLSocketFactoryWithCA((CASignedSSLParam) sslParam);
} else {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam);
}
try {
transport =
new TSocket(
sslSocketFactory.createSocket(
InetAddress.getByName(address.getHost()).getHostAddress(),
address.getPort()),
newTimeout,
newTimeout);
} catch (IOException e) {
throw new TTransportException(IOErrorException.E_UNKNOWN, e);
}
} else {
this.transport = new TSocket(
InetAddress.getByName(address.getHost()).getHostAddress(),
address.getPort(),
newTimeout,
newTimeout);
this.transport.open();
}
this.protocol = new TCompactProtocol(transport);
client = new GraphStorageService.Client(protocol);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import com.vesoft.nebula.HostAddr;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.meta.MetaManager;
import com.vesoft.nebula.client.storage.scan.PartScanInfo;
import com.vesoft.nebula.client.storage.scan.ScanEdgeResultIterator;
Expand All @@ -34,6 +35,11 @@ public class StorageClient {
private MetaManager metaManager;
private final List<HostAddress> addresses;
private int timeout = 10000; // ms
private int connectionRetry = 3;
private int executionRetry = 1;

private boolean enableSSL = false;
private SSLParam sslParam = null;

/**
* Get a Nebula Storage client that executes the scan query to get NebulaGraph's data with
Expand Down Expand Up @@ -70,16 +76,35 @@ public StorageClient(List<HostAddress> addresses, int timeout) {
this.timeout = timeout;
}

/**
* Get a Nebula Storage client that executes the scan query to get NebulaGraph's data with
* multi servers' hosts, timeout and ssl config.
*/
public StorageClient(List<HostAddress> addresses, int timeout, int connectionRetry,
int executionRetry, boolean enableSSL, SSLParam sslParam) {
this(addresses, timeout);
this.connectionRetry = connectionRetry;
this.executionRetry = executionRetry;
this.enableSSL = enableSSL;
this.sslParam = sslParam;
if (enableSSL && sslParam == null) {
throw new IllegalArgumentException("SSL is enabled, but SSLParam is nul.");
}
}

/**
* Connect to Nebula Storage server.
*
* @return true if connect successfully.
*/
public boolean connect() throws Exception {
connection.open(addresses.get(0), timeout);
connection.open(addresses.get(0), timeout, enableSSL, sslParam);
StoragePoolConfig config = new StoragePoolConfig();
config.setEnableSSL(enableSSL);
config.setSslParam(sslParam);
pool = new StorageConnPool(config);
metaManager = new MetaManager(addresses);
metaManager = new MetaManager(addresses, timeout, connectionRetry, executionRetry,
enableSSL, sslParam);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ public boolean validateObject(HostAddress hostAndPort,
public void activateObject(HostAddress address,
PooledObject<GraphStorageConnection> pooledObject)
throws Exception {
pooledObject.getObject().open(address, config.getTimeout());
pooledObject.getObject().open(
address,
config.getTimeout(),
config.isEnableSSL(),
config.getSslParam());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package com.vesoft.nebula.client.storage;

import com.vesoft.nebula.client.graph.data.SSLParam;

public class StoragePoolConfig {
// The min connections in pool for all addresses
private int minConnsSize = 0;
Expand All @@ -28,6 +30,10 @@ public class StoragePoolConfig {
// the max total connection in pool for each key
private int maxTotalPerKey = 10;

private boolean enableSSL = false;

private SSLParam sslParam = null;

public int getMinConnsSize() {
return minConnsSize;
}
Expand Down Expand Up @@ -75,4 +81,20 @@ public int getMaxTotalPerKey() {
public void setMaxTotalPerKey(int maxTotalPerKey) {
this.maxTotalPerKey = maxTotalPerKey;
}

public boolean isEnableSSL() {
return enableSSL;
}

public void setEnableSSL(boolean enableSSL) {
this.enableSSL = enableSSL;
}

public SSLParam getSslParam() {
return sslParam;
}

public void setSslParam(SSLParam sslParam) {
this.sslParam = sslParam;
}
}
Loading