diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 4e29c9dd6..df03a4040 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -289,17 +289,4 @@ def traverse_info(fields_info: Any, entities: List): def get_server_type(host: str): - if host is None or not isinstance(host, str): - return MILVUS - splits = host.split(".") - len_of_splits = len(splits) - if ( - len_of_splits >= 2 - and ( - splits[len_of_splits - 2].lower() == "zilliz" - or splits[len_of_splits - 2].lower() == "zillizcloud" - ) - and splits[len_of_splits - 1].lower() == "com" - ): - return ZILLIZ - return MILVUS + return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 783ee823b..7826de8f2 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -17,7 +17,6 @@ from pymilvus.client.check import is_legal_address, is_legal_host, is_legal_port from pymilvus.client.grpc_handler import GrpcHandler -from pymilvus.client.utils import ZILLIZ, get_server_type from pymilvus.exceptions import ( ConnectionConfigException, ConnectionNotExistException, @@ -113,7 +112,7 @@ def __verify_host_port(self, host: str, port: Union[int, str]): def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): illegal_uri_msg = ( - "Illegal uri: [{}], expected form 'http[s]://[user:password@]example.com:12345'" + "Illegal uri: [{}], expected form 'http[s]://[user:password@]example.com[:12345]'" ) try: parsed_uri = parse.urlparse(uri) @@ -126,7 +125,8 @@ def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST - port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT + default_port = "443" if parsed_uri.scheme == "https" else Config.DEFAULT_PORT + port = parsed_uri.port if parsed_uri.port is not None else default_port addr = f"{host}:{port}" self.__verify_host_port(host, port) @@ -302,7 +302,6 @@ def connect_milvus(**kwargs): gh._wait_for_channel_ready(timeout=timeout) kwargs.pop("password") kwargs.pop("token", None) - kwargs.pop("db_name", None) kwargs.pop("secure", None) kwargs.pop("db_name", "") @@ -315,14 +314,6 @@ def with_config(config: Tuple) -> bool: if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - # Set port if server type is zilliz cloud serverless - uri = kwargs.get("uri") - if uri is not None: - server_type = get_server_type(uri) - parsed_uri = parse.urlparse(uri) - if server_type == ZILLIZ and parsed_uri.port is None: - kwargs["uri"] = uri + ":" + str(VIRTUAL_PORT) - config = ( kwargs.pop("address", ""), kwargs.pop("uri", ""), @@ -335,14 +326,15 @@ def with_config(config: Tuple) -> bool: # 1st Priority: connection from params if with_config(config): - in_addr, parsed_uri = self.__get_full_address(*config) - kwargs["address"] = in_addr + addr, parsed_uri = self.__get_full_address(*config) + kwargs["address"] = addr - if self.has_connection(alias) and self._alias[alias].get("address") != in_addr: + if self.has_connection(alias) and self._alias[alias].get("address") != addr: raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) # uri might take extra info if parsed_uri is not None: + # get db_name from uri user = parsed_uri.username or user password = parsed_uri.password or password diff --git a/tests/test_connections.py b/tests/test_connections.py index cb862d51e..9e2a01203 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -48,7 +48,9 @@ def no_host_or_port(self, request): {"uri": "http://127.0.0.1:19530"}, {"uri": "http://example.com:80"}, {"uri": "http://example.com:80/database1"}, - {"uri": "https://127.0.0.1:19530/databse2"}, + {"uri": "https://127.0.0.1:19530/database2"}, + {"uri": "https://127.0.0.1/database3"}, + {"uri": "http://127.0.0.1/database4"}, ]) def uri(self, request): return request.param diff --git a/tests/test_utils.py b/tests/test_utils.py index f1c3d0187..ad852b0a2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,8 +5,8 @@ class TestUtils: def test_get_server_type(self): urls_and_wants = [ ('in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com', 'zilliz'), - ('something.notzillizcloud.com', 'milvus'), - ('something.zillizcloud.not.com', 'milvus') + ('something.abc.com', 'milvus'), + ('something.zillizcloud.cn', 'zilliz') ] for (url, want) in urls_and_wants: assert utils.get_server_type(url) == want