Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ssl support for graph daemon #364

Merged
merged 6 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
<commons-pool2.version>2.2</commons-pool2.version>
<servlet.version>3.0.1</servlet.version>
<fastjson.version>1.2.78</fastjson.version>
<bouncycastle.version>1.69</bouncycastle.version>
</properties>

<build>
Expand Down Expand Up @@ -239,5 +240,10 @@
<artifactId>fastjson</artifactId>
<version>${fastjson.version}</version>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,37 @@ public TSocket(Socket socket) throws TTransportException {
}
}

/**
* Constructor that takes an already created socket that comes alone with timeout
* and connectionTimeout.
*
* @param socket Already created socket object
* @param timeout Socket timeout
* @param connectionTimeout Socket connection timeout
* @throws TTransportException if there is an error setting up the streams
*/
public TSocket(Socket socket, int timeout, int connectionTimeout) throws TTransportException {
socket_ = socket;
try {
socket_.setSoLinger(false, 0);
socket_.setTcpNoDelay(true);
socket_.setSoTimeout(timeout);
connectionTimeout_ = connectionTimeout;
} catch (SocketException sx) {
LOGGER.warn("Could not configure socket.", sx);
}

if (isOpen()) {
try {
inputStream_ = new BufferedInputStream(socket_.getInputStream());
outputStream_ = new BufferedOutputStream(socket_.getOutputStream());
} catch (IOException iox) {
close();
throw new TTransportException(TTransportException.NOT_OPEN, iox);
}
}
}

