Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto reconnection when channel state changed to idle #1846

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from urllib import parse

import grpc
Expand Down Expand Up @@ -87,6 +87,16 @@ def __init__(
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name", None))
self._setup_grpc_channel()
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
self.callbacks.append(callback)
self._channel.subscribe(callback, try_to_connect=True)

def deregister_state_change_callbacks(self):
for callback in self.callbacks:
self._channel.unsubscribe(callback)
self.callbacks = []

def __get_address(self, uri: str, host: str, port: str) -> str:
if host != "" and port != "" and is_legal_host(host) and is_legal_port(port):
Expand Down Expand Up @@ -141,6 +151,7 @@ def _wait_for_channel_ready(self, timeout: Union[float] = 10):
raise e from e

def close(self):
self.deregister_state_change_callbacks()
self._channel.close()

def reset_db_name(self, db_name: str):
Expand Down
64 changes: 64 additions & 0 deletions pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# the License.

import copy
import logging
import threading
import time
from typing import Callable, Tuple, Union
from urllib import parse

Expand All @@ -24,6 +26,8 @@
)
from pymilvus.settings import Config

logger = logging.getLogger(__name__)

VIRTUAL_PORT = 443


Expand Down Expand Up @@ -58,6 +62,53 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)


class ReconnectHandler:
def __init__(self, conns: object, connection_name: str, kwargs: object) -> None:
self.connection_name = connection_name
self.conns = conns
self._kwargs = kwargs
self.is_idle_state = False
self.reconnect_lock = threading.Lock()

def check_state_and_reconnect_later(self):
check_after_seconds = 3
logger.debug(f"state is idle, schedule reconnect in {check_after_seconds} seconds")
time.sleep(check_after_seconds)
if not self.is_idle_state:
logger.debug("idle state changed, skip reconnect")
return
with self.reconnect_lock:
logger.info("reconnect on idle state")
self.is_idle_state = False
try:
logger.debug("try disconnecting old connection...")
self.conns.disconnect(self.connection_name)
except Exception:
logger.warning("disconnect failed: {e}")
finally:
reconnected = False
while not reconnected:
try:
logger.debug("try reconnecting...")
self.conns.connect(self.connection_name, **self._kwargs)
reconnected = True
except Exception as e:
logger.warning(
f"reconnect failed: {e}, try again after {check_after_seconds} seconds"
)
time.sleep(check_after_seconds)
logger.info("reconnected")

def reconnect_on_idle(self, state: object):
logger.debug(f"state change to: {state}")
with self.reconnect_lock:
if state.value[1] != "idle":
self.is_idle_state = False
return
self.is_idle_state = True
threading.Thread(target=self.check_state_and_reconnect_later).start()


class Connections(metaclass=SingleInstanceMetaClass):
"""Class for managing all connections of milvus. Used as a singleton in this module."""

Expand Down Expand Up @@ -270,6 +321,8 @@ def connect(
Optional. Serving as the key for identification and authentication purposes.
Whenever a token is furnished, we shall supplement the corresponding header
to each RPC call.
* *keep_alive* (``bool``) --
Optional. Default is false. If set to true, client will keep an alive connection.
* *db_name* (``str``) --
Optional. default database name of this connection
* *client_key_path* (``str``) --
Expand All @@ -293,13 +346,24 @@ def connect(
>>> connections.connect("test", host="localhost", port="19530")
"""

# kwargs_copy is used for auto reconnect
kwargs_copy = copy.deepcopy(kwargs)
kwargs_copy["user"] = user
kwargs_copy["password"] = password
kwargs_copy["db_name"] = db_name
kwargs_copy["token"] = token

def connect_milvus(**kwargs):
gh = GrpcHandler(**kwargs)

t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
if kwargs.get("keep_alive", False):
gh.register_state_change_callback(
ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle
)
kwargs.pop("password")
kwargs.pop("token", None)
kwargs.pop("secure", None)
Expand Down
Loading