From ef9e8d54d6275f0d9df1a531e0a577f168350599 Mon Sep 17 00:00:00 2001 From: Yichen Wang <18348405+Aiee@users.noreply.github.com> Date: Tue, 13 Dec 2022 17:39:07 +0800 Subject: [PATCH] [v2.6.1 patch] Refactor connection pool and add more tests (#251) * Add more tests * Tune tests and refactor Lower concurrence for ci so the github action can pass --- nebula2/gclient/net/ConnectionPool.py | 100 +++++++++++-------- tests/docker-compose.yaml | 18 ++-- tests/test_pool.py | 132 ++++++++++++++++++++++---- tests/test_session.py | 8 +- 4 files changed, 186 insertions(+), 72 deletions(-) diff --git a/nebula2/gclient/net/ConnectionPool.py b/nebula2/gclient/net/ConnectionPool.py index 9379eb0c..0950c232 100644 --- a/nebula2/gclient/net/ConnectionPool.py +++ b/nebula2/gclient/net/ConnectionPool.py @@ -12,10 +12,7 @@ from collections import deque from threading import RLock, Timer -from nebula2.Exception import ( - NotValidConnectionException, - InValidHostname -) +from nebula2.Exception import NotValidConnectionException, InValidHostname from nebula2.gclient.net.Session import Session from nebula2.gclient.net.Connection import Connection @@ -65,7 +62,7 @@ def init(self, addresses, configs, ssl_conf=None): self._addresses.append(ip_port) self._addresses_status[ip_port] = self.S_BAD self._connections[ip_port] = deque() - + self._ssl_configs = ssl_conf self.update_servers_status() # detect the services @@ -74,25 +71,19 @@ def init(self, addresses, configs, ssl_conf=None): # init min connections ok_num = self.get_ok_servers_num() if ok_num < len(self._addresses): - raise RuntimeError('The services status exception: {}'.format( - self._get_services_status())) - - conns_per_address = int( - self._configs.min_connection_pool_size / ok_num) - - if ssl_conf is None: - for addr in self._addresses: - for i in range(0, conns_per_address): - connection = Connection() - connection.open(addr[0], addr[1], self._configs.timeout) - self._connections[addr].append(connection) - else: - for addr in self._addresses: - for i in range(0, conns_per_address): - connection = Connection() - connection.open_SSL( - addr[0], addr[1], self._configs.timeout, self._ssl_configs) - self._connections[addr].append(connection) + raise RuntimeError( + 'The services status exception: {}'.format(self._get_services_status()) + ) + + conns_per_address = int(self._configs.min_connection_pool_size / ok_num) + + for addr in self._addresses: + for i in range(0, conns_per_address): + connection = Connection() + connection.open_SSL( + addr[0], addr[1], self._configs.timeout, self._ssl_configs + ) + self._connections[addr].append(connection) return True def get_session(self, user_name, password, retry_connect=True): @@ -151,25 +142,45 @@ def get_connection(self): if ok_num == 0: logging.error('No available server') return None - max_con_per_address = int(self._configs.max_connection_pool_size / ok_num) + max_con_per_address = int( + self._configs.max_connection_pool_size / ok_num + ) try_count = 0 while try_count <= len(self._addresses): self._pos = (self._pos + 1) % len(self._addresses) addr = self._addresses[self._pos] if self._addresses_status[addr] == self.S_OK: + invalid_connections = list() + + # iterate all connections to find an available connection for connection in self._connections[addr]: if not connection.is_used: + # ping to check the connection is valid if connection.ping(): connection.is_used = True logging.info('Get connection to {}'.format(addr)) return connection - # remove unusable connection - self._connections[addr].remove(connection) + else: + invalid_connections.append(connection) + # remove invalid connections + for connection in invalid_connections: + self._connections[addr].remove(connection) + + # check if the server is still alive + if not self.ping(addr): + self._addresses_status[addr] = self.S_BAD + continue + + # create new connection if the number of connections is less than max_con_per_address if len(self._connections[addr]) < max_con_per_address: connection = Connection() connection.open_SSL( - addr[0], addr[1], self._configs.timeout, self._ssl_configs) + addr[0], + addr[1], + self._configs.timeout, + self._ssl_configs, + ) connection.is_used = True self._connections[addr].append(connection) logging.info('Get connection to {}'.format(addr)) @@ -179,6 +190,8 @@ def get_connection(self): if not connection.is_used: self._connections[addr].remove(connection) try_count = try_count + 1 + + logging.error('No available connection') return None except Exception as ex: logging.error('Get connection failed: {}'.format(ex)) @@ -192,14 +205,13 @@ def ping(self, address): """ try: conn = Connection() - if self._ssl_configs is None: - conn.open(address[0], address[1], 1000) - else: - conn.open_SSL(address[0], address[1], 1000, self._ssl_configs) + conn.open_SSL(address[0], address[1], 1000, self._ssl_configs) conn.close() return True except Exception as ex: - logging.warning('Connect {}:{} failed: {}'.format(address[0], address[1], ex)) + logging.warning( + 'Connect {}:{} failed: {}'.format(address[0], address[1], ex) + ) return False def close(self): @@ -211,7 +223,7 @@ def close(self): for addr in self._connections.keys(): for connection in self._connections[addr]: if connection.is_used: - logging.error('The connection using by someone, but now want to close it') + logging.warning('Closing a connection that is in use') connection.close() self._close = True @@ -260,8 +272,7 @@ def _get_services_status(self): return ', '.join(msg_list) def update_servers_status(self): - """update the servers' status - """ + """update the servers' status""" for address in self._addresses: if self.ping(address): self._addresses_status[address] = self.S_OK @@ -277,11 +288,22 @@ def _remove_idle_unusable_connection(self): for connection in list(conns): if not connection.is_used: if not connection.ping(): - logging.debug('Remove the not unusable connection to {}'.format(connection.get_address())) + logging.debug( + 'Remove the unusable connection to {}'.format( + connection.get_address() + ) + ) conns.remove(connection) continue - if self._configs.idle_time != 0 and connection.idle_time() > self._configs.idle_time: - logging.debug('Remove the idle connection to {}'.format(connection.get_address())) + if ( + self._configs.idle_time != 0 + and connection.idle_time() > self._configs.idle_time + ): + logging.debug( + 'Remove the idle connection to {}'.format( + connection.get_address() + ) + ) conns.remove(connection) def _period_detect(self): diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 505ca0ce..73a7f24f 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -14,7 +14,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} @@ -55,7 +55,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} @@ -96,7 +96,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} @@ -137,7 +137,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} @@ -183,7 +183,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} @@ -229,7 +229,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} @@ -273,7 +273,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} @@ -316,7 +316,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} @@ -359,7 +359,7 @@ services: - --log_dir=/logs - --v=0 - --minloglevel=0 - - --heartbeat_interval_secs=2 + - --heartbeat_interval_secs=1 - --timezone_name=+08:00 # ssl - --ca_path=${ca_path} diff --git a/tests/test_pool.py b/tests/test_pool.py index a45ed701..3da6b32f 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -18,14 +18,14 @@ from unittest import TestCase -from nebula2.gclient.net import ConnectionPool +from nebula2.gclient.net import ConnectionPool, Connection from nebula2.Config import Config from nebula2.Exception import ( NotValidConnectionException, InValidHostname, - IOErrorException + IOErrorException, ) @@ -152,7 +152,9 @@ def test_timeout(self): assert pool.init([('127.0.0.1', 9669)], config) session = pool.get_session('root', 'nebula') try: - resp = session.execute('USE nba;GO 1000 STEPS FROM \"Tim Duncan\" OVER like') + resp = session.execute( + 'USE nba;GO 1000 STEPS FROM \"Tim Duncan\" OVER like' + ) assert False except IOErrorException as e: assert True @@ -168,14 +170,15 @@ def test_multi_thread(): # Test multi thread addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)] configs = Config() - configs.max_connection_pool_size = 4 + thread_num = 50 + configs.max_connection_pool_size = thread_num pool = ConnectionPool() assert pool.init(addresses, configs) global success_flag success_flag = True - def main_test(): + def pool_multi_thread_test(): session = None global success_flag try: @@ -185,8 +188,10 @@ def main_test(): return space_name = 'space_' + threading.current_thread().getName() - session.execute('DROP SPACE %s' % space_name) - resp = session.execute('CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' % space_name) + session.execute('DROP SPACE IF EXISTS %s' % space_name) + resp = session.execute( + 'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' % space_name + ) if not resp.is_succeeded(): raise RuntimeError('CREATE SPACE failed: {}'.format(resp.error_msg())) @@ -203,21 +208,108 @@ def main_test(): if session is not None: session.release() - thread1 = threading.Thread(target=main_test, name='thread1') - thread2 = threading.Thread(target=main_test, name='thread2') - thread3 = threading.Thread(target=main_test, name='thread3') - thread4 = threading.Thread(target=main_test, name='thread4') + threads = [] + for num in range(0, thread_num): + thread = threading.Thread( + target=pool_multi_thread_test, name='test_pool_thread' + str(num) + ) + thread.start() + threads.append(thread) - thread1.start() - thread2.start() - thread3.start() - thread4.start() - - thread1.join() - thread2.join() - thread3.join() - thread4.join() + for t in threads: + t.join() + assert success_flag pool.close() + + +def test_session_context_multi_thread(): + # Test multi thread + addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)] + configs = Config() + thread_num = 50 + configs.max_connection_pool_size = thread_num + pool = ConnectionPool() + assert pool.init(addresses, configs) + + global success_flag + success_flag = True + + def pool_session_context_multi_thread_test(): + session = None + global success_flag + try: + with pool.session_context('root', 'nebula') as session: + if session is None: + success_flag = False + return + space_name = 'space_' + threading.current_thread().getName() + + session.execute('DROP SPACE IF EXISTS %s' % space_name) + resp = session.execute( + 'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' + % space_name + ) + if not resp.is_succeeded(): + raise RuntimeError( + 'CREATE SPACE failed: {}'.format(resp.error_msg()) + ) + + time.sleep(3) + resp = session.execute('USE %s' % space_name) + if not resp.is_succeeded(): + raise RuntimeError('USE SPACE failed:{}'.format(resp.error_msg())) + + except Exception as x: + print(x) + success_flag = False + return + + threads = [] + for num in range(0, thread_num): + thread = threading.Thread( + target=pool_session_context_multi_thread_test, + name='test_session_context_thread' + str(num), + ) + thread.start() + threads.append(thread) + + for t in threads: + t.join() assert success_flag + pool.close() + + +def test_remove_invalid_connection(): + addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670), ('127.0.0.1', 9671)] + configs = Config() + configs.min_connection_pool_size = 30 + configs.max_connection_pool_size = 45 + pool = ConnectionPool() + + try: + assert pool.init(addresses, configs) + + # turn down one server('127.0.0.1', 9669) so the connection to it is invalid + os.system('docker stop tests_graphd0_1') + time.sleep(3) + + # get connection from the pool, we should be able to still get 30 connections even though one server is down + for i in range(0, 30): + conn = pool.get_connection() + assert conn is not None + + # total connection should still be 30 + assert pool.connects() == 30 + + # the number of connections to the down server should be 0 + assert len(pool._connections[addresses[0]]) == 0 + + # the number of connections to the 2nd('127.0.0.1', 9670) and 3rd server('127.0.0.1', 9671) should be 15 + assert len(pool._connections[addresses[1]]) == 15 + assert len(pool._connections[addresses[2]]) == 15 + + finally: + os.system('docker start tests_graphd0_1') + time.sleep(3) diff --git a/tests/test_session.py b/tests/test_session.py index 6796f08b..ebd31a4c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -48,8 +48,8 @@ def test_2_reconnect(self): session.execute('CREATE SPACE IF NOT EXISTS test_session(vid_type=FIXED_STRING(8)); USE test_session;') for i in range(0, 5): if i == 3: - os.system('docker stop nebula-docker-compose_graphd0_1') - os.system('docker stop nebula-docker-compose_graphd1_1') + os.system('docker stop tests_graphd0_1') + os.system('docker stop tests_graphd1_1') time.sleep(3) # the session update later, the expect test # resp = session.execute('SHOW TAGS') @@ -63,8 +63,8 @@ def test_2_reconnect(self): except Exception as e: assert False, e finally: - os.system('docker start nebula-docker-compose_graphd0_1') - os.system('docker start nebula-docker-compose_graphd1_1') + os.system('docker start tests_graphd0_1') + os.system('docker start tests_graphd1_1') time.sleep(5) def test_3_session_context(self):