From 4ed5d3e4d89c5989652c5cbeaaedc1ac1533f6f7 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Fri, 11 Oct 2024 14:41:14 +0800 Subject: [PATCH] support new Function feature (#2257) currently only BM25 Function is supported. issue: https://github.com/milvus-io/milvus/issues/35853 Signed-off-by: Buqian Zheng --- examples/hello_bm25.py | 215 ++++++++++++++++++++++++++++++ examples/hello_hybrid_bm25.py | 178 +++++++++++++++++++++++++ examples/milvus_client/bm25.py | 89 +++++++++++++ pymilvus/__init__.py | 5 +- pymilvus/client/abstract.py | 56 +++++++- pymilvus/client/check.py | 2 +- pymilvus/client/prepare.py | 55 ++++++-- pymilvus/client/types.py | 6 + pymilvus/exceptions.py | 22 ++++ pymilvus/orm/constants.py | 2 + pymilvus/orm/prepare.py | 16 ++- pymilvus/orm/schema.py | 234 +++++++++++++++++++++++++++++++-- 12 files changed, 847 insertions(+), 33 deletions(-) create mode 100644 examples/hello_bm25.py create mode 100644 examples/hello_hybrid_bm25.py create mode 100644 examples/milvus_client/bm25.py diff --git a/examples/hello_bm25.py b/examples/hello_bm25.py new file mode 100644 index 000000000..37d3139eb --- /dev/null +++ b/examples/hello_bm25.py @@ -0,0 +1,215 @@ +# hello_bm25.py demonstrates how to insert raw data only into Milvus and perform +# sparse vector based ANN search using BM25 algorithm. +# 1. connect to Milvus +# 2. create collection +# 3. insert data +# 4. create index +# 5. search, query, and filtering search on entities +# 6. delete entities by PK +# 7. drop collection +import time + +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, Function, DataType, FunctionType, + Collection, +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" + +################################################################################# +# 1. connect to Milvus +# Add a new connection alias `default` for Milvus server in `localhost:19530` +print(fmt.format("start connecting to Milvus")) +connections.connect("default", host="localhost", port="19530") + +has = utility.has_collection("hello_bm25") +print(f"Does collection hello_bm25 exist in Milvus: {has}") + +################################################################################# +# 2. create collection +# We're going to create a collection with 2 explicit fields and a function. +# +-+------------+------------+------------------+------------------------------+ +# | | field name | field type | other attributes | field description | +# +-+------------+------------+------------------+------------------------------+ +# |1| "id" | INT64 | is_primary=True | "primary field" | +# | | | | auto_id=False | | +# +-+------------+------------+------------------+------------------------------+ +# |2| "document" | VarChar | | "raw text document" | +# +-+------------+------------+------------------+------------------------------+ +# +# Function 'bm25' is used to convert raw text document to a sparse vector representation +# and store it in the 'sparse' field. +# +-+------------+-------------------+-----------+------------------------------+ +# | | field name | field type | other attr| field description | +# +-+------------+-------------------+-----------+------------------------------+ +# |3| "sparse" |SPARSE_FLOAT_VECTOR| | | +# +-+------------+-------------------+-----------+------------------------------+ +# +fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000, enable_tokenizer=True), +] + +bm25_function = Function( + name="bm25", + function_type=FunctionType.BM25, + input_field_names=["document"], + output_field_names="sparse", +) + +schema = CollectionSchema(fields, "hello_bm25 demo") +schema.add_function(bm25_function) + +print(fmt.format("Create collection `hello_bm25`")) +hello_bm25 = Collection("hello_bm25", schema, consistency_level="Strong") + +################################################################################ +# 3. insert data +# We are going to insert 3 rows of data into `hello_bm25` +# Data to be inserted must be organized in fields. +# +# The insert() method returns: +# - either automatically generated primary keys by Milvus if auto_id=True in the schema; +# - or the existing primary key field from the entities if auto_id=False in the schema. + +print(fmt.format("Start inserting entities")) + +num_entities = 6 + +entities = [ + [f"This is a test document {i + hello_bm25.num_entities}" for i in range(num_entities)], +] + +insert_result = hello_bm25.insert(entities) +ids = insert_result.primary_keys + +time.sleep(3) + +hello_bm25.flush() +print(f"Number of entities in Milvus: {hello_bm25.num_entities}") # check the num_entities + +################################################################################ +# 4. create index +# We are going to create an index for hello_bm25 collection, here we simply +# uses AUTOINDEX so Milvus can use the default parameters. +print(fmt.format("Start Creating index AUTOINDEX")) +index = { + "index_type": "AUTOINDEX", + "metric_type": "BM25", +} + +hello_bm25.create_index("sparse", index) + +################################################################################ +# 5. search, query, and scalar filtering search +# After data were inserted into Milvus and indexed, you can perform: +# - search texts relevance by BM25 using sparse vector ANN search +# - query based on scalar filtering(boolean, int, etc.) +# - scalar filtering search. +# + +# Before conducting a search or a query, you need to load the data in `hello_bm25` into memory. +print(fmt.format("Start loading")) +hello_bm25.load() + +# ----------------------------------------------------------------------------- +print(fmt.format("Start searching based on BM25 texts relevance using sparse vector ANN search")) +texts_to_search = entities[-1][-2:] +print(fmt.format(f"texts_to_search: {texts_to_search}")) +search_params = { + "metric_type": "BM25", + "params": {}, +} + +start_time = time.time() +result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, output_fields=["document"], consistency_level="Strong") +end_time = time.time() + +for hits, text in zip(result, texts_to_search): + print(f"result of text: {text}") + for hit in hits: + print(f"\thit: {hit}, document field: {hit.entity.get('document')}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# query based on scalar filtering(boolean, int, etc.) +filter_id = ids[num_entities // 2 - 1] +print(fmt.format(f"Start querying with `id > {filter_id}`")) + +start_time = time.time() +result = hello_bm25.query(expr=f"id > {filter_id}", output_fields=["document"]) +end_time = time.time() + +print(f"query result:\n-{result[0]}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# pagination +r1 = hello_bm25.query(expr=f"id > {filter_id}", limit=3, output_fields=["document"]) +r2 = hello_bm25.query(expr=f"id > {filter_id}", offset=1, limit=2, output_fields=["document"]) +print(f"query pagination(limit=3):\n\t{r1}") +print(f"query pagination(offset=1, limit=2):\n\t{r2}") + + +# ----------------------------------------------------------------------------- +# scalar filtering search +print(fmt.format(f"Start filtered searching with `id > {filter_id}`")) + +start_time = time.time() +result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, expr=f"id > {filter_id}", output_fields=["document"]) +end_time = time.time() + +for hits, text in zip(result, texts_to_search): + print(f"result of text: {text}") + for hit in hits: + print(f"\thit: {hit}, document field: {hit.entity.get('document')}") +print(search_latency_fmt.format(end_time - start_time)) + +############################################################################### +# 6. delete entities by PK +# You can delete entities by their PK values using boolean expressions. + +expr = f'id in [{ids[0]}, {ids[1]}]' +print(fmt.format(f"Start deleting with expr `{expr}`")) + +result = hello_bm25.query(expr=expr, output_fields=["document"]) +print(f"query before delete by expr=`{expr}` -> result: \n- {result[0]}\n- {result[1]}\n") + +hello_bm25.delete(expr) + +result = hello_bm25.query(expr=expr, output_fields=["document"]) +print(f"query after delete by expr=`{expr}` -> result: {result}\n") + +############################################################################### +# 7. upsert by PK +# You can upsert data to replace existing data. +target_id = ids[2] +print(fmt.format(f"Start upsert operation for id {target_id}")) + +# Query before upsert +result_before = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) +print(f"Query before upsert (id={target_id}):\n{result_before}") + +# Prepare data for upsert +upsert_data = [ + [target_id], + ["This is an upserted document for testing purposes."] +] + +# Perform upsert operation +hello_bm25.upsert(upsert_data) + +# Query after upsert +result_after = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) +print(f"Query after upsert (id={target_id}):\n{result_after}") + + +############################################################################### +# 7. drop collection +# Finally, drop the hello_bm25 collection +print(fmt.format("Drop collection `hello_bm25`")) +utility.drop_collection("hello_bm25") diff --git a/examples/hello_hybrid_bm25.py b/examples/hello_hybrid_bm25.py new file mode 100644 index 000000000..df8ce1704 --- /dev/null +++ b/examples/hello_hybrid_bm25.py @@ -0,0 +1,178 @@ +# A demo showing hybrid semantic search with dense and full text search with BM25 +# using Milvus. +# +# You can optionally choose to use the BGE-M3 model to embed the text as dense +# vectors, or simply use random generated vectors as an example. +# +# You can also use the BGE CrossEncoder model to rerank the search results. +# +# Note that the full text search feature is only available in Milvus 2.4.0 or +# higher version. Make sure you follow https://milvus.io/docs/install_standalone-docker.md +# to set up the latest version of Milvus in your local environment. + +# To connect to Milvus server, you need the python client library called pymilvus. +# To use BGE-M3 model, you need to install the optional `model` module in pymilvus. +# You can get them by simply running the following commands: +# +# pip install pymilvus +# pip install pymilvus[model] + +# If true, use BGE-M3 model to generate dense vectors. +# If false, use random numbers to compose dense vectors. +use_bge_m3 = False +# If true, the search result will be reranked using BGE CrossEncoder model. +use_reranker = False + +# The overall steps are as follows: +# 1. embed the text as dense and sparse vectors +# 2. setup a Milvus collection to store the dense and sparse vectors +# 3. insert the data to Milvus +# 4. search and inspect the result! +import random +import string +import numpy as np + +from pymilvus import ( + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, + AnnSearchRequest, + RRFRanker, + connections, + Function, + FunctionType, +) + +# 1. prepare a small corpus to search +docs = [ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", +] +# add some randomly generated texts +docs.extend( + [ + " ".join( + "".join(random.choice(string.ascii_lowercase) for _ in range(random.randint(1, 8))) + for _ in range(10) + ) + for _ in range(1000) + ] +) +query = "Who started AI research?" + + +def random_embedding(texts): + rng = np.random.default_rng() + return { + "dense": np.random.rand(len(texts), 768), + } + + +dense_dim = 768 +ef = random_embedding + +if use_bge_m3: + # BGE-M3 model is included in the optional `model` module in pymilvus, to + # install it, simply run "pip install pymilvus[model]". + from pymilvus.model.hybrid import BGEM3EmbeddingFunction + + ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu") + dense_dim = ef.dim["dense"] + +docs_embeddings = ef(docs) +query_embeddings = ef([query]) + +# 2. setup Milvus collection and index +connections.connect("default", host="localhost", port="19530") + +# Specify the data schema for the new Collection. +fields = [ + # Use auto generated id as primary key + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), + # Store the original text to retrieve based on semantically distance + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512, enable_tokenizer=True), + # We need a sparse vector field to perform full text search with BM25, + # but you don't need to provide data for it when inserting data. + FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim), +] +functions = [ + Function( + name="bm25", + function_type=FunctionType.BM25, + input_field_names=["text"], + output_field_names="sparse_vector", + ) +] +schema = CollectionSchema(fields, "", functions=functions) +col_name = "hybrid_bm25_demo" +# Now we can create the new collection with above name and schema. +col = Collection(col_name, schema, consistency_level="Strong") + +# We need to create indices for the vector fields. The indices will be loaded +# into memory for efficient search. +sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"} +col.create_index("sparse_vector", sparse_index) +dense_index = {"index_type": "FLAT", "metric_type": "IP"} +col.create_index("dense_vector", dense_index) +col.load() + +# 3. insert text and sparse/dense vector representations into the collection +entities = [docs, docs_embeddings["dense"]] +col.insert(entities) +col.flush() + +# 4. search and inspect the result! +k = 2 # we want to get the top 2 docs closest to the query + +# Prepare the search requests for both full text search and dense vector search +full_text_search_params = {"metric_type": "BM25"} +# provide raw text query for full text search, while use the sparse vector as +# ANNS field +full_text_search_req = AnnSearchRequest([query], "sparse_vector", full_text_search_params, limit=k) +dense_search_params = {"metric_type": "IP"} +dense_req = AnnSearchRequest( + query_embeddings["dense"], "dense_vector", dense_search_params, limit=k +) + +# Search topK docs based on dense and sparse vectors and rerank with RRF. +res = col.hybrid_search( + [full_text_search_req, dense_req], rerank=RRFRanker(), limit=k, output_fields=["text"] +) + +# Currently Milvus only support 1 query in the same hybrid search request, so +# we inspect res[0] directly. In future release Milvus will accept batch +# hybrid search queries in the same call. +res = res[0] + +if use_reranker: + result_texts = [hit.fields["text"] for hit in res] + from pymilvus.model.reranker import BGERerankFunction + + bge_rf = BGERerankFunction(device="cpu") + # rerank the results using BGE CrossEncoder model + results = bge_rf(query, result_texts, top_k=2) + for hit in results: + print(f"text: {hit.text} distance {hit.score}") +else: + for hit in res: + print(f'text: {hit.fields["text"]} distance {hit.distance}') + +# If you used both BGE-M3 and the reranker, you should see the following: +# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.9306981017573297 +# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.03217001154515051 +# +# If you used only BGE-M3, you should see the following: +# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.032786883413791656 +# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.016129031777381897 + +# In this simple example the reranker yields the same result as the embedding based hybrid search, but in more complex +# scenarios the reranker can provide more accurate results. + +# If you used random vectors, the result will be different each time you run the script. + +# Drop the collection to clean up the data. +utility.drop_collection(col_name) diff --git a/examples/milvus_client/bm25.py b/examples/milvus_client/bm25.py new file mode 100644 index 000000000..ef344dffb --- /dev/null +++ b/examples/milvus_client/bm25.py @@ -0,0 +1,89 @@ +from pymilvus import ( + MilvusClient, + Function, + FunctionType, + DataType, +) + +fmt = "\n=== {:30} ===\n" +collection_name = "doc_in_doc_out" +milvus_client = MilvusClient("http://localhost:19530") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema() +schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False) +schema.add_field("document_content", DataType.VARCHAR, max_length=9000, enable_tokenizer=True) +schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR) + +bm25_function = Function( + name="bm25_fn", + input_field_names=["document_content"], + output_field_names="sparse_vector", + function_type=FunctionType.BM25, +) +schema.add_function(bm25_function) + +index_params = milvus_client.prepare_index_params() +index_params.add_index( + field_name="sparse_vector", + index_name="sparse_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + params={"bm25_k1": 1.2, "bm25_b": 0.75}, +) + +ret = milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") +print(ret) + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rows = [ + {"id": 1, "document_content": "hello world"}, + {"id": 2, "document_content": "hello milvus"}, + {"id": 3, "document_content": "hello zilliz"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +print(insert_result) + +texts_to_search = ["zilliz"] +search_params = { + "metric_type": "BM25", + "params": {} +} +print(fmt.format(f"Start search with retrieve several fields.")) +result = milvus_client.search(collection_name, texts_to_search, limit=3, output_fields=["document_content"], search_params=search_params) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[3]) +print(query_results[0]) + +upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "document_content": "hello milvus again"}) +print(upsert_ret) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="document_content == 'hello milvus again'") +for ret in query_results: + print(ret) + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, ids=[3]) +print(delete_result) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="document_content == 'hello zilliz'") +print(f"Query results after deletion: {query_results}") + +milvus_client.drop_collection(collection_name) diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 80ffb2f49..21b5f99ed 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -18,6 +18,7 @@ from .client.types import ( BulkInsertState, DataType, + FunctionType, Group, IndexType, Replica, @@ -38,7 +39,7 @@ from .orm.index import Index from .orm.partition import Partition from .orm.role import Role -from .orm.schema import CollectionSchema, FieldSchema +from .orm.schema import CollectionSchema, FieldSchema, Function from .orm.utility import ( create_resource_group, create_user, @@ -101,6 +102,7 @@ "Group", "Shard", "FieldSchema", + "Function", "CollectionSchema", "SearchFuture", "MutationFuture", @@ -121,6 +123,7 @@ "Prepare", "Status", "DataType", + "FunctionType", "MilvusException", "__version__", "MilvusClient", diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 4b509d9df..57cdbdc8b 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -9,7 +9,7 @@ from . import entity_helper, utils from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED -from .types import DataType +from .types import DataType, FunctionType class FieldSchema: @@ -28,6 +28,7 @@ def __init__(self, raw: Any): self.is_dynamic = False self.nullable = False self.default_value = None + self.is_function_output = False # For array field self.element_type = None self.is_clustering_key = False @@ -48,6 +49,7 @@ def __pack(self, raw: Any): self.default_value = None self.is_dynamic = raw.is_dynamic self.nullable = raw.nullable + self.is_function_output = raw.is_function_output for type_param in raw.type_params: if type_param.key == "params": @@ -114,9 +116,54 @@ def dict(self): _dict["is_primary"] = self.is_primary if self.is_clustering_key: _dict["is_clustering_key"] = True + if self.is_function_output: + _dict["is_function_output"] = True return _dict +class FunctionSchema: + def __init__(self, raw: Any): + self._raw = raw + + self.name = None + self.description = None + self.type = None + self.params = {} + self.input_field_names = [] + self.input_field_ids = [] + self.output_field_names = [] + self.output_field_ids = [] + self.id = 0 + + self.__pack(self._raw) + + def __pack(self, raw: Any): + self.name = raw.name + self.description = raw.description + self.id = raw.id + self.type = FunctionType(raw.type) + self.params = {} + for param in raw.params: + self.params[param.key] = param.value + self.input_field_names = raw.input_field_names + self.input_field_ids = raw.input_field_ids + self.output_field_names = raw.output_field_names + self.output_field_ids = raw.output_field_ids + + def dict(self): + return { + "name": self.name, + "id": self.id, + "description": self.description, + "type": self.type, + "params": self.params, + "input_field_names": self.input_field_names, + "input_field_ids": self.input_field_ids, + "output_field_names": self.output_field_names, + "output_field_ids": self.output_field_ids, + } + + class CollectionSchema: def __init__(self, raw: Any): self._raw = raw @@ -125,6 +172,7 @@ def __init__(self, raw: Any): self.description = None self.params = {} self.fields = [] + self.functions = [] self.statistics = {} self.auto_id = False # auto_id is not in collection level any more later self.aliases = [] @@ -162,6 +210,11 @@ def __pack(self, raw: Any): self.fields = [FieldSchema(f) for f in raw.schema.fields] + self.functions = [FunctionSchema(f) for f in raw.schema.functions] + function_output_field_names = [f for fn in self.functions for f in fn.output_field_names] + for field in self.fields: + if field.name in function_output_field_names: + field.is_function_output = True # for s in raw.statistics: for p in raw.properties: @@ -187,6 +240,7 @@ def dict(self): "num_shards": self.num_shards, "description": self.description, "fields": [f.dict() for f in self.fields], + "functions": [f.dict() for f in self.functions], "aliases": self.aliases, "collection_id": self.collection_id, "consistency_level": self.consistency_level, diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index 311ee9321..c76ab7f7a 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -187,7 +187,7 @@ def is_legal_search_data(data: Any) -> bool: if not isinstance(data, (list, np.ndarray)): return False - return all(isinstance(vector, (list, bytes, np.ndarray)) for vector in data) + return all(isinstance(vector, (list, bytes, np.ndarray, str)) for vector in data) def is_legal_output_fields(output_fields: Any) -> bool: diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 3de4a6867..7adf2fba5 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -134,6 +134,7 @@ def get_schema_from_collection_schema( is_dynamic=f.is_dynamic, element_type=f.element_type, is_clustering_key=f.is_clustering_key, + is_function_output=f.is_function_output, ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair( @@ -142,6 +143,20 @@ def get_schema_from_collection_schema( field_schema.type_params.append(kv_pair) schema.fields.append(field_schema) + + for f in fields.functions: + function_schema = schema_types.FunctionSchema( + name=f.name, + description=f.description, + type=f.type, + input_field_names=f.input_field_names, + output_field_names=f.output_field_names, + ) + for k, v in f.params.items(): + kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) + function_schema.params.append(kv_pair) + schema.functions.append(function_schema) + return schema @staticmethod @@ -370,6 +385,12 @@ def partition_name(cls, collection_name: str, partition_name: str): raise ParamError(message="partition_name must be of str type") return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name) + @staticmethod + def _num_input_fields(fields_info: List[Dict]): + return len(fields_info) - len( + [field for field in fields_info if field.get("is_function_output", False)] + ) + @staticmethod def _parse_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], @@ -434,19 +455,25 @@ def _parse_row_request( raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e request.fields_data.extend( - [fields_data[field["name"]] for field in fields_info if not field.get("auto_id", False)] + [ + fields_data[field["name"]] + for field in fields_info + if not field.get("is_function_output", False) and not field.get("auto_id", False) + ] ) if enable_dynamic: request.fields_data.append(d_field) + num_input_fields = Prepare._num_input_fields(fields_info) + _, _, auto_id_loc = traverse_rows_info(fields_info, entities) if auto_id_loc is not None: - if (enable_dynamic and len(fields_data) != len(fields_info)) or ( - not enable_dynamic and len(fields_data) + 1 != len(fields_info) + if (enable_dynamic and len(fields_data) != num_input_fields) or ( + not enable_dynamic and len(fields_data) + 1 != num_input_fields ): raise ParamError(ExceptionsMessage.FieldsNumInconsistent) - elif enable_dynamic and len(fields_data) != len(fields_info) + 1: + elif enable_dynamic and len(fields_data) != num_input_fields + 1: raise ParamError(ExceptionsMessage.FieldsNumInconsistent) return request @@ -600,12 +627,14 @@ def _pre_insert_batch_check( if primary_key_loc is None: raise ParamError(message="primary key not found") - if auto_id_loc is None and len(entities) != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + num_input_fields = Prepare._num_input_fields(fields_info) + + if auto_id_loc is None and len(entities) != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) - if auto_id_loc is not None and len(entities) + 1 != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + if auto_id_loc is not None and len(entities) + 1 != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) return location @@ -632,8 +661,10 @@ def _pre_upsert_batch_check( if primary_key_loc is None: raise ParamError(message="primary key not found") - if len(entities) != len(fields_info): - msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + num_input_fields = Prepare._num_input_fields(fields_info) + + if len(entities) != num_input_fields: + msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" raise ParamError(msg) return location @@ -764,6 +795,10 @@ def _prepare_placeholder_str(cls, data: Any): pl_type = PlaceholderType.BinaryVector pl_values = data # data is already a list of bytes + elif isinstance(data[0], str): + pl_type = PlaceholderType.VARCHAR + pl_values = (value.encode("utf-8") for value in data) + else: pl_type = PlaceholderType.FloatVector pl_values = (blob.vector_float_to_bytes(entity) for entity in data) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index adee2f10b..48b8aefcd 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -108,6 +108,11 @@ class DataType(IntEnum): UNKNOWN = 999 +class FunctionType(IntEnum): + UNKNOWN = 0 + BM25 = 1 + + class RangeType(IntEnum): LT = 0 # less than LTE = 1 # less than or equal @@ -173,6 +178,7 @@ class PlaceholderType(IntEnum): FLOAT16_VECTOR = 102 BFLOAT16_VECTOR = 103 SparseFloatVector = 104 + VARCHAR = 21 class State(IntEnum): diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 50347f6a3..73051d72c 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -129,6 +129,10 @@ class FieldsTypeException(MilvusException): """Raise when fields is invalid""" +class FunctionsTypeException(MilvusException): + """Raise when functions are invalid""" + + class FieldTypeException(MilvusException): """Raise when one field is invalid""" @@ -204,8 +208,26 @@ class ExceptionsMessage: IndexNotExist = "Index doesn't exist." CollectionType = "The type of collection must be pymilvus.Collection." FieldsType = "The fields of schema must be type list." + FunctionsType = "The functions of collection must be type list." + FunctionIncorrectInputOutputType = "The type of function input and output must be str." + FunctionInvalidOutputField = ( + "The output field must not be primary key, partition key, clustering key." + ) + FunctionDuplicateInputs = "Duplicate input field names are not allowed in function." + FunctionDuplicateOutputs = "Duplicate output field names are not allowed in function." + FunctionCommonInputOutput = "Input and output field names must be different." + BM25FunctionIncorrectInputOutputCount = ( + "BM25 function must have exact 1 input and 1 output field." + ) + BM25FunctionIncorrectInputFieldType = "BM25 function input field must be VARCHAR." + BM25FunctionIncorrectOutputFieldType = "BM25 function output field must be SPARSE_FLOAT_VECTOR." + FunctionMissingInputField = "Function input field not found in collection schema." + FunctionMissingOutputField = "Function output field not found in collection schema." + UnknownFunctionType = "Unknown function type." + FunctionIncorrectType = "The function of schema type must be Function." FieldType = "The field of schema type must be FieldSchema." FieldDtype = "Field dtype must be of DataType" + FieldNamesDuplicate = "Duplicate field names are not allowed." ExprType = "The type of expr must be string ,but %r is given." EnvConfigErr = "Environment variable %s has a wrong format, please check it: %s" AmbiguousIndexName = "There are multiple indexes, please specify the index_name." diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 7d903bfd0..2f9fc58ab 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -62,3 +62,5 @@ MAX_TRY_TIME: int = 20 GUARANTEE_TIMESTAMP = "guarantee_timestamp" ITERATOR_SESSION_CP_FILE = "iterator_cp_file" +BM25_k1 = "bm25_k1" +BM25_b = "bm25_b" diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 3af4bdbc1..44080b68d 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -10,7 +10,6 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -import copy from typing import List, Tuple, Union import numpy as np @@ -53,18 +52,21 @@ def prepare_data( for field in fields: if field.is_primary and field.auto_id and is_insert: continue + if field.is_function_output: + continue values = [] if field.name in list(data.columns): values = list(data[field.name]) entities.append({"name": field.name, "type": field.dtype, "values": values}) return entities - tmp_fields = copy.deepcopy(fields) - for i, field in enumerate(tmp_fields): - # TODO Goose: Checking auto_id and is_primary only, maybe different than - # schema.is_primary, schema.auto_id, need to check why and how schema is built. - if field.is_primary and field.auto_id and is_insert: - tmp_fields.pop(i) + tmp_fields = list( + filter( + lambda field: not (field.is_primary and field.auto_id and is_insert) + and not field.is_function_output, + fields, + ) + ) vec_dtype_checker = { DataType.FLOAT_VECTOR: lambda ndarr: ndarr.dtype in ("float32", "float64"), diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 54bc8b1a5..5f33309dd 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -16,6 +16,7 @@ import pandas as pd from pandas.api.types import is_list_like, is_scalar +from pymilvus.client.types import FunctionType from pymilvus.exceptions import ( AutoIDException, CannotInferSchemaException, @@ -25,6 +26,7 @@ ExceptionsMessage, FieldsTypeException, FieldTypeException, + FunctionsTypeException, ParamError, PartitionKeyException, PrimaryKeyException, @@ -85,7 +87,9 @@ def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field class CollectionSchema: - def __init__(self, fields: List, description: str = "", **kwargs): + def __init__( + self, fields: List, description: str = "", functions: Optional[List] = None, **kwargs + ): self._kwargs = copy.deepcopy(kwargs) self._fields = [] self._description = description @@ -95,10 +99,25 @@ def __init__(self, fields: List, description: str = "", **kwargs): self._partition_key_field = None self._clustering_key_field = None + if functions is None: + functions = [] + + if not isinstance(functions, list): + raise FunctionsTypeException(message=ExceptionsMessage.FunctionsType) + for function in functions: + if not isinstance(function, Function): + raise SchemaNotReadyException(message=ExceptionsMessage.FunctionIncorrectType) + self._functions = [copy.deepcopy(function) for function in functions] + if not isinstance(fields, list): raise FieldsTypeException(message=ExceptionsMessage.FieldsType) + for field in fields: + if not isinstance(field, FieldSchema): + raise FieldTypeException(message=ExceptionsMessage.FieldType) self._fields = [copy.deepcopy(field) for field in fields] + self._mark_output_fields() + self._check_kwargs() if kwargs.get("check_fields", True): self._check_fields() @@ -114,10 +133,6 @@ def _check_kwargs(self): if clustering_key_field_name is not None and not isinstance(clustering_key_field_name, str): raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyFieldType) - for field in self._fields: - if not isinstance(field, FieldSchema): - raise FieldTypeException(message=ExceptionsMessage.FieldType) - if "auto_id" in self._kwargs and not isinstance(self._kwargs["auto_id"], bool): raise AutoIDException(0, ExceptionsMessage.AutoIDType) @@ -125,6 +140,9 @@ def _check_fields(self): primary_field_name = self._kwargs.get("primary_field", None) partition_key_field_name = self._kwargs.get("partition_key_field", None) clustering_key_field_name = self._kwargs.get("clustering_key_field", None) + field_names = [field.name for field in self._fields] + if len(field_names) != len(set(field_names)): + raise ParamError(message=ExceptionsMessage.FieldNamesDuplicate) for field in self._fields: if primary_field_name and primary_field_name == field.name: field.is_primary = True @@ -175,9 +193,48 @@ def _check_fields(self): if auto_id: self._primary_field.auto_id = auto_id + def _check_functions(self): + for function in self._functions: + for output_field_name in function.output_field_names: + output_field = next( + (field for field in self._fields if field.name == output_field_name), None + ) + if output_field is None: + raise ParamError( + message=f"{ExceptionsMessage.FunctionMissingOutputField}: {output_field_name}" + ) + + if output_field is not None and ( + output_field.is_primary + or output_field.is_partition_key + or output_field.is_clustering_key + ): + raise ParamError(message=ExceptionsMessage.FunctionInvalidOutputField) + + for input_field_name in function.input_field_names: + input_field = next( + (field for field in self._fields if field.name == input_field_name), None + ) + if input_field is None: + raise ParamError( + message=f"{ExceptionsMessage.FunctionMissingInputField}: {input_field_name}" + ) + + function.verify(self) + + def _mark_output_fields(self): + for function in self._functions: + for output_field_name in function.output_field_names: + output_field = next( + (field for field in self._fields if field.name == output_field_name), None + ) + if output_field is not None: + output_field.is_function_output = True + def _check(self): self._check_kwargs() self._check_fields() + self._check_functions() def __repr__(self) -> str: return str(self.to_dict()) @@ -192,9 +249,15 @@ def __eq__(self, other: object): @classmethod def construct_from_dict(cls, raw: Dict): fields = [FieldSchema.construct_from_dict(field_raw) for field_raw in raw["fields"]] + if "functions" in raw: + functions = [ + Function.construct_from_dict(function_raw) for function_raw in raw["functions"] + ] + else: + functions = [] enable_dynamic_field = raw.get("enable_dynamic_field", False) return CollectionSchema( - fields, raw.get("description", ""), enable_dynamic_field=enable_dynamic_field + fields, raw.get("description", ""), functions, enable_dynamic_field=enable_dynamic_field ) @property @@ -222,6 +285,16 @@ def fields(self): """ return self._fields + @property + def functions(self): + """ + Returns the functions of the CollectionSchema. + + :return list: + List of Function, return when operation is successful. + """ + return self._functions + @property def description(self): """ @@ -273,12 +346,15 @@ def enable_dynamic_field(self, value: bool): self._enable_dynamic_field = bool(value) def to_dict(self): - return { + res = { "auto_id": self.auto_id, "description": self._description, "fields": [s.to_dict() for s in self._fields], "enable_dynamic_field": self.enable_dynamic_field, } + if self._functions is not None and len(self._functions) > 0: + res["functions"] = [s.to_dict() for s in self._functions] + return res def verify(self): # final check, detect obvious problems @@ -287,6 +363,14 @@ def verify(self): def add_field(self, field_name: str, datatype: DataType, **kwargs): field = FieldSchema(field_name, datatype, **kwargs) self._fields.append(field) + self._mark_output_fields() + return self + + def add_function(self, function: "Function"): + if not isinstance(function, Function): + raise ParamError(message=ExceptionsMessage.FunctionIncorrectType) + self._functions.append(function) + self._mark_output_fields() return self @@ -333,6 +417,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs) if "mmap_enabled" in kwargs: self._type_params["mmap_enabled"] = kwargs["mmap_enabled"] self._parse_type_params() + self.is_function_output = False def __repr__(self) -> str: return str(self.to_dict()) @@ -361,6 +446,13 @@ def _parse_type_params(self): if k in self._kwargs: if self._type_params is None: self._type_params = {} + if isinstance(self._kwargs[k], str): + if self._kwargs[k].lower() == "true": + self._type_params[k] = True + continue + if self._kwargs[k].lower() == "false": + self._type_params[k] = False + continue self._type_params[k] = self._kwargs[k] @classmethod @@ -377,7 +469,10 @@ def construct_from_dict(cls, raw: Dict): kwargs["is_dynamic"] = raw.get("is_dynamic", False) kwargs["nullable"] = raw.get("nullable", False) kwargs["element_type"] = raw.get("element_type") - return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) + is_function_output = raw.get("is_function_output", False) + fs = FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) + fs.is_function_output = is_function_output + return fs def to_dict(self): _dict = { @@ -404,6 +499,8 @@ def to_dict(self): _dict["element_type"] = self.element_type if self.is_clustering_key: _dict["is_clustering_key"] = True + if self.is_function_output: + _dict["is_function_output"] = True return _dict def __getattr__(self, item: str): @@ -456,6 +553,115 @@ def dtype(self) -> DataType: return self._dtype +class Function: + def __init__( + self, + name: str, + function_type: FunctionType, + input_field_names: Union[str, List[str]], + output_field_names: Union[str, List[str]], + description: str = "", + params: Optional[Dict] = None, + ): + self._name = name + self._description = description + input_field_names = ( + [input_field_names] if isinstance(input_field_names, str) else input_field_names + ) + output_field_names = ( + [output_field_names] if isinstance(output_field_names, str) else output_field_names + ) + try: + self._type = FunctionType(function_type) + except ValueError as err: + raise ParamError(message=ExceptionsMessage.UnknownFunctionType) from err + + for field_name in list(input_field_names) + list(output_field_names): + if not isinstance(field_name, str): + raise ParamError(message=ExceptionsMessage.FunctionIncorrectInputOutputType) + if len(input_field_names) != len(set(input_field_names)): + raise ParamError(message=ExceptionsMessage.FunctionDuplicateInputs) + if len(output_field_names) != len(set(output_field_names)): + raise ParamError(message=ExceptionsMessage.FunctionDuplicateOutputs) + + if set(input_field_names) & set(output_field_names): + raise ParamError(message=ExceptionsMessage.FunctionCommonInputOutput) + + self._input_field_names = input_field_names + self._output_field_names = output_field_names + self._params = params if params is not None else {} + + @property + def name(self): + return self._name + + @property + def description(self): + return self._description + + @property + def type(self): + return self._type + + @property + def input_field_names(self): + return self._input_field_names + + @property + def output_field_names(self): + return self._output_field_names + + @property + def params(self): + return self._params + + def verify(self, schema: CollectionSchema): + if self._type == FunctionType.BM25: + if len(self._input_field_names) != 1 or len(self._output_field_names) != 1: + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputOutputCount) + + for field in schema.fields: + if field.name == self._input_field_names[0] and field.dtype != DataType.VARCHAR: + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType) + if ( + field.name == self._output_field_names[0] + and field.dtype != DataType.SPARSE_FLOAT_VECTOR + ): + raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType) + + elif self._type == FunctionType.UNKNOWN: + raise ParamError(message=ExceptionsMessage.UnknownFunctionType) + + @classmethod + def construct_from_dict(cls, raw: Dict): + return Function( + raw["name"], + raw["type"], + raw["input_field_names"], + raw["output_field_names"], + raw["description"], + raw["params"], + ) + + def __repr__(self) -> str: + return str(self.to_dict()) + + def to_dict(self): + return { + "name": self._name, + "description": self._description, + "type": self._type, + "input_field_names": self._input_field_names, + "output_field_names": self._output_field_names, + "params": self._params, + } + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Function): + return False + return self.to_dict() == value.to_dict() + + def is_valid_insert_data(data: Union[pd.DataFrame, list, dict]) -> bool: """DataFrame, list, dict are valid insert data""" return isinstance(data, (pd.DataFrame, list, dict)) @@ -501,7 +707,7 @@ def _check_insert_data(data: Union[List[List], pd.DataFrame]): def _check_data_schema_cnt(fields: List, data: Union[List[List], pd.DataFrame]): - field_cnt = len(fields) + field_cnt = len([f for f in fields if not f.is_function_output]) is_dataframe = isinstance(data, pd.DataFrame) data_cnt = len(data.columns) if is_dataframe else len(data) if field_cnt != data_cnt: @@ -534,10 +740,12 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat columns.remove(schema.primary_field) data = data[columns] - tmp_fields = copy.deepcopy(schema.fields) - for i, field in enumerate(tmp_fields): - if field.is_primary and field.auto_id: - tmp_fields.pop(i) + tmp_fields = list( + filter( + lambda field: not (field.is_primary and field.auto_id) and not field.is_function_output, + schema.fields, + ) + ) _check_data_schema_cnt(tmp_fields, data) _check_insert_data(data)