/**
* Creates a new unconnected socket that will connect to the given host on the given port.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package com.vesoft.nebula.client.graph;

import com.vesoft.nebula.client.graph.data.SSLParam;

public class NebulaPoolConfig {
// The min connections in pool for all addresses
private int minConnsSize = 0;
Expand All @@ -27,6 +29,28 @@ public class NebulaPoolConfig {
// the wait time to get idle connection, unit ms
private int waitTime = 0;

// set to true to turn on ssl encrypted traffic
private boolean enableSsl = false;

// ssl param is required if ssl is turned on
private SSLParam sslParam = null;

public boolean isEnableSsl() {
return enableSsl;
}

public void setEnableSsl(boolean enableSsl) {
this.enableSsl = enableSsl;
}

public SSLParam getSslParam() {
return sslParam;
}

public void setSslParam(SSLParam sslParam) {
this.sslParam = sslParam;
}

public int getMinConnSize() {
return minConnsSize;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public class CASignedSSLParam extends SSLParam {
private String caCrtFilePath;
private String crtFilePath;
private String keyFilePath;

public CASignedSSLParam(String caCrtFilePath, String crtFilePath, String keyFilePath) {
super(SignMode.CA_SIGNED);
this.caCrtFilePath = caCrtFilePath;
this.crtFilePath = crtFilePath;
this.keyFilePath = keyFilePath;
}

public String getCaCrtFilePath() {
return caCrtFilePath;
}

public String getCrtFilePath() {
return crtFilePath;
}

public String getKeyFilePath() {
return keyFilePath;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public abstract class SSLParam {
public enum SignMode {
NONE,
SELF_SIGNED,
CA_SIGNED
}

private SignMode signMode;

public SSLParam(SignMode signMode) {
this.signMode = signMode;
}

public SignMode getSignMode() {
return signMode;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

package com.vesoft.nebula.client.graph.data;

public class SelfSignedSSLParam extends SSLParam {
private String crtFilePath;
private String keyFilePath;
private String password;

public SelfSignedSSLParam(String crtFilePath, String keyFilePath, String password) {
super(SignMode.SELF_SIGNED);
this.crtFilePath = crtFilePath;
this.keyFilePath = keyFilePath;
this.password = password;
}

public String getCrtFilePath() {
return crtFilePath;
}

public String getKeyFilePath() {
return keyFilePath;
}

public String getPassword() {
return password;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

public class ConnObjectPool extends BasePooledObjectFactory<SyncConnection> {
private final NebulaPoolConfig config;
private LoadBalancer loadBalancer;
private final LoadBalancer loadBalancer;
private static final int retryTime = 3;

public ConnObjectPool(LoadBalancer loadBalancer, NebulaPoolConfig config) {
Expand All @@ -28,7 +28,15 @@ public SyncConnection create() throws IOErrorException {
SyncConnection conn = new SyncConnection();
while (retry-- > 0) {
try {
conn.open(address, config.getTimeout());
if (config.isEnableSsl()) {
if (config.getSslParam() == null) {
throw new IllegalArgumentException("SSL Param is required when enableSsl "
+ "is set to true");
}
conn.open(address, config.getTimeout(), config.getSslParam());
} else {
conn.open(address, config.getTimeout());
}
return conn;
} catch (IOErrorException e) {
if (retry == 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.vesoft.nebula.client.graph.net;

import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.exception.IOErrorException;

public abstract class Connection {
Expand All @@ -10,6 +11,9 @@ public HostAddress getServerAddress() {
return this.serverAddr;
}

public abstract void open(HostAddress address, int timeout, SSLParam sslParam)
throws IOErrorException;

public abstract void open(HostAddress address, int timeout) throws IOErrorException;

public abstract void reopen() throws IOErrorException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ public boolean init(List<HostAddress> addresses, NebulaPoolConfig config)
checkConfig(config);
this.waitTime = config.getWaitTime();
List<HostAddress> newAddrs = hostToIp(addresses);
this.loadBalancer = new RoundRobinLoadBalancer(newAddrs, config.getTimeout());
this.loadBalancer = config.isEnableSsl()
? new RoundRobinLoadBalancer(newAddrs, config.getTimeout(), config.getSslParam())
: new RoundRobinLoadBalancer(newAddrs, config.getTimeout());
ConnObjectPool objectPool = new ConnObjectPool(this.loadBalancer, config);
this.objectPool = new GenericObjectPool<>(objectPool);
GenericObjectPoolConfig objConfig = new GenericObjectPoolConfig();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.vesoft.nebula.client.graph.net;

import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -21,6 +22,8 @@ public class RoundRobinLoadBalancer implements LoadBalancer {
private final AtomicInteger pos = new AtomicInteger(0);
private final int delayTime = 60; // unit seconds
private final ScheduledExecutorService schedule = Executors.newScheduledThreadPool(1);
private SSLParam sslParam;
private boolean enabledSsl;

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout) {
this.timeout = timeout;
Expand All @@ -31,6 +34,12 @@ public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout) {
schedule.scheduleAtFixedRate(this::scheduleTask, 0, delayTime, TimeUnit.SECONDS);
}

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout, SSLParam sslParam) {
this(addresses,timeout);
this.sslParam = sslParam;
this.enabledSsl = true;
}

public void close() {
schedule.shutdownNow();
}
Expand Down Expand Up @@ -63,7 +72,11 @@ public void updateServersStatus() {
public boolean ping(HostAddress addr) {
try {
Connection connection = new SyncConnection();
connection.open(addr, this.timeout);
if (enabledSsl) {
connection.open(addr, this.timeout, sslParam);
} else {
connection.open(addr, this.timeout);
}
connection.close();
return true;
} catch (IOErrorException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,64 @@
import com.facebook.thrift.transport.TTransportException;
import com.facebook.thrift.utils.StandardCharsets;
import com.vesoft.nebula.ErrorCode;
import com.vesoft.nebula.client.graph.data.CASignedSSLParam;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.SelfSignedSSLParam;
import com.vesoft.nebula.client.graph.exception.AuthFailedException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import com.vesoft.nebula.graph.AuthResponse;
import com.vesoft.nebula.graph.ExecutionResponse;
import com.vesoft.nebula.graph.GraphService;
import com.vesoft.nebula.util.SslUtil;
import java.io.IOException;
import javax.net.ssl.SSLSocketFactory;

public class SyncConnection extends Connection {
protected TTransport transport = null;
protected TProtocol protocol = null;
private GraphService.Client client = null;
private int timeout = 0;
private SSLParam sslParam = null;
private boolean enabledSsl = false;

@Override
public void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException {
try {
SSLSocketFactory sslSocketFactory;

this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.enabledSsl = true;
this.sslParam = sslParam;
if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithCA((CASignedSSLParam) sslParam);
} else {
sslSocketFactory =
SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam);
}
if (sslSocketFactory == null) {
throw new IOErrorException(IOErrorException.E_UNKNOWN,
"SSL Socket Factory Creation failed");
}
this.transport = new TSocket(
sslSocketFactory.createSocket(address.getHost(),
address.getPort()), this.timeout, this.timeout);
this.protocol = new TCompactProtocol(transport);
client = new GraphService.Client(protocol);
} catch (TException e) {
throw new IOErrorException(IOErrorException.E_UNKNOWN, e.getMessage());
} catch (IOException e) {
e.printStackTrace();
}
}

@Override
public void open(HostAddress address, int timeout) throws IOErrorException {
this.serverAddr = address;
try {
this.enabledSsl = false;
this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.transport = new TSocket(
address.getHost(), address.getPort(), this.timeout, this.timeout);
Expand All @@ -56,7 +97,11 @@ public void open(HostAddress address, int timeout) throws IOErrorException {
@Override
public void reopen() throws IOErrorException {
close();
open(serverAddr, timeout);
if (enabledSsl) {
open(serverAddr, timeout, sslParam);
} else {
open(serverAddr, timeout);
}
}

public AuthResult authenticate(String user, String password)
Expand Down Expand Up @@ -143,6 +188,7 @@ public boolean ping() {
execute(0, "YIELD 1;");
return true;
} catch (IOErrorException e) {
e.printStackTrace();
return false;
klay-ke marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand Down
Loading