diff --git a/pymilvus/milvus_client/__init__.py b/pymilvus/milvus_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymilvus/milvus_client/defaults.py b/pymilvus/milvus_client/defaults.py new file mode 100644 index 000000000..446e938d6 --- /dev/null +++ b/pymilvus/milvus_client/defaults.py @@ -0,0 +1,12 @@ +"""Default MilvusClient args.""" + +DEFAULT_SEARCH_PARAMS = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, +} diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py new file mode 100644 index 000000000..1853f5f7b --- /dev/null +++ b/pymilvus/milvus_client/milvus_client.py @@ -0,0 +1,816 @@ +"""MilvusClient for dealing with simple workflows.""" +import logging +import threading +from typing import Optional, Union, List, Dict +from uuid import uuid4 + +from tqdm import tqdm +from pymilvus.client.types import LoadState +from pymilvus.exceptions import MilvusException +from pymilvus.milvus_client.defaults import DEFAULT_SEARCH_PARAMS +from pymilvus.orm import utility +from pymilvus.orm.collection import Collection, CollectionSchema, FieldSchema +from pymilvus.orm.connections import connections +from pymilvus.orm.types import DataType, infer_dtype_bydata + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + + +class MilvusClient: + """The Milvus Client""" + def __init__( + self, + collection_name: str = "ClientCollection", + pk_field: str = None, + vector_field: str = None, + uri: str = None, + shard_num: int = None, + partitions: List[str] = None, + consistency_level: str = "Bounded", + replica_number: int = 1, + index_params: dict = None, + timeout: Optional[int] = None, + drop_old: bool = False, + ): + """A client for the common Milvus use case. + + This client attempts to hide away the complexity of using Pymilvus. In a lot ofcases what + the user wants is a simple wrapper that supports adding data, deleting data, and searching. + This wrapper can autoinfer the schema from a previous collection or newly inserted data, + can update the paritions, can query, and can delete by pk. + + Args: + pk_field (str, optional): Which entry in data is considered the primary key. If None, + an auto-id will be created. Will be overwritten if loading from a previous + collection. Defaults to None. + vector_field (str, optional): Which entry in the data is considered the vector field. + Will get overwritten if loading from previous collection. Defaults to None. + uri (str, optional): The connection address to use to connect to the + instance. Defaults to "http://localhost:19530". You can also set this address + as an env variable + shard_num (int, optional): The amount of shards to use for the collection. Unless + dealing with huge scale, recommended to keep at default. Defaults to None and allows + server to set. + partitions (List[str], optional): Which paritions to create for the collection. + Defaults to None. + consistency_level (str, optional): Which consistency level to use for the Client. + The options are "Strong", "Bounded", "Eventually", "Session". Defaults to "Bounded". + replica_number (int, optional): The amount of in memomory replicas to use. + Defaults to 1. + index_params (dict, optional): What index parameteres to use for the Collection. + If none, will use a default one. Defaults to None. + timeout (Optional[int], optional): What timeout to use for function calls. Defaults + to None. + drop_old (bool, optional): If a collection with the same name already exists, drop it. + Defaults to False. + """ + self.uri = uri + self.collection_name = collection_name + self.shard_num = shard_num + self.partitions = partitions + self.consistency_level = consistency_level + self.replica_number = replica_number + self.index_params = index_params + self.timeout = timeout + self.pk_field = pk_field + self.vector_field = vector_field + + # TODO: Figure out thread safety + # self.concurrent_counter = 0 + self.concurrent_lock = threading.RLock() + self.dim = None + self.default_search_params = None + self.collection = None + self.fields = None + + self.alias = self._create_connection() + self.is_self_hosted = bool( + utility.get_server_type(using=self.alias) == "milvus" + ) + if drop_old: + self.delete_collection() + self._init(None) + + def insert_data( + self, + data: List[Dict[str, any]], + timeout: int = None, + batch_size: int = 100, + partition: str = None, + progress_bar: bool = False, + ) -> List[Union[str, int]]: + """Insert data into the collection. + + If the Milvus Client was initiated without an existing Collection, the first dict passed + in will be used to initiate the collection. + + Args: + data (List[Dict[str, any]]): A list of dicts to pass in. If list not provided, will + cast to list. + timeout (int, optional): The timeout to use, will override init timeout. Defaults + to None. + batch_size (int, optional): The batch size to perform inputs with. Defaults to 100. + partition (str, optional): Which partition to insert into. Defaults to None. + progress_bar (bool, optional): Whether to display a progress bar for the input. + Defaults to False. + + Raises: + DataNotMatchException: If the data has misssing fields an exception will be thrown. + MilvusException: General Milvus error on insert. + + Returns: + List[Union[str, int]]: A list of primary keys that were inserted. + """ + # If no data provided, we cannot input anything + if len(data) == 0: + return [] + + if batch_size < 1: + logger.error( + "Invalid batch size provided for insert." + ) + + raise ValueError("Invalid batch size provided for insert.") + + # If the collection hasnt been initialized, initialize it + with self.concurrent_lock: + if self.collection is None: + self._init(data[0]) + + # Dont include the primary key if auto_id is true and they included it in data + ignore_pk = self.pk_field if self.collection.schema.auto_id else None + insert_dict = {} + pks = [] + + for k in data: + for key, value in k.items(): + if key in self.fields: + insert_dict.setdefault(key, []).append(value) + + # Insert the data in batches + for i in tqdm(range(0, len(data), batch_size), disable=not progress_bar): + # Convert dict to list of lists batch for insertion + try: + insert_batch = [insert_dict[key][i : i + batch_size] for key in self.fields if key != ignore_pk] + except KeyError as ex: + logger.error( + "Malformed data, one of the inserts does not contain all fields required." + ) + raise ex + # Insert into the collection. + try: + res = self.collection.insert( + insert_batch, + timeout=timeout or self.timeout, + partition_name=partition, + ) + pks.extend(res.primary_keys) + except MilvusException as ex: + logger.error( + "Failed to insert batch starting at entity: %s/%s", str(i), str(len(data)) + ) + raise ex + return pks + + def upsert_data( + self, + data: List[Dict[str, any]], + timeout: int = None, + batch_size: int = 100, + partition: str = None, + progress_bar: bool = False, + ) -> List[Union[str, int]]: + """WARNING: SLOW AND NOT ATOMIC. Will be updated for 2.3 release. + + Upsert the data into the collection. + + If the Milvus Client was initiated without an existing Collection, the first dict passed + in will be used to initiate the collection. + + Args: + data (List[Dict[str, any]]): A list of dicts to upsert. + timeout (int, optional): The timeout to use, will override init timeout. Defaults + to None. + batch_size (int, optional): The batch size to perform inputs with. Defaults to 100. + partition (str, optional): Which partition to insert into. Defaults to None. + progress_bar (bool, optional): Whether to display a progress bar for the input. + Defaults to False. + Returns: + List[Union[str, int]]: A list of primary keys that were inserted. + """ + # If the collection exists we need to first delete the values + if self.collection is not None: + pks = [x[self.pk_field] for x in data] + self.delete_by_pk(pks, timeout) + + ret = self.insert_data( + data=data, + timeout=timeout, + batch_size=batch_size, + partition=partition, + progress_bar=progress_bar + ) + + return ret + + def search_data( + self, + data: Union[List[list], list], + search_params: dict = None, + filter_expression: str = None, + top_k: int = 10, + partitions: List[str] = None, + timeout: int = None, + ) -> List[dict]: + """Search for a query vector/vectors. + + In order for the search to process, a collection needs to have been either provided + at init or data needs to have been inserted. + + Args: + data (Union[List[list], list]): The vector/vectors to search. + search_params (dict, optional): The search params to use for the search. Will default + to the default set for the client. + filter_expression (str, optional): A filter to use for the search. Defaults to None. + top_k (int, optional): How many results to return per search. Defaults to 10. + partitions (List[str], optional): Which partitions to search within. Defaults to + searching through all. + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + + Raises: + ValueError: The collection being searched doesnt exist. Need to insert data first. + + Returns: + List[dict]: A list of dicts containing the score and the result data. Embeddings are + not included in the result data. + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + if not isinstance(data[0], list): + data = [data] + + return_fields = list(self.fields.keys()) + return_fields.remove(self.vector_field) + + res = self.collection.search( + data, + self.vector_field, + expr=filter_expression, + param=search_params or self.default_search_params, + limit=top_k, + partition_names=partitions, + output_fields=return_fields, + timeout=timeout or self.timeout, + ) + + ret = [] + for hits in res: + for hit in hits: + ret_dict = {x: hit.entity.get(x) for x in return_fields} + ret.append({"score": hit.score, "data": ret_dict}) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + return ret + + def query_data( + self, + filter_expression: str, + partitions: List[str] = None, + timeout: int = None, + ) -> List[dict]: + """Query for entries in the Collection. + + Args: + filter_expression (str): The filter to use for the query. + partitions (List[str], optional): Which partitions to perform query. Defaults to None. + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + + Raises: + ValueError: Missing collection. + + Returns: + List[dict]: A list of result dicts, embeddings are not included. + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + return_fields = list(self.fields.keys()) + return_fields.remove(self.vector_field) + + res = self.collection.query( + expr=filter_expression, + partition_names=partitions, + output_fields=return_fields, + timeout=timeout or self.timeout, + ) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + return res + + def get_embeddings_by_pk( + self, + pks: List[Union[str, int]], + timeout: int = None, + ) -> None: + """Grab the inserted embeddings using the primary key from the Collection. + + Due to current implementations, grabbing a large amount of vectors is slow. + + Args: + filter_expression (str): The filter to use for the query. + timeout (int, optional): Timeout to use, overides the client level assigned at + init. Defaults to None. + + Raises: + ValueError: Missing collection. + + Returns: + List[dict]: A list of result dicts with keys {pk_field, vector_field} + """ + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + raise ValueError( + "Missing collection. Make sure data inserted or intialized on existing collection." + ) + + # Varchar pks need double quotes around the values + if self.fields[self.pk_field] == DataType.VARCHAR: + ids = ['"' + str(entry) + '"' for entry in pks] + expr = f""""{self.pk_field}" in [{','.join(ids)}]""" + else: + ids = [str(entry) for entry in pks] + expr = f"{self.pk_field} in [{','.join(ids)}]" + + res = self.collection.query( + expr=expr, + output_fields=[self.vector_field], + timeout=timeout or self.timeout, + ) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + return res + + def delete_by_pk( + self, + pks: list, + timeout: int = None, + ) -> None: + """Delete entries in the collection by their pk. + + Delete all the entries based on the pk. If unsure of pk you can first query the collection + to grab the corresponding data. Then you can delete using the pk_field. + + Args: + pks (list): _description_ + timeout (int, optional): Timeout to use, overides the client level assigned at init. + Defaults to None. + """ + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter += 1 + + if self.collection is None: + logger.error("Collection does not exist: %s", self.collection_name) + + if len(pks) == 0: + return + + if self.fields[self.pk_field] == DataType.VARCHAR: + ids = ['"' + str(entry) + '"' for entry in pks] + expr = f""""{self.pk_field}" in [{','.join(ids)}]""" + else: + ids = [str(entry) for entry in pks] + expr = f"{self.pk_field} in [{','.join(ids)}]" + + self.collection.delete(expr=expr, timout=timeout or self.timeout) + + # TODO: Figure out thread safety + # with self.concurrent_lock: + # self.concurrent_counter -= 1 + + def delete_collection( + self, + ) -> None: + """Delete the collection.""" + # TODO: Figure out thread safety + with self.concurrent_lock: + if self.collection is not None: + self.collection.drop() + self.collection = None + + def close( + self, + drop_collection: bool = True, + ): + if drop_collection == True: + self.delete_collection() + + connections.remove_connection(self.alias) + + def add_partitions(self, input_partitions: List[str]): + """Add partitions to the collection. + + Add a list of partition names to the collection. If the collection is loaded + it will first be unloaded, then the partitions will be added, and then reloaded. + + Args: + input_partitions (List[str]): The list of partition names to be added. + + Raises: + MilvusException: Unable to add the partition. + """ + # TODO: Figure out thread safety + with self.concurrent_lock: + if self.collection is not None and self.is_self_hosted: + # Calculate which partitions need to be added + input_partitions = set(input_partitions) + current_partitions = { + partition.name for partition in self.collection.partitions + } + new_partitions = input_partitions.difference(current_partitions) + # If partitions need to be added, add them + if len(new_partitions) != 0: + # Try to unload the collection, if exception already released + reload = False + try: + self.collection.release(timeout=self.timeout) + reload = True + except MilvusException: + pass + + try: + for part in new_partitions: + self.collection.create_partition(part) + logger.info( + "Successfully added partitions to collection: %s partitions: %s", + self.collection_name, + ",".join(part for part in list(new_partitions)), + ) + except MilvusException as ex: + logger.debug( + "Failed to add partitions to: %s", self.collection_name + ) + raise ex + # If the collection started out loaded, reload it. + if reload: + self._load() + else: + logger.debug( + "No parititons to add for collection: %s", self.collection_name + ) + else: + logger.debug( + "Collection either on Zilliz or non existant for collection: %s", + self.collection_name, + ) + + def delete_partitions(self, remove_partitions: List[str]): + """Remove partitions from the collection. + + Remove a list of partition names from the collection. If the collection is loaded + it will first be unloaded, then the partitions will be removed, and then reloaded. + + Args: + remove_partitions (List[str]): The list of partition names to be removed. + + Raises: + MilvusException: Unable to remove the partition. + """ + with self.concurrent_lock: + if self.collection is not None and self.is_self_hosted: + # Calculate which partitions need to be removed + remove_partitions = set(remove_partitions) + current_partitions = { + partition.name for partition in self.collection.partitions + } + removal_partitions = remove_partitions.intersection(current_partitions) + # If partitions need to be added, add them + if len(removal_partitions) != 0: + # Try to unload the collection, if exception raised it is most likely already + # released + reload = False + try: + self.collection.release(timeout=self.timeout) + reload = True + except MilvusException: + pass + try: + for part in removal_partitions: + self.collection.drop_partition(part) + logger.info( + "Successfully deleted partitions from collection: %s partitions: %s", + self.collection_name, + ",".join(part for part in list(removal_partitions)), + ) + except MilvusException as ex: + logger.debug( + "Failed to delete partitions from: %s", self.collection_name + ) + raise ex + # If the collection started out loaded, reload it. + if reload: + self._load() + else: + logger.debug( + "No parititons to delete for collection: %s", + self.collection_name, + ) + + def _create_connection(self) -> str: + """Create the connection to the Milvus server.""" + # TODO: Implement reuse with new uri style + alias = uuid4().hex + try: + connections.connect(alias=alias, uri = self.uri) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as ex: + logger.error("Failed to create new connection using: %s", alias) + raise ex + + def _init(self, input_data: Optional[dict]): + """Create/connect to the colletion""" + # If no input data and collection exists, use that + if input_data is None and utility.has_collection( + self.collection_name, using=self.alias + ): + self.collection = Collection(self.collection_name, using=self.alias) + # Grab the field information from the existing collection + self._extract_fields() + # If data is supplied we can create a new collection + elif input_data is not None: + self._create_collection(input_data) + # Nothin to init from + else: + logger.debug( + "No information to perform init from for collection %s", + self.collection_name, + ) + return + self._create_index() + # Partitions only allowed on Milvus at the moment + if self.is_self_hosted and self.partitions is not None: + self.add_partitions(self.partitions) + self._create_default_search_params() + self._load() + + def _create_collection(self, data: dict) -> None: + """Create the collection by autoinferring the schema.""" + # TODO: Assuming ordered dict for 3.7 + fields = {} + + # Figure out each datatype of the input. + for key, value in data.items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + + # Datatype isnt compatible + if dtype in (DataType.UNKNOWN, DataType.NONE): + logger.error( + "Failed to parse schema for collection %s, unrecognized dtype for key: %s", + self.collection_name, + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + + # Create an entry under the field name + fields[key] = {} + fields[key]["name"] = key + fields[key]["dtype"] = dtype + + # Area for attaching kwargs for certain datatypes + if dtype == DataType.VARCHAR: + fields[key]["max_length"] = 65_535 + + if self.vector_field is None: + logger.error( + "Missing vector_field, cannot infer schema for collection: %s", + self.collection_name, + ) + raise ValueError("Missing vector_field, cannot autoinfer schema.") + + try: + self.dim = len(data[self.vector_field]) + # Attach dim kwarg to vector field + fields[self.vector_field]["dim"] = self.dim + except KeyError as ex: + logger.error( + "Missing vector_field: %s in data for collection: %s", + self.vector_field, + self.collection_name, + ) + raise ex + + if self.pk_field is None: + # Generate a unique auto-id field + self.pk_field = "internal_pk_" + uuid4().hex[:4] + # Create a new field for pk + fields[self.pk_field] = {} + fields[self.pk_field]["name"] = self.pk_field + fields[self.pk_field]["dtype"] = DataType.INT64 + fields[self.pk_field]["auto_id"] = True + fields[self.pk_field]["is_primary"] = True + logger.debug( + "Missing pk_field, creating auto-id pk for collection: %s", + self.collection_name, + ) + else: + # If pk_field given, assume it was iterated + try: + fields[self.pk_field]["auto_id"] = False + fields[self.pk_field]["is_primary"] = True + except KeyError as ex: + logger.error( + "Missing pk_field: %s in data for collection: %s", + self.pk_field, + self.collection_name, + ) + raise ex + try: + # Create the fieldschemas + fieldschemas = [] + # TODO: Assuming ordered dicts for 3.7 + self.fields = {} + for field_dict in fields.values(): + fieldschemas.append(FieldSchema(**field_dict)) + self.fields[field_dict["name"]] = field_dict["dtype"] + # Create the schema for the collection + schema = CollectionSchema(fieldschemas) + # Create the collection + self.collection = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + shards_num=self.shard_num, + using=self.alias, + ) + logger.error("Successfully created collection: %s", self.collection_name) + except MilvusException as ex: + logger.error("Failed to create collection: %s", self.collection_name) + raise ex + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + self.fields = {} + schema = self.collection.schema + for field in schema.fields: + field_dict = field.to_dict() + if field_dict.get("is_primary", None) is not None: + logger.debug("Updating pk_field with one from collection.") + self.pk_field = field_dict["name"] + if field_dict["type"] in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + logger.debug("Updating vector_field with one from collection.") + self.vector_field = field_dict["name"] + self.fields[field_dict["name"]] = field_dict["type"] + + logger.info( + "Successfully extracted fields from for collection: %s, total fields: %s, " + "pk_field: %s, vector_field: %s", + self.collection_name, + len(self.fields), + self.pk_field, + self.vector_field, + ) + + def _create_index(self) -> None: + """Create a index on the collection""" + if self._get_index() is None: + # If no index params, use a default HNSW based one + if self.index_params is None: + # TODO: Once segment normalization we can default to IP + metric_type = ( + "L2" + if self.fields[self.vector_field] == DataType.FLOAT_VECTOR + else "JACCARD" + ) + # TODO: Once AUTOINDEX type is supported by Milvus we can default to HNSW always + index_type = "HNSW" if self.is_self_hosted else "AUTOINDEX" + params = {"M": 8, "efConstruction": 64} if self.is_self_hosted else {} + self.index_params = { + "metric_type": metric_type, + "index_type": index_type, + "params": params, + } + try: + self.collection.create_index( + self.vector_field, + index_params=self.index_params, + using=self.alias, + timeout=self.timeout, + ) + logger.info( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as ex: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise ex + else: + logger.debug( + "Index exists already for collection: %s", self.collection_name + ) + + def _get_index(self): + """Return the index dict if index exists.""" + for index in self.collection.indexes: + if index.field_name == self.vector_field: + return index + return None + + def _create_default_search_params(self) -> None: + """Generate search params based on the current index type""" + index = self._get_index().to_dict() + if index is not None: + index_type = index["index_param"]["index_type"] + metric_type = index["index_param"]["metric_type"] + self.default_search_params = DEFAULT_SEARCH_PARAMS[index_type] + self.default_search_params["metric_type"] = metric_type + + def _load(self): + """Loads the collection.""" + if self._get_index() is not None: + # Check if the collection is loaded or in progress of loading + if ( + utility.load_state( + self.collection_name, using=self.alias, timeout=self.timeout + ) + != LoadState.NotLoad + ): + # IF the collection is loaded/loading, check the replica count + if len(self.collection.get_replicas().groups) == self.replica_number: + logger.debug("Collection already loaded.") + return + + # If the replica count is incorrect, release the collection + try: + self.collection.release(timeout=self.timeout) + logger.debug( + "Successfully released collection due to incorrect replica: %s", + self.collection_name, + ) + except MilvusException as ex: + logger.error( + "Failed to release collection with incorrect num_replicas: %s", + self.collection_name, + ) + raise ex + # Try to load in the collection with correct replica count + try: + self.collection.load(replica_number=self.replica_number) + logger.info( + "Successfully loaded collection with correct replica_count: %s", + self.collection_name, + ) + except MilvusException: + logger.error( + "Failed to load collection with num_replicas greater than one: %s, " + / "attempting num_replicas==1", + self.collection_name, + ) + # If load fails, try to load in with only 1 replica (standalone) + try: + self.collection.load(replica_number=1) + logger.info( + "Successfully loaded collection with num_replicas==1: %s", + self.collection_name, + ) + # If both loads fail, raise exception + except MilvusException as ex: + logger.error("Failed to load collection: %s", self.collection_name) + raise ex diff --git a/pymilvus/milvus_client/milvus_client_tests.py b/pymilvus/milvus_client/milvus_client_tests.py new file mode 100644 index 000000000..2a505d57b --- /dev/null +++ b/pymilvus/milvus_client/milvus_client_tests.py @@ -0,0 +1,148 @@ +"""Test the MilvusClient""" +import logging +import sys +from uuid import uuid4 +import numpy as np + +from pymilvus import FieldSchema, DataType, CollectionSchema, connections, utility, Collection +from pymilvus.milvus_client.milvus_client import MilvusClient + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +handler.setFormatter(formatter) +logger.addHandler(handler) + + +""" +Tests to Run: + +Construct non existant collection +Construct existant collection + +Insert data existant collection +Insert data nonexistant collection + +Insert non matching data existant collection +Insert non matching data nonexistant collection + +Insert insert data into auto_id with pk field +insert data into auto_id without pk field + +""" + +MILVUS_URI = "http://localhost:19530" +COLLECTION_NAME = "test" + +def valid_data(seed: int): + datas = [] + count = 10 + for cur in range(count): + float_num = seed + (cur / 10) + int_num = (seed * 10) + cur + temp = { + "varchar": str(float_num)[:5], + "float": np.float32(float_num), + "int": int_num, + "float_vector": [float_num] * 3 + } + datas.append(temp) + + return datas + +def create_existing_collection(uri, collection_name): + alias = uuid4().hex + connections.connect(uri=uri, alias=alias) + if utility.has_collection(collection_name=collection_name, using=alias): + utility.drop_collection(collection_name=collection_name, using=alias) + fields = [ + FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim = 3), + FieldSchema(name="int", dtype=DataType.INT64, is_primary = True, auto_id = True), + FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length= 65_535) + ] + schema = CollectionSchema(fields) + + ret = { + "col": Collection(collection_name, schema, using=alias), + "fields": ["float_vector", "int", "float", "varchar"], + "primary_field": "int", + "vector_field": "float_vector" + } + + return ret + + +class TestMilvusClient: + @staticmethod + def test_construct_from_existing_collection(): + info = create_existing_collection(MILVUS_URI, COLLECTION_NAME) + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + assert list(client.fields.keys()) == info["fields"] + assert client.pk_field == info["primary_field"] + assert client.vector_field == info["vector_field"] + info["col"].drop() + + @staticmethod + def test_construct_from_nonexistant_collection(): + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + assert client.fields == None + assert client.pk_field == None + assert client.vector_field == None + + @staticmethod + def test_insert_in_existing_collection_valid(): + info = create_existing_collection(MILVUS_URI, COLLECTION_NAME) + client = MilvusClient(collection_name=COLLECTION_NAME, uri=MILVUS_URI) + client.insert_data(valid_data(1)) + info["col"].drop() + + + +if __name__ == "__main__": + # TestMilvusClient.test_construct_from_existing_collection() + # TestMilvusClient.test_construct_from_nonexistant_collection() + TestMilvusClient.test_insert_in_existing_collection_valid() + +# import sys + +# # Test the insert +# outs = s.insert_data(test_data, partition="lol") +# pprint(outs) +# rets = s.search_data([0, 0, 0, 0, 0, 0, 0]) +# pprint(rets) + +# # Test the searches +# rets = s.search_data([0, 0, 0, 0, 0, 0, 0]) +# pprint(rets) + +# rets = s.search_data([0, 0, 0, 0, 0, 0, 0], partitions=["lol"]) +# pprint(rets) + +# rets = s.search_data([0, 0, 0, 0, 0, 0, 0], partitions=["default"]) +# pprint(rets) + +# # Test the query +# rets = s.query_data(s.pk_field + " in [1]") +# print(rets) + +# rets = s.get_embeddings_by_pk([1]) +# print(rets) + +# # pprint(s.collection.partitions) + +# # ret = s.search_data([0, 0, 0, 0, 0, 0, 0]) +# # pprint(ret) + +# # ret = s.query_data("""char in ["bar"]""") +# # pprint(ret) + +# # s.delete_by_pk([1]) + +# # ret = s.query_data("""char in ["bar"]""") +# # pprint(ret)