Skip to content

Commit

Permalink
Merge branch 'milvus-27469-normalization' of https://github.com/CaoHa…
Browse files Browse the repository at this point in the history
…iNam/pymilvus into milvus-27469-normalization
  • Loading branch information
CaoHaiNam committed Nov 7, 2024
2 parents 6c5354e + a0a3c56 commit 729b453
Show file tree
Hide file tree
Showing 26 changed files with 812 additions and 481 deletions.
2 changes: 1 addition & 1 deletion examples/hello_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000, enable_tokenizer=True),
FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000, enable_analyzer=True),
]

bm25_function = Function(
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_hybrid_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def random_embedding(texts):
# Use auto generated id as primary key
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
# Store the original text to retrieve based on semantically distance
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512, enable_tokenizer=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512, enable_analyzer=True),
# We need a sparse vector field to perform full text search with BM25,
# but you don't need to provide data for it when inserting data.
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
Expand Down
2 changes: 1 addition & 1 deletion examples/milvus_client/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

schema = milvus_client.create_schema()
schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False)
schema.add_field("document_content", DataType.VARCHAR, max_length=9000, enable_tokenizer=True)
schema.add_field("document_content", DataType.VARCHAR, max_length=9000, enable_analyzer=True)
schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)

bm25_function = Function(
Expand Down
83 changes: 83 additions & 0 deletions examples/milvus_client/compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)
milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2")

rng = np.random.default_rng(seed=19530)
rows = [
{"id": 1, "vector": rng.random((1, dim))[0], "a": 100},
{"id": 2, "vector": rng.random((1, dim))[0], "b": 200},
{"id": 3, "vector": rng.random((1, dim))[0], "c": 300},
{"id": 4, "vector": rng.random((1, dim))[0], "d": 400},
{"id": 5, "vector": rng.random((1, dim))[0], "e": 500},
{"id": 6, "vector": rng.random((1, dim))[0], "f": 600},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)

upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "vector": rng.random((1, dim))[0], "g": 100})
print(upsert_ret)

print(fmt.format("Start flush"))
milvus_client.flush(collection_name)
print(fmt.format("flush done"))

result = milvus_client.query(collection_name, "", output_fields = ["count(*)"])
print(f"final entities in {collection_name} is {result[0]['count(*)']}")

rows = [
{"id": 7, "vector": rng.random((1, dim))[0], "g": 700},
{"id": 8, "vector": rng.random((1, dim))[0], "h": 800},
{"id": 9, "vector": rng.random((1, dim))[0], "i": 900},
{"id": 10, "vector": rng.random((1, dim))[0], "j": 1000},
{"id": 11, "vector": rng.random((1, dim))[0], "k": 1100},
{"id": 12, "vector": rng.random((1, dim))[0], "l": 1200},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)

print(fmt.format("Start flush"))
milvus_client.flush(collection_name)
print(fmt.format("flush done"))

result = milvus_client.query(collection_name, "", output_fields = ["count(*)"])
print(f"final entities in {collection_name} is {result[0]['count(*)']}")

print(fmt.format("Start compact"))
job_id = milvus_client.compact(collection_name)
print(f"job_id:{job_id}")

cnt = 0
state = milvus_client.get_compaction_state(job_id)
while (state != "Completed" and cnt < 10):
time.sleep(1.0)
state = milvus_client.get_compaction_state(job_id)
print(f"compaction state: {state}")
cnt += 1

if state == "Completed":
print(fmt.format("compact done"))
else:
print(fmt.format("compact timeout"))

result = milvus_client.query(collection_name, "", output_fields = ["count(*)"])
print(f"final entities in {collection_name} is {result[0]['count(*)']}")

milvus_client.drop_collection(collection_name)
57 changes: 57 additions & 0 deletions examples/milvus_client/flush.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)
milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2")

rng = np.random.default_rng(seed=19530)
rows = [
{"id": 1, "vector": rng.random((1, dim))[0], "a": 100},
{"id": 2, "vector": rng.random((1, dim))[0], "b": 200},
{"id": 3, "vector": rng.random((1, dim))[0], "c": 300},
{"id": 4, "vector": rng.random((1, dim))[0], "d": 400},
{"id": 5, "vector": rng.random((1, dim))[0], "e": 500},
{"id": 6, "vector": rng.random((1, dim))[0], "f": 600},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)

upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "vector": rng.random((1, dim))[0], "g": 100})
print(upsert_ret)

print(fmt.format("Start flush"))
milvus_client.flush(collection_name)
print(fmt.format("flush done"))


result = milvus_client.query(collection_name, "", output_fields = ["count(*)"])
print(f"final entities in {collection_name} is {result[0]['count(*)']}")


print(f"start to delete by specifying filter in collection {collection_name}")
delete_result = milvus_client.delete(collection_name, ids=[6])
print(delete_result)


print(fmt.format("Start flush"))
milvus_client.flush(collection_name)
print(fmt.format("flush done"))


result = milvus_client.query(collection_name, "", output_fields = ["count(*)"])
print(f"final entities in {collection_name} is {result[0]['count(*)']}")

milvus_client.drop_collection(collection_name)
8 changes: 8 additions & 0 deletions examples/milvus_client/get_server_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pymilvus import (
MilvusClient,
)

milvus_client = MilvusClient("http://localhost:19530")

version = milvus_client.get_server_version()
print(f"server version: {version}")
2 changes: 2 additions & 0 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
for field in schema.fields:
if field.is_primary and field.auto_id:
continue
if field.is_function_output:
continue
self._buffer[field.name] = []
self._fields[field.name] = field

Expand Down
5 changes: 5 additions & 0 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def _verify_row(self, row: dict):
)
else:
continue
if field.is_function_output:
if field.name in row:
self._throw(f"Field '{field.name}' is function output, no need to provide")
else:
continue

if field.name not in row:
self._throw(f"The field '{field.name}' is missed in the row")
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
GROUP_BY_FIELD = "group_by_field"
GROUP_SIZE = "group_size"
RANK_GROUP_SCORER = "rank_group_scorer"
GROUP_STRICT_SIZE = "group_strict_size"
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"
Expand Down
6 changes: 3 additions & 3 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def convert_to_array(obj: List[Any], field_info: Any):
field_data.string_data.data.extend(obj)
return field_data
raise ParamError(
message=f"UnSupported element type: {element_type} for Array field: {field_info.get('name')}"
message=f"Unsupported element type: {element_type} for Array field: {field_info.get('name')}"
)


Expand Down Expand Up @@ -424,7 +424,7 @@ def pack_field_value_to_field_data(
% (field_name, "array", type(field_value))
) from e
else:
raise ParamError(message=f"UnSupported data type: {field_type}")
raise ParamError(message=f"Unsupported data type: {field_type}")


# TODO: refactor here.
Expand Down Expand Up @@ -562,7 +562,7 @@ def entity_to_field_data(entity: Any, field_info: Any, num_rows: int):
% (field_name, "sparse_float_vector", type(entity.get("values")[0]))
) from e
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")
raise ParamError(message=f"Unsupported data type: {entity_type}")

return field_data

