From c2c489c22e00b0965263980ebbae6f715d6bd937 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Tue, 30 Apr 2024 13:39:13 +0800 Subject: [PATCH] Add a base class for functions execute_* (#338) * Add a base class for functions execute_* ; fix typing ; pdm run fmt ; * fix typing * decode bytes * Use bytes instead of string --- nebula3/gclient/net/Connection.py | 7 +++-- nebula3/gclient/net/ConnectionPool.py | 41 ++++++++++++++------------- nebula3/gclient/net/Session.py | 39 ++++++++++++++----------- nebula3/gclient/net/SessionPool.py | 18 ++++++------ nebula3/gclient/net/base.py | 23 +++++++++++++++ 5 files changed, 82 insertions(+), 46 deletions(-) create mode 100644 nebula3/gclient/net/base.py diff --git a/nebula3/gclient/net/Connection.py b/nebula3/gclient/net/Connection.py index ad9071a6..d4b9b547 100644 --- a/nebula3/gclient/net/Connection.py +++ b/nebula3/gclient/net/Connection.py @@ -225,11 +225,14 @@ def execute_json_with_parameter(self, session_id, stmt, params): :param session_id: the session id get from result of authenticate interface :param stmt: the ngql :param params: parameter map - :return: string json representing the execution result + :return: json bytes representing the execution result """ try: resp = self._connection.executeJsonWithParameter(session_id, stmt, params) - return resp + if not isinstance(resp, bytes): + raise TypeError("response is not bytes") + else: + return resp except Exception as te: if isinstance(te, TTransportException): if te.message.find("timed out") > 0: diff --git a/nebula3/gclient/net/ConnectionPool.py b/nebula3/gclient/net/ConnectionPool.py index 798a1315..863300cc 100644 --- a/nebula3/gclient/net/ConnectionPool.py +++ b/nebula3/gclient/net/ConnectionPool.py @@ -17,6 +17,7 @@ from nebula3.gclient.net.Connection import Connection from nebula3.Config import Config from nebula3.logger import logger +from typing import Dict, List, Tuple class ConnectionPool(object): @@ -25,13 +26,13 @@ class ConnectionPool(object): def __init__(self): # all addresses of servers - self._addresses = list() + self._addresses: List[Tuple[str, int]] = list() # server's status self._addresses_status = dict() # all connections - self._connections = dict() + self._connections: Dict[Tuple[str, int], List[Connection]] = dict() self._configs = None self._ssl_configs = None self._lock = RLock() @@ -50,14 +51,14 @@ def init(self, addresses, configs=None, ssl_conf=None): :return: if all addresses are ok, return True else return False. """ if self._close: - logger.error('The pool has init or closed.') - raise RuntimeError('The pool has init or closed.') + logger.error("The pool has init or closed.") + raise RuntimeError("The pool has init or closed.") if configs is None: self._configs = Config() else: assert isinstance( configs, Config - ), 'wrong type of Config, try this: `from nebula3.Config import Config`' + ), "wrong type of Config, try this: `from nebula3.Config import Config`" self._configs = configs self._ssl_configs = ssl_conf for address in addresses: @@ -80,7 +81,7 @@ def init(self, addresses, configs=None, ssl_conf=None): ok_num = self.get_ok_servers_num() if ok_num < len(self._addresses): raise RuntimeError( - 'The services status exception: {}'.format(self._get_services_status()) + "The services status exception: {}".format(self._get_services_status()) ) conns_per_address = int(self._configs.min_connection_pool_size / ok_num) @@ -147,13 +148,13 @@ def get_connection(self): """ with self._lock: if self._close: - logger.error('The pool is closed') + logger.error("The pool is closed") raise NotValidConnectionException() try: ok_num = self.get_ok_servers_num() if ok_num == 0: - logger.error('No available server') + logger.error("No available server") return None max_con_per_address = int( self._configs.max_connection_pool_size / ok_num @@ -171,7 +172,7 @@ def get_connection(self): # ping to check the connection is valid if connection.ping(): connection.is_used = True - logger.info('Get connection to {}'.format(addr)) + logger.info("Get connection to {}".format(addr)) return connection else: invalid_connections.append(connection) @@ -198,7 +199,7 @@ def get_connection(self): ) connection.is_used = True self._connections[addr].append(connection) - logger.info('Get connection to {}'.format(addr)) + logger.info("Get connection to {}".format(addr)) return connection else: for connection in list(self._connections[addr]): @@ -206,10 +207,10 @@ def get_connection(self): self._connections[addr].remove(connection) try_count = try_count + 1 - logger.error('No available connection') + logger.error("No available connection") return None except Exception as ex: - logger.error('Get connection failed: {}'.format(ex)) + logger.error("Get connection failed: {}".format(ex)) return None def ping(self, address): @@ -235,7 +236,7 @@ def ping(self, address): return True except Exception as ex: logger.warning( - 'Connect {}:{} failed: {}'.format(address[0], address[1], ex) + "Connect {}:{} failed: {}".format(address[0], address[1], ex) ) return False @@ -248,7 +249,7 @@ def close(self): for addr in self._connections.keys(): for connection in self._connections[addr]: if connection.is_used: - logger.warning('Closing a connection that is in use') + logger.warning("Closing a connection that is in use") connection.close() self._close = True @@ -290,11 +291,11 @@ def get_ok_servers_num(self): def _get_services_status(self): msg_list = [] for addr in self._addresses_status.keys(): - status = 'OK' + status = "OK" if self._addresses_status[addr] != self.S_OK: - status = 'BAD' - msg_list.append('[services: {}, status: {}]'.format(addr, status)) - return ', '.join(msg_list) + status = "BAD" + msg_list.append("[services: {}, status: {}]".format(addr, status)) + return ", ".join(msg_list) def update_servers_status(self): """update the servers' status""" @@ -314,7 +315,7 @@ def _remove_idle_unusable_connection(self): if not connection.is_used: if not connection.ping(): logger.debug( - 'Remove the unusable connection to {}'.format( + "Remove the unusable connection to {}".format( connection.get_address() ) ) @@ -325,7 +326,7 @@ def _remove_idle_unusable_connection(self): and connection.idle_time() > self._configs.idle_time ): logger.debug( - 'Remove the idle connection to {}'.format( + "Remove the idle connection to {}".format( connection.get_address() ) ) diff --git a/nebula3/gclient/net/Session.py b/nebula3/gclient/net/Session.py index 1faf4b58..069c38f8 100644 --- a/nebula3/gclient/net/Session.py +++ b/nebula3/gclient/net/Session.py @@ -8,6 +8,8 @@ import json import time +from typing import TYPE_CHECKING + from nebula3.Exception import ( IOErrorException, NotValidConnectionException, @@ -15,15 +17,20 @@ from nebula3.common.ttypes import ErrorCode from nebula3.data.ResultSet import ResultSet from nebula3.gclient.net.AuthResult import AuthResult +from nebula3.gclient.net.base import BaseExecutor from nebula3.logger import logger +if TYPE_CHECKING: + from nebula3.gclient.net.ConnectionPool import ConnectionPool + from nebula3.gclient.net.Connection import Connection + -class Session(object): +class Session(BaseExecutor, object): def __init__( self, - connection, + connection: "Connection", auth_result: AuthResult, - pool, + pool: "ConnectionPool", retry_connect=True, execution_retry_count=0, retry_interval_seconds=1, @@ -50,6 +57,14 @@ def __init__( # the time stamp when the session was added to the idle list of the session pool self._idle_time_start = 0 + def execute(self, stmt): + """execute statement + + :param stmt: the ngql + :return: ResultSet + """ + return super().execute(stmt) + def execute_parameter(self, stmt, params): """execute statement :param stmt: the ngql @@ -105,16 +120,8 @@ def execute_parameter(self, stmt, params): except Exception: raise - def execute(self, stmt): - """execute statement - - :param stmt: the ngql - :return: ResultSet - """ - return self.execute_parameter(stmt, None) - def execute_json(self, stmt): - """execute statement and return the result as a JSON string + """execute statement and return the result as a JSON bytes Date and Datetime will be returned in UTC JSON struct: { @@ -172,12 +179,12 @@ def execute_json(self, stmt): ] } :param stmt: the ngql - :return: JSON string + :return: JSON bytes """ - return self.execute_json_with_parameter(stmt, None) + return super().execute_json(stmt) def execute_json_with_parameter(self, stmt, params): - """execute statement and return the result as a JSON string + """execute statement and return the result as a JSON bytes Date and Datetime will be returned in UTC JSON struct: { @@ -236,7 +243,7 @@ def execute_json_with_parameter(self, stmt, params): } :param stmt: the ngql :param params: parameter map - :return: JSON string + :return: JSON bytes """ if self._connection is None: raise RuntimeError("The session has been released") diff --git a/nebula3/gclient/net/SessionPool.py b/nebula3/gclient/net/SessionPool.py index 2e69bada..45ef690f 100644 --- a/nebula3/gclient/net/SessionPool.py +++ b/nebula3/gclient/net/SessionPool.py @@ -9,6 +9,7 @@ import socket from threading import RLock, Timer +from typing import List import time from nebula3.common.ttypes import ErrorCode @@ -20,11 +21,12 @@ from nebula3.gclient.net.Session import Session from nebula3.gclient.net.Connection import Connection +from nebula3.gclient.net.base import BaseExecutor from nebula3.logger import logger from nebula3.Config import SessionPoolConfig -class SessionPool(object): +class SessionPool(BaseExecutor, object): S_OK = 0 S_BAD = 1 @@ -53,9 +55,9 @@ def __init__(self, username, password, space_name, addresses): self._addresses_status[ip_port] = self.S_BAD # sessions that are currently in use - self._active_sessions = list() + self._active_sessions: List[Session] = list() # sessions that are currently available - self._idle_sessions = list() + self._idle_sessions: List[Session] = list() self._configs = SessionPoolConfig() self._ssl_configs = None @@ -84,7 +86,7 @@ def init(self, configs=None): if configs is not None: assert isinstance( configs, SessionPoolConfig - ), 'wrong type of SessionPoolConfig, try this: `from nebula3.Config import SessionPoolConfig`' + ), "wrong type of SessionPoolConfig, try this: `from nebula3.Config import SessionPoolConfig`" self._configs = configs else: self._configs = SessionPoolConfig() @@ -165,7 +167,7 @@ def execute(self, stmt): :param stmt: the query string :return: ResultSet """ - return self.execute_parameter(stmt, None) + return super().execute(stmt) def execute_parameter(self, stmt, params): """execute statement @@ -213,7 +215,7 @@ def execute_parameter(self, stmt, params): raise e def execute_json(self, stmt): - """execute statement and return the result as a JSON string + """execute statement and return the result as a JSON bytes Date and Datetime will be returned in UTC JSON struct: { @@ -271,9 +273,9 @@ def execute_json(self, stmt): ] } :param stmt: the ngql - :return: JSON string + :return: JSON bytes """ - return self.execute_json_with_parameter(stmt, None) + return super().execute_json(stmt) def execute_json_with_parameter(self, stmt, params): session = self._get_idle_session() diff --git a/nebula3/gclient/net/base.py b/nebula3/gclient/net/base.py new file mode 100644 index 00000000..b80320b0 --- /dev/null +++ b/nebula3/gclient/net/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from typing import Dict, Any, Optional +from nebula3.data.ResultSet import ResultSet + + +class BaseExecutor: + @abstractmethod + def execute_parameter( + self, stmt: str, params: Optional[Dict[str, Any]] + ) -> ResultSet: + pass + + @abstractmethod + def execute_json_with_parameter( + self, stmt: str, params: Optional[Dict[str, Any]] + ) -> bytes: + pass + + def execute(self, stmt: str) -> ResultSet: + return self.execute_parameter(stmt, None) + + def execute_json(self, stmt: str) -> bytes: + return self.execute_json_with_parameter(stmt, None)