Skip to content

Commit

Permalink
add version selection (#305)
Browse files Browse the repository at this point in the history
* add version selection

* add version whitelist

* add version whitelist

* format

* format code

* format code

* update test_ssl_connection

* add handshakekey to sessionPool

* add import

* update test sessionpool

* format code

* add test code

* add condition for time

---------

Co-authored-by: Anqi <[email protected]>
  • Loading branch information
javaGitHub2022 and Nicole00 authored Jan 10, 2024
1 parent 3be007f commit d58b274
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 79 deletions.
4 changes: 3 additions & 1 deletion example/SessinPoolExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@

from FormatResp import print_resp


from nebula3.common.ttypes import ErrorCode
from nebula3.Config import SessionPoolConfig
from nebula3.gclient.net import Connection
from nebula3.gclient.net.SessionPool import SessionPool


if __name__ == "__main__":
ip = "127.0.0.1"
port = 9669

try:
config = SessionPoolConfig()

Expand All @@ -36,6 +37,7 @@
time.sleep(10)

# init session pool

session_pool = SessionPool("root", "nebula", "session_pool_test", [(ip, port)])
assert session_pool.init(config)

Expand Down
3 changes: 3 additions & 0 deletions nebula3/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Config(object):
# the interval to check idle time connection, unit second, -1 means no check
interval_check = -1

handshakeKey = None


class SSL_config(object):
"""configs used to Initialize a TSSLSocket.
Expand Down Expand Up @@ -89,3 +91,4 @@ class SessionPoolConfig(object):
max_size = 30
min_size = 1
interval_check = -1
handshakeKey = None
25 changes: 17 additions & 8 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,34 @@ def __init__(self):
self._port = None
self._timeout = 0
self._ssl_conf = None
self.handshakeKey = None

def open(self, ip, port, timeout):
def open(self, ip, port, timeout, handshakeKey=None):
"""open the connection
:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:param handshakeKey: the client version
:return: void
"""
self.open_SSL(ip, port, timeout, None)
self.open_SSL(ip, port, timeout, handshakeKey, None)

def open_SSL(self, ip, port, timeout, ssl_config=None):
def open_SSL(self, ip, port, timeout, handshakeKey=None, ssl_config=None):
"""open the SSL connection
:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:param handshakeKey: the client version
:ssl_config: configs for SSL
:return: void
"""
self._ip = ip
self._port = port
self._timeout = timeout
self._ssl_conf = ssl_config
self.handshakeKey = handshakeKey
try:
if ssl_config is not None:
s = TSSLSocket.TSSLSocket(
Expand All @@ -89,7 +93,10 @@ def open_SSL(self, ip, port, timeout, ssl_config=None):
header_transport.open()

self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
verifyClientVersionReq = VerifyClientVersionReq()
if handshakeKey is not None:
verifyClientVersionReq.version = handshakeKey
resp = self._connection.verifyClientVersion(verifyClientVersionReq)
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
Expand All @@ -103,9 +110,11 @@ def _reopen(self):
"""
self.close()
if self._ssl_conf is not None:
self.open_SSL(self._ip, self._port, self._timeout, self._ssl_conf)
self.open_SSL(
self._ip, self._port, self._timeout, self.handshakeKey, self._ssl_conf
)
else:
self.open(self._ip, self._port, self._timeout)
self.open(self._ip, self._port, self._timeout, self.handshakeKey)

def authenticate(self, user_name, password):
"""authenticate to graphd
Expand Down Expand Up @@ -216,15 +225,15 @@ def close(self):
self._connection._iprot.trans.close()
except Exception as e:
logger.error(
'Close connection to {}:{} failed:{}'.format(self._ip, self._port, e)
"Close connection to {}:{} failed:{}".format(self._ip, self._port, e)
)

def ping(self):
"""check the connection if ok
:return: True or False
"""
try:
resp = self._connection.execute(0, 'YIELD 1;')
resp = self._connection.execute(0, "YIELD 1;")
return True
except Exception:
return False
Expand Down
49 changes: 30 additions & 19 deletions nebula3/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def init(self, addresses, configs, ssl_conf=None):
:return: if all addresses are ok, return True else return False.
"""
if self._close:
logger.error('The pool has init or closed.')
raise RuntimeError('The pool has init or closed.')
logger.error("The pool has init or closed.")
raise RuntimeError("The pool has init or closed.")
self._configs = configs
self._ssl_configs = ssl_conf
for address in addresses:
Expand All @@ -73,7 +73,7 @@ def init(self, addresses, configs, ssl_conf=None):
ok_num = self.get_ok_servers_num()
if ok_num < len(self._addresses):
raise RuntimeError(
'The services status exception: {}'.format(self._get_services_status())
"The services status exception: {}".format(self._get_services_status())
)

conns_per_address = int(self._configs.min_connection_pool_size / ok_num)
Expand All @@ -82,7 +82,11 @@ def init(self, addresses, configs, ssl_conf=None):
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
addr[0],
addr[1],
self._configs.timeout,
configs.handshakeKey,
self._ssl_configs,
)
self._connections[addr].append(connection)
return True
Expand Down Expand Up @@ -135,13 +139,13 @@ def get_connection(self):
"""
with self._lock:
if self._close:
logger.error('The pool is closed')
logger.error("The pool is closed")
raise NotValidConnectionException()

try:
ok_num = self.get_ok_servers_num()
if ok_num == 0:
logger.error('No available server')
logger.error("No available server")
return None
max_con_per_address = int(
self._configs.max_connection_pool_size / ok_num
Expand All @@ -159,7 +163,7 @@ def get_connection(self):
# ping to check the connection is valid
if connection.ping():
connection.is_used = True
logger.info('Get connection to {}'.format(addr))
logger.info("Get connection to {}".format(addr))
return connection
else:
invalid_connections.append(connection)
Expand All @@ -180,22 +184,23 @@ def get_connection(self):
addr[0],
addr[1],
self._configs.timeout,
self._configs.handshakeKey,
self._ssl_configs,
)
connection.is_used = True
self._connections[addr].append(connection)
logger.info('Get connection to {}'.format(addr))
logger.info("Get connection to {}".format(addr))
return connection
else:
for connection in list(self._connections[addr]):
if not connection.is_used:
self._connections[addr].remove(connection)
try_count = try_count + 1

logger.error('No available connection')
logger.error("No available connection")
return None
except Exception as ex:
logger.error('Get connection failed: {}'.format(ex))
logger.error("Get connection failed: {}".format(ex))
return None

def ping(self, address):
Expand All @@ -206,12 +211,18 @@ def ping(self, address):
"""
try:
conn = Connection()
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(
address[0],
address[1],
1000,
self._configs.handshakeKey,
self._ssl_configs,
)
conn.close()
return True
except Exception as ex:
logger.warning(
'Connect {}:{} failed: {}'.format(address[0], address[1], ex)
"Connect {}:{} failed: {}".format(address[0], address[1], ex)
)
return False

