Skip to content

Commit

Permalink
add ssl support for graph daemon (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
Klay authored Oct 13, 2021
1 parent a19546e commit 34626a7
Show file tree
Hide file tree
Showing 35 changed files with 1,214 additions and 16 deletions.
6 changes: 6 additions & 0 deletions client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
<commons-pool2.version>2.2</commons-pool2.version>
<servlet.version>3.0.1</servlet.version>
<fastjson.version>1.2.78</fastjson.version>
<bouncycastle.version>1.69</bouncycastle.version>
</properties>

<build>
Expand Down Expand Up @@ -239,5 +240,10 @@
<artifactId>fastjson</artifactId>
<version>${fastjson.version}</version>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,37 @@ public TSocket(Socket socket) throws TTransportException {
}
}

/**
* Constructor that takes an already created socket that comes alone with timeout
* and connectionTimeout.
*
* @param socket Already created socket object
* @param timeout Socket timeout
* @param connectionTimeout Socket connection timeout
* @throws TTransportException if there is an error setting up the streams
*/
public TSocket(Socket socket, int timeout, int connectionTimeout) throws TTransportException {
socket_ = socket;
try {
socket_.setSoLinger(false, 0);
socket_.setTcpNoDelay(true);
socket_.setSoTimeout(timeout);
connectionTimeout_ = connectionTimeout;
} catch (SocketException sx) {
LOGGER.warn("Could not configure socket.", sx);
}

if (isOpen()) {
try {
inputStream_ = new BufferedInputStream(socket_.getInputStream());
outputStream_ = new BufferedOutputStream(socket_.getOutputStream());
} catch (IOException iox) {
close();
throw new TTransportException(TTransportException.NOT_OPEN, iox);
}
}
}