Expand Down
6 changes: 3 additions & 3 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _prepare_batch_insert_request(
if param and not isinstance(param, milvus_types.InsertRequest):
raise ParamError(message="The value of key 'insert_param' is invalid")
if not isinstance(entities, list):
raise ParamError(message="None entities, please provide valid entities.")
raise ParamError(message="'entities' must be a list, please provide valid entity data.")

schema = kwargs.get("schema")
if not schema:
Expand Down Expand Up @@ -634,7 +634,7 @@ def _prepare_batch_upsert_request(
if param and not isinstance(param, milvus_types.UpsertRequest):
raise ParamError(message="The value of key 'upsert_param' is invalid")
if not isinstance(entities, list):
raise ParamError(message="None entities, please provide valid entities.")
raise ParamError(message="'entities' must be a list, please provide valid entity data.")

schema = kwargs.get("schema")
if not schema:
Expand Down Expand Up @@ -691,7 +691,7 @@ def _prepare_row_upsert_request(
**kwargs,
):
if not isinstance(rows, list):
raise ParamError(message="None rows, please provide valid row data.")
raise ParamError(message="'rows' must be a list, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
return Prepare.row_upsert_param(
Expand Down
69 changes: 44 additions & 25 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
DYNAMIC_FIELD_NAME,
GROUP_BY_FIELD,
GROUP_SIZE,
GROUP_STRICT_SIZE,
ITERATOR_FIELD,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
REDUCE_STOP_FOR_BEST,
STRICT_GROUP_SIZE,
)
from .types import (
DataType,
Expand Down Expand Up @@ -808,12 +808,48 @@ def _prepare_placeholder_str(cls, data: Any):

@classmethod
def prepare_expression_template(cls, values: Dict) -> Any:
def all_elements_same_type(lst: List):
return all(isinstance(item, type(lst[0])) for item in lst)

def add_array_data(v: List) -> schema_types.TemplateArrayValue:
data = schema_types.TemplateArrayValue()
if len(v) == 0:
return data
element_type = (
infer_dtype_by_scalar_data(v[0]) if all_elements_same_type(v) else schema_types.JSON
)
if element_type in (schema_types.Bool,):
data.bool_data.data.extend(v)
return data
if element_type in (
schema_types.Int8,
schema_types.Int16,
schema_types.Int32,
schema_types.Int64,
):
data.long_data.data.extend(v)
return data
if element_type in (schema_types.Float, schema_types.Double):
data.double_data.data.extend(v)
return data
if element_type in (schema_types.VarChar, schema_types.String):
data.string_data.data.extend(v)
return data
if element_type in (schema_types.Array,):
for e in v:
data.array_data.data.append(add_array_data(e))
return data
if element_type in (schema_types.JSON,):
for e in v:
data.json_data.data.append(entity_helper.convert_to_json(e))
return data
raise ParamError(message=f"Unsupported element type: {element_type}")

def add_data(v: Any) -> schema_types.TemplateValue:
dtype = infer_dtype_by_scalar_data(v)
data = schema_types.TemplateValue()
if dtype in (schema_types.Bool,):
data.bool_val = v
data.type = schema_types.Bool
return data
if dtype in (
schema_types.Int8,
Expand All @@ -822,38 +858,21 @@ def add_data(v: Any) -> schema_types.TemplateValue:
schema_types.Int64,
):
data.int64_val = v
data.type = schema_types.Int64
return data
if dtype in (schema_types.Float, schema_types.Double):
data.float_val = v
data.type = schema_types.Double
return data
if dtype in (schema_types.VarChar, schema_types.String):
data.string_val = v
data.type = schema_types.VarChar
return data
if dtype in (schema_types.Array,):
element_datas = schema_types.TemplateArrayValue()
same_type = True
element_type = None
for element in v:
rdata = add_data(element)
element_datas.array.append(rdata)
if element_type is None:
element_type = rdata.type
elif element_type != rdata.type:
same_type = False
element_datas.element_type = element_type if same_type else schema_types.JSON
element_datas.same_type = same_type
data.array_val.CopyFrom(element_datas)
data.type = schema_types.Array
data.array_val.CopyFrom(add_array_data(v))
return data
raise ParamError(message=f"Unsupported element type: {dtype}")

expression_template_values = {}
for k, v in values.items():
expression_template_values[k] = add_data(v)

return expression_template_values

@classmethod
Expand Down Expand Up @@ -920,9 +939,9 @@ def search_requests_with_expr(
if group_size is not None:
search_params[GROUP_SIZE] = group_size

group_strict_size = kwargs.get(GROUP_STRICT_SIZE)
if group_strict_size is not None:
search_params[GROUP_STRICT_SIZE] = group_strict_size
strict_group_size = kwargs.get(STRICT_GROUP_SIZE)
if strict_group_size is not None:
search_params[STRICT_GROUP_SIZE] = strict_group_size

if param.get("metric_type") is not None:
search_params["metric_type"] = param["metric_type"]
Expand Down Expand Up @@ -1016,11 +1035,11 @@ def hybrid_search_request_with_ranker(
]
)

if kwargs.get(GROUP_STRICT_SIZE) is not None:
if kwargs.get(STRICT_GROUP_SIZE) is not None:
request.rank_params.extend(
[
common_types.KeyValuePair(
key=GROUP_STRICT_SIZE, value=utils.dumps(kwargs.get(GROUP_STRICT_SIZE))
key=STRICT_GROUP_SIZE, value=utils.dumps(kwargs.get(STRICT_GROUP_SIZE))
)
]
)
Expand Down
Loading

0 comments on commit 729b453

Please sign in to comment.