From 65c4fea203d4c9d27728c101b3736d42335d5357 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Mon, 9 Sep 2024 20:36:47 +0800 Subject: [PATCH] add function impl for pymilvus, specifically BM25 function Signed-off-by: Buqian Zheng --- examples/hello_bm25.py | 217 +++++++++++++++++++++++++++++++++ examples/hello_hybrid_bm25.py | 179 +++++++++++++++++++++++++++ examples/milvus_client/bm25.py | 90 ++++++++++++++ pymilvus/__init__.py | 5 +- pymilvus/client/abstract.py | 57 ++++++++- pymilvus/client/check.py | 2 +- pymilvus/client/prepare.py | 53 ++++++-- pymilvus/client/types.py | 4 + pymilvus/exceptions.py | 18 +++ pymilvus/orm/constants.py | 2 + pymilvus/orm/prepare.py | 13 +- pymilvus/orm/schema.py | 181 +++++++++++++++++++++++++-- pymilvus/orm/types.py | 2 +- 13 files changed, 790 insertions(+), 33 deletions(-) create mode 100644 examples/hello_bm25.py create mode 100644 examples/hello_hybrid_bm25.py create mode 100644 examples/milvus_client/bm25.py diff --git a/examples/hello_bm25.py b/examples/hello_bm25.py new file mode 100644 index 000000000..00a14106b --- /dev/null +++ b/examples/hello_bm25.py @@ -0,0 +1,217 @@ +# hello_bm25.py demonstrates how to insert raw data only into Milvus and perform +# sparse vector based ANN search using BM25 algorithm. +# 1. connect to Milvus +# 2. create collection +# 3. insert data +# 4. create index +# 5. search, query, and filtering search on entities +# 6. delete entities by PK +# 7. drop collection +import time + +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, Function, DataType, FunctionType, + Collection, +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" + +################################################################################# +# 1. connect to Milvus +# Add a new connection alias `default` for Milvus server in `localhost:19530` +print(fmt.format("start connecting to Milvus")) +connections.connect("default", host="localhost", port="19530") + +has = utility.has_collection("hello_bm25") +print(f"Does collection hello_bm25 exist in Milvus: {has}") + +################################################################################# +# 2. create collection +# We're going to create a collection with 2 explicit fields and a function. +# +-+------------+------------+------------------+------------------------------+ +# | | field name | field type | other attributes | field description | +# +-+------------+------------+------------------+------------------------------+ +# |1| "id" | INT64 | is_primary=True | "primary field" | +# | | | | auto_id=False | | +# +-+------------+------------+------------------+------------------------------+ +# |2| "document" | VarChar | | "raw text document" | +# +-+------------+------------+------------------+------------------------------+ +# +# Function 'bm25' is used to convert raw text document to a sparse vector representation +# and store it in the 'sparse' field. +# +-+------------+-------------------+-----------+------------------------------+ +# | | field name | field type | other attr| field description | +# +-+------------+-------------------+-----------+------------------------------+ +# |3| "sparse" |SPARSE_FLOAT_VECTOR| | | +# +-+------------+-------------------+-----------+------------------------------+ +# +fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000), +] + +bm25_function = Function( + name="bm25", + function_type=FunctionType.BM25, + inputs=["document"], + outputs=["sparse"], + params={"bm25_k1": 1.2, "bm25_b": 0.75}, +) + +schema = CollectionSchema(fields, "hello_bm25 demo") +schema.add_function(bm25_function) + +print(fmt.format("Create collection `hello_bm25`")) +hello_bm25 = Collection("hello_bm25", schema, consistency_level="Strong") + +################################################################################ +# 3. insert data +# We are going to insert 3 rows of data into `hello_bm25` +# Data to be inserted must be organized in fields. +# +# The insert() method returns: +# - either automatically generated primary keys by Milvus if auto_id=True in the schema; +# - or the existing primary key field from the entities if auto_id=False in the schema. + +print(fmt.format("Start inserting entities")) + +num_entities = 6 + +entities = [ + [f"This is a test document {i + hello_bm25.num_entities}" for i in range(num_entities)], +] + +insert_result = hello_bm25.insert(entities) +ids = insert_result.primary_keys + +time.sleep(3) + +hello_bm25.flush() +print(f"Number of entities in Milvus: {hello_bm25.num_entities}") # check the num_entities + +################################################################################ +# 4. create index +# We are going to create an SPARSE_INVERTED_INDEX index for hello_bm25 collection. +# create_index() can only be applied to `FloatVector` and `BinaryVector` fields. +print(fmt.format("Start Creating index SPARSE_INVERTED_INDEX")) +index = { + "index_type": "SPARSE_INVERTED_INDEX", + "metric_type": "BM25", + 'params': {"bm25_k1": 1.2, "bm25_b": 0.75}, +} + +hello_bm25.create_index("sparse", index) + +################################################################################ +# 5. search, query, and scalar filtering search +# After data were inserted into Milvus and indexed, you can perform: +# - search texts relevance by BM25 using sparse vector ANN search +# - query based on scalar filtering(boolean, int, etc.) +# - scalar filtering search. +# + +# Before conducting a search or a query, you need to load the data in `hello_bm25` into memory. +print(fmt.format("Start loading")) +hello_bm25.load() + +# ----------------------------------------------------------------------------- +print(fmt.format("Start searching based on BM25 texts relevance using sparse vector ANN search")) +texts_to_search = entities[-1][-2:] +print(fmt.format(f"texts_to_search: {texts_to_search}")) +search_params = { + "metric_type": "BM25", + "params": {}, +} + +start_time = time.time() +result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, output_fields=["document"], consistency_level="Strong") +end_time = time.time() + +for hits, text in zip(result, texts_to_search): + print(f"result of text: {text}") + for hit in hits: + print(f"\thit: {hit}, document field: {hit.entity.get('document')}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# query based on scalar filtering(boolean, int, etc.) +filter_id = ids[num_entities // 2 - 1] +print(fmt.format(f"Start querying with `id > {filter_id}`")) + +start_time = time.time() +result = hello_bm25.query(expr=f"id > {filter_id}", output_fields=["document"]) +end_time = time.time() + +print(f"query result:\n-{result[0]}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# pagination +r1 = hello_bm25.query(expr=f"id > {filter_id}", limit=3, output_fields=["document"]) +r2 = hello_bm25.query(expr=f"id > {filter_id}", offset=1, limit=2, output_fields=["document"]) +print(f"query pagination(limit=3):\n\t{r1}") +print(f"query pagination(offset=1, limit=2):\n\t{r2}") + + +# ----------------------------------------------------------------------------- +# scalar filtering search +print(fmt.format(f"Start filtered searching with `id > {filter_id}`")) + +start_time = time.time() +result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, expr=f"id > {filter_id}", output_fields=["document"]) +end_time = time.time() + +for hits, text in zip(result, texts_to_search): + print(f"result of text: {text}") + for hit in hits: + print(f"\thit: {hit}, document field: {hit.entity.get('document')}") +print(search_latency_fmt.format(end_time - start_time)) + +############################################################################### +# 6. delete entities by PK +# You can delete entities by their PK values using boolean expressions. + +expr = f'id in [{ids[0]}, {ids[1]}]' +print(fmt.format(f"Start deleting with expr `{expr}`")) + +result = hello_bm25.query(expr=expr, output_fields=["document"]) +print(f"query before delete by expr=`{expr}` -> result: \n- {result[0]}\n- {result[1]}\n") + +hello_bm25.delete(expr) + +result = hello_bm25.query(expr=expr, output_fields=["document"]) +print(f"query after delete by expr=`{expr}` -> result: {result}\n") + +############################################################################### +# 7. upsert by PK +# You can upsert data to replace existing data. +target_id = ids[2] +print(fmt.format(f"Start upsert operation for id {target_id}")) + +# Query before upsert +result_before = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) +print(f"Query before upsert (id={target_id}):\n{result_before}") + +# Prepare data for upsert +upsert_data = [ + [target_id], + ["This is an upserted document for testing purposes."] +] + +# Perform upsert operation +hello_bm25.upsert(upsert_data) + +# Query after upsert +result_after = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) +print(f"Query after upsert (id={target_id}):\n{result_after}") + + +############################################################################### +# 7. drop collection +# Finally, drop the hello_bm25 collection +print(fmt.format("Drop collection `hello_bm25`")) +utility.drop_collection("hello_bm25") diff --git a/examples/hello_hybrid_bm25.py b/examples/hello_hybrid_bm25.py new file mode 100644 index 000000000..b8deba164 --- /dev/null +++ b/examples/hello_hybrid_bm25.py @@ -0,0 +1,179 @@ +# A demo showing hybrid semantic search with dense and full text search with BM25 +# using Milvus. +# +# You can optionally choose to use the BGE-M3 model to embed the text as dense +# vectors, or simply use random generated vectors as an example. +# +# You can also use the BGE CrossEncoder model to rerank the search results. +# +# Note that the full text search feature is only available in Milvus 2.4.0 or +# higher version. Make sure you follow https://milvus.io/docs/install_standalone-docker.md +# to set up the latest version of Milvus in your local environment. + +# To connect to Milvus server, you need the python client library called pymilvus. +# To use BGE-M3 model, you need to install the optional `model` module in pymilvus. +# You can get them by simply running the following commands: +# +# pip install pymilvus +# pip install pymilvus[model] + +# If true, use BGE-M3 model to generate dense vectors. +# If false, use random numbers to compose dense vectors. +use_bge_m3 = False +# If true, the search result will be reranked using BGE CrossEncoder model. +use_reranker = False + +# The overall steps are as follows: +# 1. embed the text as dense and sparse vectors +# 2. setup a Milvus collection to store the dense and sparse vectors +# 3. insert the data to Milvus +# 4. search and inspect the result! +import random +import string +import numpy as np + +from pymilvus import ( + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, + AnnSearchRequest, + RRFRanker, + connections, + Function, + FunctionType, +) + +# 1. prepare a small corpus to search +docs = [ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", +] +# add some randomly generated texts +docs.extend( + [ + " ".join( + "".join(random.choice(string.ascii_lowercase) for _ in range(random.randint(1, 8))) + for _ in range(10) + ) + for _ in range(1000) + ] +) +query = "Who started AI research?" + + +def random_embedding(texts): + rng = np.random.default_rng() + return { + "dense": np.random.rand(len(texts), 768), + } + + +dense_dim = 768 +ef = random_embedding + +if use_bge_m3: + # BGE-M3 model is included in the optional `model` module in pymilvus, to + # install it, simply run "pip install pymilvus[model]". + from pymilvus.model.hybrid import BGEM3EmbeddingFunction + + ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu") + dense_dim = ef.dim["dense"] + +docs_embeddings = ef(docs) +query_embeddings = ef([query]) + +# 2. setup Milvus collection and index +connections.connect("default", host="localhost", port="19530") + +# Specify the data schema for the new Collection. +fields = [ + # Use auto generated id as primary key + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), + # Store the original text to retrieve based on semantically distance + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512), + # We need a sparse vector field to perform full text search with BM25, + # but you don't need to provide data for it when inserting data. + FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim), +] +functions = [ + Function( + name="bm25", + function_type=FunctionType.BM25, + inputs=["text"], + outputs=["sparse_vector"], + params={"bm25_k1": 1.2, "bm25_b": 0.75}, + ) +] +schema = CollectionSchema(fields, "", functions=functions) +col_name = "hybrid_bm25_demo" +# Now we can create the new collection with above name and schema. +col = Collection(col_name, schema, consistency_level="Strong") + +# We need to create indices for the vector fields. The indices will be loaded +# into memory for efficient search. +sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"} +col.create_index("sparse_vector", sparse_index) +dense_index = {"index_type": "FLAT", "metric_type": "IP"} +col.create_index("dense_vector", dense_index) +col.load() + +# 3. insert text and sparse/dense vector representations into the collection +entities = [docs, docs_embeddings["dense"]] +col.insert(entities) +col.flush() + +# 4. search and inspect the result! +k = 2 # we want to get the top 2 docs closest to the query + +# Prepare the search requests for both full text search and dense vector search +full_text_search_params = {"metric_type": "BM25"} +# provide raw text query for full text search, while use the sparse vector as +# ANNS field +full_text_search_req = AnnSearchRequest([query], "sparse_vector", full_text_search_params, limit=k) +dense_search_params = {"metric_type": "IP"} +dense_req = AnnSearchRequest( + query_embeddings["dense"], "dense_vector", dense_search_params, limit=k +) + +# Search topK docs based on dense and sparse vectors and rerank with RRF. +res = col.hybrid_search( + [full_text_search_req, dense_req], rerank=RRFRanker(), limit=k, output_fields=["text"] +) + +# Currently Milvus only support 1 query in the same hybrid search request, so +# we inspect res[0] directly. In future release Milvus will accept batch +# hybrid search queries in the same call. +res = res[0] + +if use_reranker: + result_texts = [hit.fields["text"] for hit in res] + from pymilvus.model.reranker import BGERerankFunction + + bge_rf = BGERerankFunction(device="cpu") + # rerank the results using BGE CrossEncoder model + results = bge_rf(query, result_texts, top_k=2) + for hit in results: + print(f"text: {hit.text} distance {hit.score}") +else: + for hit in res: + print(f'text: {hit.fields["text"]} distance {hit.distance}') + +# If you used both BGE-M3 and the reranker, you should see the following: +# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.9306981017573297 +# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.03217001154515051 +# +# If you used only BGE-M3, you should see the following: +# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.032786883413791656 +# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.016129031777381897 + +# In this simple example the reranker yields the same result as the embedding based hybrid search, but in more complex +# scenarios the reranker can provide more accurate results. + +# If you used random vectors, the result will be different each time you run the script. + +# Drop the collection to clean up the data. +utility.drop_collection(col_name) diff --git a/examples/milvus_client/bm25.py b/examples/milvus_client/bm25.py new file mode 100644 index 000000000..49301a3f4 --- /dev/null +++ b/examples/milvus_client/bm25.py @@ -0,0 +1,90 @@ +from pymilvus import ( + MilvusClient, + Function, + FunctionType, + DataType, +) + +fmt = "\n=== {:30} ===\n" +collection_name = "doc_in_doc_out" +milvus_client = MilvusClient("http://localhost:19530") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema() +schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False) +schema.add_field("document_content", DataType.VARCHAR, max_length=9000) +schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR) + +bm25_function = Function( + name="bm25_fn", + inputs=["document_content"], + outputs=["sparse_vector"], + function_type=FunctionType.BM25, + params={"bm25_k1": 1.2, "bm25_b": 0.75}, +) +schema.add_function(bm25_function) + +index_params = milvus_client.prepare_index_params() +index_params.add_index( + field_name="sparse_vector", + index_name="sparse_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + params={"bm25_k1": 1.2, "bm25_b": 0.75}, +) + +ret = milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") +print(ret) + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rows = [ + {"id": 1, "document_content": "hello world"}, + {"id": 2, "document_content": "hello milvus"}, + {"id": 3, "document_content": "hello zilliz"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +print(insert_result) + +texts_to_search = ["zilliz"] +search_params = { + "metric_type": "BM25", + "params": {} +} +print(fmt.format(f"Start search with retrieve several fields.")) +result = milvus_client.search(collection_name, texts_to_search, limit=3, output_fields=["document_content"], search_params=search_params) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[3]) +print(query_results[0]) + +upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "document_content": "hello milvus again"}) +print(upsert_ret) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="document_content == 'hello milvus again'") +for ret in query_results: + print(ret) + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, ids=[3]) +print(delete_result) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="document_content == 'hello zilliz'") +print(f"Query results after deletion: {query_results}") + +milvus_client.drop_collection(collection_name) diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 80ffb2f49..21b5f99ed 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -18,6 +18,7 @@ from .client.types import ( BulkInsertState, DataType, + FunctionType, Group, IndexType, Replica, @@ -38,7 +39,7 @@ from .orm.index import Index from .orm.partition import Partition from .orm.role import Role -from .orm.schema import CollectionSchema, FieldSchema +from .orm.schema import CollectionSchema, FieldSchema, Function from .orm.utility import ( create_resource_group, create_user, @@ -101,6 +102,7 @@ "Group", "Shard", "FieldSchema", + "Function", "CollectionSchema", "SearchFuture", "MutationFuture", @@ -121,6 +123,7 @@ "Prepare", "Status", "DataType", + "FunctionType", "MilvusException", "__version__", "MilvusClient", diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 4fd90156f..c85f1697a 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -9,7 +9,7 @@ from . import entity_helper, utils from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED -from .types import DataType +from .types import DataType, FunctionType class FieldSchema: @@ -28,6 +28,7 @@ def __init__(self, raw: Any): self.is_dynamic = False self.nullable = False self.default_value = None + self.is_function_output = False # For array field self.element_type = None self.is_clustering_key = False @@ -48,6 +49,7 @@ def __pack(self, raw: Any): self.default_value = None self.is_dynamic = raw.is_dynamic self.nullable = raw.nullable + self.is_function_output = raw.is_function_output for type_param in raw.type_params: if type_param.key == "params": @@ -110,9 +112,55 @@ def dict(self): _dict["is_primary"] = self.is_primary if self.is_clustering_key: _dict["is_clustering_key"] = True + if self.is_function_output: + _dict["is_function_output"] = True return _dict +class FunctionSchema: + def __init__(self, raw: Any): + self._raw = raw + + self.name = None + self.description = None + self.type = None + self.params = {} + self.input_field_names = [] + self.input_field_ids = [] + self.output_field_names = [] + self.output_field_ids = [] + self.id = 0 + + self.__pack(self._raw) + + def __pack(self, raw: Any): + self.name = raw.name + self.description = raw.description + self.id = raw.id + self.type = FunctionType(raw.type) + self.params = {} + for param in raw.params: + self.params[param.key] = param.value + self.input_field_names = raw.input_field_names + self.input_field_ids = raw.input_field_ids + self.output_field_names = raw.output_field_names + self.output_field_ids = raw.output_field_ids + + def dict(self): + return { + "name": self.name, + "id": self.id, + "description": self.description, + "type": self.type, + "params": self.params, + "input_field_names": self.input_field_names, + "input_field_ids": self.input_field_ids, + "output_field_names": self.output_field_names, + "output_field_ids": self.output_field_ids, + } + + + class CollectionSchema: def __init__(self, raw: Any): self._raw = raw @@ -121,6 +169,7 @@ def __init__(self, raw: Any): self.description = None self.params = {} self.fields = [] + self.functions = [] self.statistics = {} self.auto_id = False # auto_id is not in collection level any more later self.aliases = [] @@ -158,6 +207,11 @@ def __pack(self, raw: Any): self.fields = [FieldSchema(f) for f in raw.schema.fields] + self.functions = [FunctionSchema(f) for f in raw.schema.functions] + function_output_field_names = [f for fn in self.functions for f in fn.output_field_names] + for field in self.fields: + if field.name in function_output_field_names: + field.is_function_output = True # for s in raw.statistics: for p in raw.properties: @@ -183,6 +237,7 @@ def dict(self): "num_shards": self.num_shards, "description": self.description, "fields": [f.dict() for f in self.fields], + "functions": [f.dict() for f in self.functions], "aliases": self.aliases, "collection_id": self.collection_id, "consistency_level": self.consistency_level, diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index efbe9ca49..ccb46843a 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -182,7 +182,7 @@ def is_legal_search_data(data: Any) -> bool: if not isinstance(data, (list, np.ndarray)): return False - return all(isinstance(vector, (list, bytes, np.ndarray)) for vector in data) + return all(isinstance(vector, (list, bytes, np.ndarray, str)) for vector in data) def is_legal_output_fields(output_fields: Any) -> bool: diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 1642e6a20..eac5aec2f 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -133,6 +133,7 @@ def get_schema_from_collection_schema( is_dynamic=f.is_dynamic, element_type=f.element_type, is_clustering_key=f.is_clustering_key, + is_function_output=f.is_function_output, ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair( @@ -141,6 +142,20 @@ def get_schema_from_collection_schema( field_schema.type_params.append(kv_pair) schema.fields.append(field_schema) + + for f in fields.functions: + function_schema = schema_types.FunctionSchema( + name=f.name, + description=f.description, + type=f.type, + input_field_names=f.input_field_names, + output_field_names=f.output_field_names, + ) + for k, v in f.params.items(): + kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) + function_schema.params.append(kv_pair) + schema.functions.append(function_schema) + return schema @staticmethod @@ -369,6 +384,10 @@ def partition_name(cls, collection_name: str, partition_name: str): raise ParamError(message="partition_name must be of str type") return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name) + @staticmethod + def _num_input_fields(fields_info: List[Dict]): + return len(fields_info) - len([field for field in fields_info if field.get("is_function_output", False)]) + @staticmethod def _parse_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], @@ -421,19 +440,25 @@ def _parse_row_request( raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e request.fields_data.extend( - [fields_data[field["name"]] for field in fields_info if not field.get("auto_id", False)] + [ + fields_data[field["name"]] + for field in fields_info + if not field.get("is_function_output", False) and not field.get("auto_id", False) + ] ) if enable_dynamic: request.fields_data.append(d_field) + num_input_fields = Prepare._num_input_fields(fields_info) + _, _, auto_id_loc = traverse_rows_info(fields_info, entities) if auto_id_loc is not None: - if (enable_dynamic and len(fields_data) != len(fields_info)) or ( - not enable_dynamic and len(fields_data) + 1 != len(fields_info) + if (enable_dynamic and len(fields_data) != num_input_fields) or ( + not enable_dynamic and len(fields_data) + 1 != num_input_fields ): raise ParamError(ExceptionsMessage.FieldsNumInconsistent) - elif enable_dynamic and len(fields_data) != len(fields_info) + 1: + elif enable_dynamic and len(fields_data) != num_input_fields + 1: raise ParamError(ExceptionsMessage.FieldsNumInconsistent) return request @@ -504,12 +529,14 @@ def _pre_insert_batch_check( if primary_key_loc is None: raise ParamError(message="primary key not found") - if auto_id_loc is None and len(entities) != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + num_input_fields = Prepare._num_input_fields(fields_info) + + if auto_id_loc is None and len(entities) != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) - if auto_id_loc is not None and len(entities) + 1 != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + if auto_id_loc is not None and len(entities) + 1 != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) return location @@ -536,8 +563,10 @@ def _pre_upsert_batch_check( if primary_key_loc is None: raise ParamError(message="primary key not found") - if len(entities) != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + num_input_fields = Prepare._num_input_fields(fields_info) + + if len(entities) != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) return location @@ -668,6 +697,10 @@ def _prepare_placeholder_str(cls, data: Any): pl_type = PlaceholderType.BinaryVector pl_values = data # data is already a list of bytes + elif isinstance(data[0], str): + pl_type = PlaceholderType.VARCHAR + pl_values = (value.encode('utf-8') for value in data) + else: pl_type = PlaceholderType.FloatVector pl_values = (blob.vector_float_to_bytes(entity) for entity in data) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index adee2f10b..8233b79a2 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -107,6 +107,9 @@ class DataType(IntEnum): UNKNOWN = 999 +class FunctionType(IntEnum): + UNKNOWN = 0 + BM25 = 1 class RangeType(IntEnum): LT = 0 # less than @@ -173,6 +176,7 @@ class PlaceholderType(IntEnum): FLOAT16_VECTOR = 102 BFLOAT16_VECTOR = 103 SparseFloatVector = 104 + VARCHAR = 21 class State(IntEnum): diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 75dc3726c..80a5bfe61 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -129,6 +129,10 @@ class FieldsTypeException(MilvusException): """Raise when fields is invalid""" +class FunctionsTypeException(MilvusException): + """Raise when functions are invalid""" + + class FieldTypeException(MilvusException): """Raise when one field is invalid""" @@ -204,8 +208,22 @@ class ExceptionsMessage: IndexNotExist = "Index doesn't exist." CollectionType = "The type of collection must be pymilvus.Collection." FieldsType = "The fields of schema must be type list." + FunctionsType = "The functions of collection must be type list." + FunctionIncorrectInputOutputType = "The type of function input and output must be str." + FunctionInvalidOutputField = "The output field must not be primary key, partition key, clustering key." + FunctionDuplicateInputs = "Duplicate input field names are not allowed in function." + FunctionDuplicateOutputs = "Duplicate output field names are not allowed in function." + FunctionCommonInputOutput = "Input and output field names must be different." + BM25FunctionIncorrectInputOutputCount = "BM25 function must have exact 1 input and 1 output field." + BM25FunctionIncorrectInputFieldType = "BM25 function input field must be VARCHAR." + BM25FunctionIncorrectOutputFieldType = "BM25 function output field must be SPARSE_FLOAT_VECTOR." + FunctionMissingInputField = "Function input field not found in collection schema." + FunctionMissingOutputField = "Function output field not found in collection schema." + UnknownFunctionType = "Unknown function type." + FunctionIncorrectType = "The function of schema type must be Function." FieldType = "The field of schema type must be FieldSchema." FieldDtype = "Field dtype must be of DataType" + FieldNamesDuplicate = "Duplicate field names are not allowed." ExprType = "The type of expr must be string ,but %r is given." EnvConfigErr = "Environment variable %s has a wrong format, please check it: %s" AmbiguousIndexName = "There are multiple indexes, please specify the index_name." diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index b4980204d..d6d83445d 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -51,3 +51,5 @@ DEFAULT_SEARCH_EXTENSION_RATE: int = 10 UNLIMITED: int = -1 MAX_TRY_TIME: int = 20 +BM25_k1 = 'bm25_k1' +BM25_b = 'bm25_b' \ No newline at end of file diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 3af4bdbc1..e65401f43 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -59,12 +59,13 @@ def prepare_data( entities.append({"name": field.name, "type": field.dtype, "values": values}) return entities - tmp_fields = copy.deepcopy(fields) - for i, field in enumerate(tmp_fields): - # TODO Goose: Checking auto_id and is_primary only, maybe different than - # schema.is_primary, schema.auto_id, need to check why and how schema is built. - if field.is_primary and field.auto_id and is_insert: - tmp_fields.pop(i) + tmp_fields = list( + filter( + lambda field: not (field.is_primary and field.auto_id and is_insert) + and not field.is_function_output, + fields, + ) + ) vec_dtype_checker = { DataType.FLOAT_VECTOR: lambda ndarr: ndarr.dtype in ("float32", "float64"), diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index ae8df3c97..4d8d4ef3c 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -24,6 +24,7 @@ DataTypeNotSupportException, ExceptionsMessage, FieldsTypeException, + FunctionsTypeException, FieldTypeException, ParamError, PartitionKeyException, @@ -32,9 +33,10 @@ ) from pymilvus.grpc_gen import schema_pb2 as schema_types -from .constants import COMMON_TYPE_PARAMS +from .constants import COMMON_TYPE_PARAMS, BM25_k1, BM25_b from .types import ( DataType, + FunctionType, infer_dtype_by_scalar_data, infer_dtype_bydata, map_numpy_dtype_to_datatype, @@ -85,7 +87,7 @@ def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field class CollectionSchema: - def __init__(self, fields: List, description: str = "", **kwargs): + def __init__(self, fields: List, description: str = "", functions: List = [], **kwargs): self._kwargs = copy.deepcopy(kwargs) self._fields = [] self._description = description @@ -95,10 +97,22 @@ def __init__(self, fields: List, description: str = "", **kwargs): self._partition_key_field = None self._clustering_key_field = None + if not isinstance(functions, list): + raise FunctionsTypeException(message=ExceptionsMessage.FunctionsType) + for function in functions: + if not isinstance(function, Function): + raise SchemaNotReadyException(message=ExceptionsMessage.FunctionIncorrectType) + self._functions = [copy.deepcopy(function) for function in functions] + if not isinstance(fields, list): raise FieldsTypeException(message=ExceptionsMessage.FieldsType) + for field in fields: + if not isinstance(field, FieldSchema): + raise FieldTypeException(message=ExceptionsMessage.FieldType) self._fields = [copy.deepcopy(field) for field in fields] + self._mark_output_fields() + self._check_kwargs() if kwargs.get("check_fields", True): self._check_fields() @@ -114,10 +128,6 @@ def _check_kwargs(self): if clustering_key_field_name is not None and not isinstance(clustering_key_field_name, str): raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyFieldType) - for field in self._fields: - if not isinstance(field, FieldSchema): - raise FieldTypeException(message=ExceptionsMessage.FieldType) - if "auto_id" in self._kwargs and not isinstance(self._kwargs["auto_id"], bool): raise AutoIDException(0, ExceptionsMessage.AutoIDType) @@ -125,6 +135,9 @@ def _check_fields(self): primary_field_name = self._kwargs.get("primary_field", None) partition_key_field_name = self._kwargs.get("partition_key_field", None) clustering_key_field_name = self._kwargs.get("clustering_key_field", None) + field_names = [field.name for field in self._fields] + if len(field_names) != len(set(field_names)): + raise ParamError(message=ExceptionsMessage.FieldNamesDuplicate) for field in self._fields: if primary_field_name and primary_field_name == field.name: field.is_primary = True @@ -175,9 +188,34 @@ def _check_fields(self): if auto_id: self._primary_field.auto_id = auto_id + def _check_functions(self): + for function in self._functions: + for output_field_name in function.output_field_names: + output_field = next((field for field in self._fields if field.name == output_field_name), None) + if output_field is None: + raise ParamError(message=ExceptionsMessage.FunctionMissingOutputField) + + if output_field is not None and (output_field.is_primary or output_field.is_partition_key or output_field.is_clustering_key): + raise ParamError(message=ExceptionsMessage.FunctionInvalidOutputField) + + for input_field_name in function.input_field_names: + input_field = next((field for field in self._fields if field.name == input_field_name), None) + if input_field is None: + raise ParamError(message=ExceptionsMessage.FunctionMissingInputField) + + function.verify(self) + + def _mark_output_fields(self): + for function in self._functions: + for output_field_name in function.output_field_names: + output_field = next((field for field in self._fields if field.name == output_field_name), None) + if output_field is not None: + output_field.is_function_output = True + def _check(self): self._check_kwargs() self._check_fields() + self._check_functions() def __repr__(self) -> str: return str(self.to_dict()) @@ -192,9 +230,10 @@ def __eq__(self, other: object): @classmethod def construct_from_dict(cls, raw: Dict): fields = [FieldSchema.construct_from_dict(field_raw) for field_raw in raw["fields"]] + functions = [Function.construct_from_dict(function_raw) for function_raw in raw["functions"]] enable_dynamic_field = raw.get("enable_dynamic_field", False) return CollectionSchema( - fields, raw.get("description", ""), enable_dynamic_field=enable_dynamic_field + fields, raw.get("description", ""), functions, enable_dynamic_field=enable_dynamic_field ) @property @@ -222,6 +261,16 @@ def fields(self): """ return self._fields + @property + def functions(self): + """ + Returns the functions of the CollectionSchema. + + :return list: + List of Function, return when operation is successful. + """ + return self._functions + @property def description(self): """ @@ -277,6 +326,7 @@ def to_dict(self): "auto_id": self.auto_id, "description": self._description, "fields": [s.to_dict() for s in self._fields], + "functions": [s.to_dict() for s in self._functions], "enable_dynamic_field": self.enable_dynamic_field, } @@ -287,6 +337,14 @@ def verify(self): def add_field(self, field_name: str, datatype: DataType, **kwargs): field = FieldSchema(field_name, datatype, **kwargs) self._fields.append(field) + self._mark_output_fields() + return self + + def add_function(self, function): + if not isinstance(function, Function): + raise ParamError(message=ExceptionsMessage.FunctionIncorrectType) + self._functions.append(function) + self._mark_output_fields() return self @@ -333,6 +391,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs) if "mmap_enabled" in kwargs: self._type_params["mmap_enabled"] = kwargs["mmap_enabled"] self._parse_type_params() + self.is_function_output = False def __repr__(self) -> str: return str(self.to_dict()) @@ -378,13 +437,17 @@ def construct_from_dict(cls, raw: Dict): kwargs["is_dynamic"] = raw.get("is_dynamic", False) kwargs["nullable"] = raw.get("nullable", False) kwargs["element_type"] = raw.get("element_type") - return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) + is_function_output = raw.get("is_function_output", False) + fs = FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) + fs.is_function_output = is_function_output + return fs def to_dict(self): _dict = { "name": self.name, "description": self._description, "type": self.dtype, + "is_function_output": self.is_function_output, } if self._type_params: _dict["params"] = copy.deepcopy(self.params) @@ -456,6 +519,96 @@ def params(self): def dtype(self) -> DataType: return self._dtype +class Function: + def __init__(self, name: str, function_type: FunctionType, inputs: List[str], outputs: List[str], description: str = "", params: Dict = {}): + self._name = name + self._description = description + try: + self._type = FunctionType(function_type) + except ValueError: + raise ParamError(message=ExceptionsMessage.UnknownFunctionType) + + for field_name in list(inputs) + list(outputs): + if not isinstance(field_name, str): + raise ParamError(message=ExceptionsMessage.FunctionIncorrectInputOutputType) + if len(inputs) != len(set(inputs)): + raise ParamError(message=ExceptionsMessage.FunctionDuplicateInputs) + if len(outputs) != len(set(outputs)): + raise ParamError(message=ExceptionsMessage.FunctionDuplicateOutputs) + + if set(inputs) & set(outputs): + raise ParamError(message=ExceptionsMessage.FunctionCommonInputOutput) + + self._input_field_names = inputs + self._output_field_names = outputs + if BM25_k1 in params: + params[BM25_k1] = str(params[BM25_k1]) + if BM25_b in params: + params[BM25_b] = str(params[BM25_b]) + self._params = params + + @property + def name(self): + return self._name + + @property + def description(self): + return self._description + + @property + def type(self): + return self._type + + @property + def input_field_names(self): + return self._input_field_names + + @property + def output_field_names(self): + return self._output_field_names + + @property + def params(self): + return self._params + + def verify(self, schema: CollectionSchema): + if self._type == FunctionType.BM25: + if len(self._input_field_names) != 1 or len(self._output_field_names) != 1: + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputOutputCount) + + for field in schema.fields: + if field.name == self._input_field_names[0]: + if field.dtype != DataType.VARCHAR: + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType) + if field.name == self._output_field_names[0]: + if field.dtype != DataType.SPARSE_FLOAT_VECTOR: + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType) + + elif self._type == FunctionType.UNKNOWN: + raise ParamError(message=ExceptionsMessage.UnknownFunctionType) + + @classmethod + def construct_from_dict(cls, raw: Dict): + return Function(raw["name"], raw["type"], raw["input_field_names"], raw["output_field_names"], raw["description"], raw["params"]) + + def __repr__(self) -> str: + return str(self.to_dict()) + + def to_dict(self): + return { + "name": self._name, + "description": self._description, + "type": self._type, + "input_field_names": self._input_field_names, + "output_field_names": self._output_field_names, + "params": self._params, + } + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Function): + return False + return self.to_dict() == value.to_dict() + def is_valid_insert_data(data: Union[pd.DataFrame, list, dict]) -> bool: """DataFrame, list, dict are valid insert data""" @@ -502,7 +655,7 @@ def _check_insert_data(data: Union[List[List], pd.DataFrame]): def _check_data_schema_cnt(fields: List, data: Union[List[List], pd.DataFrame]): - field_cnt = len(fields) + field_cnt = len([f for f in fields if not f.is_function_output]) is_dataframe = isinstance(data, pd.DataFrame) data_cnt = len(data.columns) if is_dataframe else len(data) if field_cnt != data_cnt: @@ -535,10 +688,12 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat columns.remove(schema.primary_field) data = data[columns] - tmp_fields = copy.deepcopy(schema.fields) - for i, field in enumerate(tmp_fields): - if field.is_primary and field.auto_id: - tmp_fields.pop(i) + tmp_fields = list( + filter( + lambda field: not (field.is_primary and field.auto_id) and not field.is_function_output, + schema.fields, + ) + ) _check_data_schema_cnt(tmp_fields, data) _check_insert_data(data) diff --git a/pymilvus/orm/types.py b/pymilvus/orm/types.py index eebbb50bc..27fd03de2 100644 --- a/pymilvus/orm/types.py +++ b/pymilvus/orm/types.py @@ -21,7 +21,7 @@ is_scalar, ) -from pymilvus.client.types import DataType +from pymilvus.client.types import DataType, FunctionType dtype_str_map = { "string": DataType.VARCHAR,