diff --git a/examples/hello_cost.py b/examples/hello_cost.py new file mode 100644 index 000000000..aa14d353c --- /dev/null +++ b/examples/hello_cost.py @@ -0,0 +1,187 @@ +# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus. +# 1. connect to Milvus +# 2. create collection +# 3. insert data +# 4. create index +# 5. search, query, and hybrid search on entities +# 6. delete entities by PK +# 7. drop collection +import time + +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim = 10, 8 + +################################################################################# +# 1. connect to Milvus +# Add a new connection alias `default` for Milvus server in `localhost:19530` +# Actually the "default" alias is a buildin in PyMilvus. +# If the address of Milvus is the same as `localhost:19530`, you can omit all +# parameters and call the method as: `connections.connect()`. +# +# Note: the `using` parameter of the following methods is default to "default". +print(fmt.format("start connecting to Milvus")) +connections.connect("default", host="localhost", port="19530") + +collection_name = "hello_cost" +has = utility.has_collection(collection_name) +print(f"Does collection {collection_name} exist in Milvus: {has}") + +################################################################################# +# 2. create collection +# We're going to create a collection with 3 fields. +# +-+------------+------------+------------------+------------------------------+ +# | | field name | field type | other attributes | field description | +# +-+------------+------------+------------------+------------------------------+ +# |1| "pk" | VarChar | is_primary=True | "primary field" | +# | | | | auto_id=False | | +# +-+------------+------------+------------------+------------------------------+ +# |2| "random" | Double | | "a double field" | +# +-+------------+------------+------------------+------------------------------+ +# |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" | +# +-+------------+------------+------------------+------------------------------+ +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim) +] + +schema = CollectionSchema(fields, f"{collection_name} is the simplest demo to introduce the APIs") + +print(fmt.format(f"Create collection `{collection_name}`")) +hello_milvus = Collection(collection_name, schema, consistency_level="Strong") + +################################################################################ +# 3. insert data +# We are going to insert 3000 rows of data into `hello_milvus` +# Data to be inserted must be organized in fields. +# +# The insert() method returns: +# - either automatically generated primary keys by Milvus if auto_id=True in the schema; +# - or the existing primary key field from the entities if auto_id=False in the schema. + +print(fmt.format("Start inserting entities")) +rng = np.random.default_rng(seed=19530) +entities = [ + # provide the pk field because `auto_id` is set to False + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), # field random, only supports list + rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list +] + +insert_result = hello_milvus.insert(entities) +# OUTPUT: +# insert result: (insert count: 10, delete count: 0, upsert count: 0, timestamp: 449296288881311748, success count: 10, err count: 0, cost: 1); +# insert cost: 1 +print(f"insert result: {insert_result};\ninsert cost: {insert_result.cost}") + +hello_milvus.flush() +print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities + +################################################################################ +# 4. create index +# We are going to create an IVF_FLAT index for hello_milvus collection. +# create_index() can only be applied to `FloatVector` and `BinaryVector` fields. +print(fmt.format("Start Creating index IVF_FLAT")) +index = { + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, +} + +hello_milvus.create_index("embeddings", index) + +################################################################################ +# 5. search, query, and hybrid search +# After data were inserted into Milvus and indexed, you can perform: +# - search based on vector similarity +# - query based on scalar filtering(boolean, int, etc.) +# - hybrid search based on vector similarity and scalar filtering. +# + +# Before conducting a search or a query, you need to load the data in `hello_milvus` into memory. +print(fmt.format("Start loading")) +hello_milvus.load() + +# ----------------------------------------------------------------------------- +# search based on vector similarity +print(fmt.format("Start searching based on vector similarity")) +vectors_to_search = entities[-1][-2:] +search_params = { + "metric_type": "L2", + "params": {"nprobe": 10}, +} + +start_time = time.time() +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["random"]) +end_time = time.time() + +# OUTPUT: +# search result: data: ['["id: 8, distance: 0.0, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 0.49515748023986816, entity: {\'random\': 0.6378742006852851}", "id: 2, distance: 0.5305156707763672, entity: {\'random\': 0.1321158395732429}"]', '["id: 9, distance: 0.0, entity: {\'random\': 0.4494463384561439}", "id: 8, distance: 0.558194100856781, entity: {\'random\': 0.9007387227368949}", "id: 2, distance: 0.7718868255615234, entity: {\'random\': 0.1321158395732429}"]'], cost: 21; +# search cost: 21 +print(f"search result: {result};\nsearch cost: {result.cost}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# query based on scalar filtering(boolean, int, etc.) +print(fmt.format("Start querying with `random > 0.5`")) + +start_time = time.time() +result = hello_milvus.query(expr="random > 0.5", output_fields=["random", "embeddings"]) +end_time = time.time() + +# OUTPUT: +# query result: data: ["{'random': 0.6378742006852851, 'embeddings': [0.18477614, 0.42930314, 0.40345728, 0.3957196, 0.6963897, 0.24356908, 0.42512414, 0.5724385], 'pk': '0'}", "{'random': 0.744296470467782, 'embeddings': [0.8349225, 0.6614872, 0.98359716, 0.15854438, 0.30939594, 0.23553558, 0.1950739, 0.80361205], 'pk': '4'}", "{'random': 0.6025374094941409, 'embeddings': [0.36677808, 0.218786, 0.25240582, 0.82230526, 0.21011819, 0.16813536, 0.8129038, 0.74800706], 'pk': '7'}", "{'random': 0.9007387227368949, 'embeddings': [0.27464902, 0.07500089, 0.57728964, 0.6654878, 0.8698446, 0.3814792, 0.8825416, 0.58730817], 'pk': '8'}"], extra_info: {'cost': '21'}; +# query cost: 21 +print(f"query result: {result};\nquery cost: {result.extra['cost']}") +print(search_latency_fmt.format(end_time - start_time)) + + +# ----------------------------------------------------------------------------- +# hybrid search +print(fmt.format("Start hybrid searching with `random > 0.5`")) + +start_time = time.time() +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"]) +end_time = time.time() + +# OUTPUT: +# search result: data: ['["id: 8, distance: 0.0, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 0.49515748023986816, entity: {\'random\': 0.6378742006852851}", "id: 7, distance: 0.670731246471405, entity: {\'random\': 0.6025374094941409}"]', '["id: 8, distance: 0.558194100856781, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 1.0780366659164429, entity: {\'random\': 0.6378742006852851}", "id: 7, distance: 1.1083570718765259, entity: {\'random\': 0.6025374094941409}"]'], cost: 21; +# search cost: 21 +print(f"search result: {result};\nsearch cost: {result.cost}") +print(search_latency_fmt.format(end_time - start_time)) + +############################################################################### +# 6. delete entities by PK +# You can delete entities by their PK values using boolean expressions. +ids = insert_result.primary_keys + +expr = f'pk in ["{ids[0]}" , "{ids[1]}"]' +print(fmt.format(f"Start deleting with expr `{expr}`")) + +result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n") + +delete_result = hello_milvus.delete(expr) +# OUTPUT: +# delete result: (insert count: 0, delete count: 2, upsert count: 0, timestamp: 0, success count: 0, err count: 0, cost: 2); +# delete cost: 2 +print(f"delete result: {delete_result};\ndelete cost: {delete_result.cost}") + +result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query after delete by expr=`{expr}` -> result: {result}\n") + + +############################################################################### +# 7. drop collection +# Finally, drop the hello_milvus collection +print(fmt.format(f"Drop collection `{collection_name}`")) +utility.drop_collection(collection_name) diff --git a/examples/milvus_client/simple_cost.py b/examples/milvus_client/simple_cost.py new file mode 100644 index 000000000..d8b35bf0d --- /dev/null +++ b/examples/milvus_client/simple_cost.py @@ -0,0 +1,74 @@ +import time +import numpy as np +from pymilvus import ( + MilvusClient, +) + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_client_cost" +# milvus_client = MilvusClient("http://localhost:19530") +milvus_client = MilvusClient(uri="https://in01-20fa6a32462c074.aws-us-west-2.vectordb-uat3.zillizcloud.com:19541", + token="root:j6|y3/g$5Lq,a[TJ^ckphSMs{-F[&Jl)") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) +milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"id": 1, "vector": rng.random((1, dim))[0], "a": 100}, + {"id": 2, "vector": rng.random((1, dim))[0], "b": 200}, + {"id": 3, "vector": rng.random((1, dim))[0], "c": 300}, + {"id": 4, "vector": rng.random((1, dim))[0], "d": 400}, + {"id": 5, "vector": rng.random((1, dim))[0], "e": 500}, + {"id": 6, "vector": rng.random((1, dim))[0], "f": 600}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +# OUTPUT: +# insert result: {'insert_count': 6, 'ids': [1, 2, 3, 4, 5, 6], 'cost': '1'}; +# insert cost: 1 +print(f"insert result: {insert_result};\ninsert cost: {insert_result['cost']}") + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +# OUTPUT: +# query result: data: ["{'id': 2, 'vector': [0.9007387, 0.44944635, 0.18477614, 0.42930314, 0.40345728, 0.3957196, 0.6963897, 0.24356908], 'b': 200}"], extra_info: {'cost': '21'} +# query cost: 21 +print(f"query result: {query_results}\nquery cost: {query_results.extra['cost']}") + +upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "vector": rng.random((1, dim))[0], "g": 100}) +# OUTPUT: +# upsert result: {'upsert_count': 1, 'cost': '2'} +# upsert cost: 2 +print(f"upsert result: {upsert_ret}\nupsert cost: {upsert_ret['cost']}") + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +print(f"query result: {query_results}\nquery cost: {query_results.extra['cost']}") + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, ids=[6]) +# OUTPUT: +# delete result: {'delete_count': 1, 'cost': '1'} +# delete cost: 1 +print(f"delete result: {delete_result}\ndelete cost: {delete_result['cost']}") + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +print(f"search result: {result}\nsearch cost: {result.extra['cost']}") + +milvus_client.drop_collection(collection_name) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 6e78963fc..7a8db16d7 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -4,7 +4,7 @@ import ujson from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException -from pymilvus.grpc_gen import schema_pb2 +from pymilvus.grpc_gen import common_pb2, schema_pb2 from pymilvus.settings import Config from . import entity_helper, utils @@ -195,6 +195,7 @@ def __init__(self, raw: Any): self._timestamp = 0 self._succ_index = [] self._err_index = [] + self._cost = 0 self._pack(raw) @@ -234,10 +235,16 @@ def succ_index(self): def err_index(self): return self._err_index + # The unit of this cost is vcu, similar to token + @property + def cost(self): + return self._cost + def __str__(self): return ( f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " - f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})" + f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count}, " + f"cost: {self._cost})" ) __repr__ = __str__ @@ -262,6 +269,9 @@ def _pack(self, raw: Any): self._timestamp = raw.timestamp self._succ_index = raw.succ_index self._err_index = raw.err_index + self._cost = int( + raw.status.extra_info["report_value"] if raw.status and raw.status.extra_info else "0" + ) class SequenceIterator: @@ -374,10 +384,17 @@ def __str__(self): class SearchResult(list): """nq results: List[Hits]""" - def __init__(self, res: schema_pb2.SearchResultData, round_decimal: Optional[int] = None): + def __init__( + self, + res: schema_pb2.SearchResultData, + round_decimal: Optional[int] = None, + status: Optional[common_pb2.Status] = None, + ): self._nq = res.num_queries all_topks = res.topks + self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0") + output_fields = res.output_fields fields_data = res.fields_data @@ -497,7 +514,7 @@ def __iter__(self) -> SequenceIterator: def __str__(self) -> str: """Only print at most 10 query results""" - return str(list(map(str, self[:10]))) + return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, cost: {self.cost}" __repr__ = __str__ diff --git a/pymilvus/client/asynch.py b/pymilvus/client/asynch.py index d8e26575d..8bb329450 100644 --- a/pymilvus/client/asynch.py +++ b/pymilvus/client/asynch.py @@ -164,7 +164,7 @@ def exception(self): class SearchFuture(Future): def on_response(self, response: milvus_pb2.SearchResults): check_status(response.status) - return SearchResult(response.results) + return SearchResult(response.results, status=response.status) class MutationFuture(Future): diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index a0b01b844..9bbaf15af 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -44,6 +44,7 @@ CompactionState, DatabaseInfo, DataType, + ExtraList, GrantInfo, Group, IndexState, @@ -57,6 +58,7 @@ State, Status, UserInfo, + get_cost_extra, ) from .utils import ( check_invalid_binary_vector, @@ -732,7 +734,7 @@ def _execute_search( response = self._stub.Search(request, timeout=timeout) check_status(response.status) round_decimal = kwargs.get("round_decimal", -1) - return SearchResult(response.results, round_decimal) + return SearchResult(response.results, round_decimal, status=response.status) except Exception as e: if kwargs.get("_async", False): @@ -751,7 +753,7 @@ def _execute_hybrid_search( response = self._stub.HybridSearch(request, timeout=timeout) check_status(response.status) round_decimal = kwargs.get("round_decimal", -1) - return SearchResult(response.results, round_decimal) + return SearchResult(response.results, round_decimal, status=response.status) except Exception as e: if kwargs.get("_async", False): @@ -1519,7 +1521,7 @@ def query( response.fields_data, index, dynamic_fields ) results.append(entity_row_data) - return results + return ExtraList(results, extra=get_cost_extra(response.status)) @retry_on_rpc_failure() def load_balance( diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 457a97dc1..00b653602 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -1,6 +1,6 @@ import time from enum import IntEnum -from typing import Any, ClassVar, Dict, List, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union from pymilvus.exceptions import ( AutoIDException, @@ -763,7 +763,6 @@ def groups(self): class ResourceGroupInfo: def __init__(self, resource_group: Any) -> None: - self._name = resource_group.name self._capacity = resource_group.capacity self._num_available_node = resource_group.num_available_node @@ -917,3 +916,36 @@ def __init__(self, info: Any) -> None: def __str__(self) -> str: return f"DatabaseInfo(name={self.name}, properties={self.properties})" + + +class ExtraList(list): + """ + A list that can hold extra information. + Attributes: + extra (dict): The extra information of the list. + Example: + ExtraList([1, 2, 3], extra={"total": 3}) + """ + + def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.extra = extra or {} + + def __str__(self) -> str: + """Only print at most 10 query results""" + return f"data: {list(map(str, self[:10]))} {'...' if len(self) else ''}, extra_info: {self.extra}" + + __repr__ = __str__ + + +def get_cost_from_status(status: Optional[common_pb2.Status] = None): + return int(status.extra_info["report_value"] if status and status.extra_info else "0") + + +def get_cost_extra(status: Optional[common_pb2.Status] = None): + return {"cost": get_cost_from_status(status)} + + +# Construct extra dict, the cost unit is the vcu, similar to tokenlike the +def construct_cost_extra(cost: int): + return {"cost": cost} diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 229a06de3..072044efd 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -5,7 +5,12 @@ from uuid import uuid4 from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL -from pymilvus.client.types import ExceptionsMessage, LoadState +from pymilvus.client.types import ( + ExceptionsMessage, + ExtraList, + LoadState, + construct_cost_extra, +) from pymilvus.exceptions import ( DataTypeNotMatchException, MilvusException, @@ -216,7 +221,11 @@ def insert( ) except Exception as ex: raise ex from ex - return {"insert_count": res.insert_count, "ids": res.primary_keys} + return { + "insert_count": res.insert_count, + "ids": res.primary_keys, + "cost": res.cost, + } def upsert( self, @@ -263,7 +272,10 @@ def upsert( except Exception as ex: raise ex from ex - return {"upsert_count": res.upsert_count} + return { + "upsert_count": res.upsert_count, + "cost": res.cost, + } def search( self, @@ -325,7 +337,7 @@ def search( query_result.append(hit.to_dict()) ret.append(query_result) - return ret + return ExtraList(ret, extra=construct_cost_extra(res.cost)) def query( self, @@ -543,7 +555,7 @@ def delete( if ret_pks: return ret_pks - return {"delete_count": res.delete_count} + return {"delete_count": res.delete_count, "cost": res.cost} def get_collection_stats(self, collection_name: str, timeout: Optional[float] = None) -> Dict: conn = self._get_connection() diff --git a/pymilvus/orm/mutation.py b/pymilvus/orm/mutation.py index 2fe64e594..d9518cb7f 100644 --- a/pymilvus/orm/mutation.py +++ b/pymilvus/orm/mutation.py @@ -53,6 +53,11 @@ def succ_index(self): def err_index(self): return self._mr.err_index if self._mr else [] + # The unit of this cost is vcu, similar to token + @property + def cost(self): + return self._mr.cost if self._mr else 0 + def __str__(self) -> str: """ Return the information of mutation result