/**
* Creates a new unconnected socket that will connect to the given host on the given port.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package com.vesoft.nebula.client.graph;

import com.vesoft.nebula.client.graph.data.SSLParam;

public class NebulaPoolConfig {
// The min connections in pool for all addresses
private int minConnsSize = 0;
Expand All @@ -27,6 +29,28 @@ public class NebulaPoolConfig {
// the wait time to get idle connection, unit ms
private int waitTime = 0;

// set to true to turn on ssl encrypted traffic
private boolean enableSsl = false;

// ssl param is required if ssl is turned on
private SSLParam sslParam = null;

public boolean isEnableSsl() {
return enableSsl;
}

public void setEnableSsl(boolean enableSsl) {
this.enableSsl = enableSsl;
}

public SSLParam getSslParam() {
return sslParam;
}

public void setSslParam(SSLParam sslParam) {
this.sslParam = sslParam;
}

public int getMinConnSize() {
return minConnsSize;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public class CASignedSSLParam extends SSLParam {
private String caCrtFilePath;
private String crtFilePath;
private String keyFilePath;

public CASignedSSLParam(String caCrtFilePath, String crtFilePath, String keyFilePath) {
super(SignMode.CA_SIGNED);
this.caCrtFilePath = caCrtFilePath;
this.crtFilePath = crtFilePath;
this.keyFilePath = keyFilePath;
}

public String getCaCrtFilePath() {
return caCrtFilePath;
}

public String getCrtFilePath() {
return crtFilePath;
}

public String getKeyFilePath() {
return keyFilePath;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public abstract class SSLParam {
public enum SignMode {
NONE,
SELF_SIGNED,
CA_SIGNED
}

private SignMode signMode;

public SSLParam(SignMode signMode) {
this.signMode = signMode;
}

public SignMode getSignMode() {
return signMode;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public class SelfSignedSSLParam extends SSLParam {
private String crtFilePath;
private String keyFilePath;
private String password;

public SelfSignedSSLParam(String crtFilePath, String keyFilePath, String password) {
super(SignMode.SELF_SIGNED);
this.crtFilePath = crtFilePath;
this.keyFilePath = keyFilePath;
this.password = password;
}

public String getCrtFilePath() {
return crtFilePath;
}

public String getKeyFilePath() {
return keyFilePath;
}

public String getPassword() {
return password;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

public class ConnObjectPool extends BasePooledObjectFactory<SyncConnection> {
private final NebulaPoolConfig config;
private LoadBalancer loadBalancer;
private final LoadBalancer loadBalancer;
private static final int retryTime = 3;

public ConnObjectPool(LoadBalancer loadBalancer, NebulaPoolConfig config) {
Expand All @@ -28,7 +28,15 @@ public SyncConnection create() throws IOErrorException {
SyncConnection conn = new SyncConnection();
while (retry-- > 0) {
try {
conn.open(address, config.getTimeout());
if (config.isEnableSsl()) {
if (config.getSslParam() == null) {
throw new IllegalArgumentException("SSL Param is required when enableSsl "
+ "is set to true");
}
conn.open(address, config.getTimeout(), config.getSslParam());
} else {
conn.open(address, config.getTimeout());
}
return conn;
} catch (IOErrorException e) {
if (retry == 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.vesoft.nebula.client.graph.net;

import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.exception.IOErrorException;

public abstract class Connection {
Expand All @@ -10,6 +11,9 @@ public HostAddress getServerAddress() {
return this.serverAddr;
}

public abstract void open(HostAddress address, int timeout, SSLParam sslParam)
throws IOErrorException;

public abstract void open(HostAddress address, int timeout) throws IOErrorException;

public abstract void reopen() throws IOErrorException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ public boolean init(List<HostAddress> addresses, NebulaPoolConfig config)
checkConfig(config);
this.waitTime = config.getWaitTime();
List<HostAddress> newAddrs = hostToIp(addresses);
this.loadBalancer = new RoundRobinLoadBalancer(newAddrs, config.getTimeout());
this.loadBalancer = config.isEnableSsl()
? new RoundRobinLoadBalancer(newAddrs, config.getTimeout(), config.getSslParam())
: new RoundRobinLoadBalancer(newAddrs, config.getTimeout());
ConnObjectPool objectPool = new ConnObjectPool(this.loadBalancer, config);
this.objectPool = new GenericObjectPool<>(objectPool);
GenericObjectPoolConfig objConfig = new GenericObjectPoolConfig();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.vesoft.nebula.client.graph.net;

import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -21,6 +22,8 @@ public class RoundRobinLoadBalancer implements LoadBalancer {
private final AtomicInteger pos = new AtomicInteger(0);
private final int delayTime = 60; // unit seconds
private final ScheduledExecutorService schedule = Executors.newScheduledThreadPool(1);
private SSLParam sslParam;
private boolean enabledSsl;

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout) {
this.timeout = timeout;
Expand All @@ -31,6 +34,12 @@ public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout) {
schedule.scheduleAtFixedRate(this::scheduleTask, 0, delayTime, TimeUnit.SECONDS);
}

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout, SSLParam sslParam) {
this(addresses,timeout);
this.sslParam = sslParam;
this.enabledSsl = true;
}

public void close() {
schedule.shutdownNow();
}
Expand Down Expand Up @@ -63,7 +72,11 @@ public void updateServersStatus() {
public boolean ping(HostAddress addr) {
try {
Connection connection = new SyncConnection();
connection.open(addr, this.timeout);
if (enabledSsl) {
connection.open(addr, this.timeout, sslParam);
} else {
connection.open(addr, this.timeout);
}
connection.close();
return true;
} catch (IOErrorException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,64 @@
import com.facebook.thrift.transport.TTransportException;
import com.facebook.thrift.utils.StandardCharsets;
import com.vesoft.nebula.ErrorCode;
import com.vesoft.nebula.client.graph.data.CASignedSSLParam;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.SelfSignedSSLParam;
import com.vesoft.nebula.client.graph.exception.AuthFailedException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import com.vesoft.nebula.graph.AuthResponse;
import com.vesoft.nebula.graph.ExecutionResponse;
import com.vesoft.nebula.graph.GraphService;
import com.vesoft.nebula.util.SslUtil;
import java.io.IOException;
import javax.net.ssl.SSLSocketFactory;

public class SyncConnection extends Connection {
protected TTransport transport = null;
protected TProtocol protocol = null;
private GraphService.Client client = null;
private int timeout = 0;
private SSLParam sslParam = null;
private boolean enabledSsl = false;

@Override
public void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException {
try {
SSLSocketFactory sslSocketFactory;

this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.enabledSsl = true;
this.sslParam = sslParam;
if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithCA((CASignedSSLParam) sslParam);
} else {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam);
}
if (sslSocketFactory == null) {
throw new IOErrorException(IOErrorException.E_UNKNOWN,
"SSL Socket Factory Creation failed");
}
this.transport = new TSocket(
sslSocketFactory.createSocket(address.getHost(),
address.getPort()), this.timeout, this.timeout);
this.protocol = new TCompactProtocol(transport);
client = new GraphService.Client(protocol);
} catch (TException e) {
throw new IOErrorException(IOErrorException.E_UNKNOWN, e.getMessage());
} catch (IOException e) {
e.printStackTrace();
}
}

@Override
public void open(HostAddress address, int timeout) throws IOErrorException {
this.serverAddr = address;
try {
this.enabledSsl = false;
this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.transport = new TSocket(
address.getHost(), address.getPort(), this.timeout, this.timeout);
Expand All @@ -56,7 +97,11 @@ public void open(HostAddress address, int timeout) throws IOErrorException {
@Override
public void reopen() throws IOErrorException {
close();
open(serverAddr, timeout);
if (enabledSsl) {
open(serverAddr, timeout, sslParam);
} else {
open(serverAddr, timeout);
}
}

public AuthResult authenticate(String user, String password)
Expand Down Expand Up @@ -143,6 +188,7 @@ public boolean ping() {
execute(0, "YIELD 1;");
return true;
} catch (IOErrorException e) {
e.printStackTrace();
return false;
}
}
Expand Down
Loading

0 comments on commit 34626a7

Please sign in to comment.