diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 4127e1104..baab23e88 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -4,7 +4,7 @@ import socket import time from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from urllib import parse import grpc @@ -87,6 +87,16 @@ def __init__( self._set_authorization(**kwargs) self._setup_db_interceptor(kwargs.get("db_name", None)) self._setup_grpc_channel() + self.callbacks = [] + + def register_state_change_callback(self, callback: Callable): + self.callbacks.append(callback) + self._channel.subscribe(callback, try_to_connect=True) + + def deregister_state_change_callbacks(self): + for callback in self.callbacks: + self._channel.unsubscribe(callback) + self.callbacks = [] def __get_address(self, uri: str, host: str, port: str) -> str: if host != "" and port != "" and is_legal_host(host) and is_legal_port(port): @@ -141,6 +151,7 @@ def _wait_for_channel_ready(self, timeout: Union[float] = 10): raise e from e def close(self): + self.deregister_state_change_callbacks() self._channel.close() def reset_db_name(self, db_name: str): diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7826de8f2..c29863758 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -11,7 +11,9 @@ # the License. import copy +import logging import threading +import time from typing import Callable, Tuple, Union from urllib import parse @@ -24,6 +26,8 @@ ) from pymilvus.settings import Config +logger = logging.getLogger(__name__) + VIRTUAL_PORT = 443 @@ -58,6 +62,53 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) +class ReconnectHandler: + def __init__(self, conns: object, connection_name: str, kwargs: object) -> None: + self.connection_name = connection_name + self.conns = conns + self._kwargs = kwargs + self.is_idle_state = False + self.reconnect_lock = threading.Lock() + + def check_state_and_reconnect_later(self): + check_after_seconds = 3 + logger.debug(f"state is idle, schedule reconnect in {check_after_seconds} seconds") + time.sleep(check_after_seconds) + if not self.is_idle_state: + logger.debug("idle state changed, skip reconnect") + return + with self.reconnect_lock: + logger.info("reconnect on idle state") + self.is_idle_state = False + try: + logger.debug("try disconnecting old connection...") + self.conns.disconnect(self.connection_name) + except Exception: + logger.warning("disconnect failed: {e}") + finally: + reconnected = False + while not reconnected: + try: + logger.debug("try reconnecting...") + self.conns.connect(self.connection_name, **self._kwargs) + reconnected = True + except Exception as e: + logger.warning( + f"reconnect failed: {e}, try again after {check_after_seconds} seconds" + ) + time.sleep(check_after_seconds) + logger.info("reconnected") + + def reconnect_on_idle(self, state: object): + logger.debug(f"state change to: {state}") + with self.reconnect_lock: + if state.value[1] != "idle": + self.is_idle_state = False + return + self.is_idle_state = True + threading.Thread(target=self.check_state_and_reconnect_later).start() + + class Connections(metaclass=SingleInstanceMetaClass): """Class for managing all connections of milvus. Used as a singleton in this module.""" @@ -270,6 +321,8 @@ def connect( Optional. Serving as the key for identification and authentication purposes. Whenever a token is furnished, we shall supplement the corresponding header to each RPC call. + * *keep_alive* (``bool``) -- + Optional. Default is false. If set to true, client will keep an alive connection. * *db_name* (``str``) -- Optional. default database name of this connection * *client_key_path* (``str``) -- @@ -293,6 +346,13 @@ def connect( >>> connections.connect("test", host="localhost", port="19530") """ + # kwargs_copy is used for auto reconnect + kwargs_copy = copy.deepcopy(kwargs) + kwargs_copy["user"] = user + kwargs_copy["password"] = password + kwargs_copy["db_name"] = db_name + kwargs_copy["token"] = token + def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) @@ -300,6 +360,10 @@ def connect_milvus(**kwargs): timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) + if kwargs.get("keep_alive", False): + gh.register_state_change_callback( + ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle + ) kwargs.pop("password") kwargs.pop("token", None) kwargs.pop("secure", None)