Expand All @@ -224,7 +235,7 @@ def close(self):
for addr in self._connections.keys():
for connection in self._connections[addr]:
if connection.is_used:
logger.warning('Closing a connection that is in use')
logger.warning("Closing a connection that is in use")
connection.close()
self._close = True

Expand Down Expand Up @@ -266,11 +277,11 @@ def get_ok_servers_num(self):
def _get_services_status(self):
msg_list = []
for addr in self._addresses_status.keys():
status = 'OK'
status = "OK"
if self._addresses_status[addr] != self.S_OK:
status = 'BAD'
msg_list.append('[services: {}, status: {}]'.format(addr, status))
return ', '.join(msg_list)
status = "BAD"
msg_list.append("[services: {}, status: {}]".format(addr, status))
return ", ".join(msg_list)

def update_servers_status(self):
"""update the servers' status"""
Expand All @@ -290,7 +301,7 @@ def _remove_idle_unusable_connection(self):
if not connection.is_used:
if not connection.ping():
logger.debug(
'Remove the unusable connection to {}'.format(
"Remove the unusable connection to {}".format(
connection.get_address()
)
)
Expand All @@ -301,7 +312,7 @@ def _remove_idle_unusable_connection(self):
and connection.idle_time() > self._configs.idle_time
):
logger.debug(
'Remove the idle connection to {}'.format(
"Remove the idle connection to {}".format(
connection.get_address()
)
)
Expand Down
Loading

0 comments on commit d58b274

Please sign in to comment.