Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: upsert rows when set autoid==true fail(#2286) #2287

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymilvus/bulk_writer/local_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def commit(self, **kwargs):
f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}"
)
_async = kwargs.get("_async", False)
call_back = kwargs.get("call_back", None)
call_back = kwargs.get("call_back")

x = Thread(target=self._flush, args=(call_back,))
logger.info(f"Flush thread begin, name: {x.name}")
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/asynch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def result(self, **kwargs):
self.exception()
with self._condition:
# future not finished. wait callback being called.
to = kwargs.get("timeout", None)
to = kwargs.get("timeout")
if to is None:
to = self._kwargs.get("timeout", None)

Expand Down
36 changes: 18 additions & 18 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def __init__(
self._address = addr if addr is not None else self.__get_address(uri, host, port)
self._log_level = None
self._request_id = None
self._user = kwargs.get("user", None)
self._user = kwargs.get("user")
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name", None))
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel()
self.callbacks = []

Expand Down Expand Up @@ -125,9 +125,9 @@ def _set_authorization(self, **kwargs):

self._authorization_interceptor = None
self._setup_authorization_interceptor(
kwargs.get("user", None),
kwargs.get("password", None),
kwargs.get("token", None),
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)

def __enter__(self):
Expand Down Expand Up @@ -470,7 +470,7 @@ def get_partition_stats(
return response.stats

def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
schema = kwargs.get("schema", None)
schema = kwargs.get("schema")
if not schema:
schema = self.describe_collection(collection_name, timeout=timeout)

Expand Down Expand Up @@ -567,7 +567,7 @@ def batch_insert(
)
rf = self._stub.Insert.future(request, timeout=timeout)
if kwargs.get("_async", False):
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(rf, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -599,12 +599,12 @@ def delete(
partition_name,
expression,
consistency_level=kwargs.get("consistency_level", 0),
param_name=kwargs.get("param_name", None),
param_name=kwargs.get("param_name"),
)
future = self._stub.Delete.future(req, timeout=timeout)

if kwargs.get("_async", False):
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(future, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -664,7 +664,7 @@ def upsert(
)
rf = self._stub.Upsert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(rf, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -727,7 +727,7 @@ def _execute_search(
try:
if kwargs.get("_async", False):
future = self._stub.Search.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
func = kwargs.get("_callback")
return SearchFuture(future, func)

response = self._stub.Search(request, timeout=timeout)
Expand All @@ -746,7 +746,7 @@ def _execute_hybrid_search(
try:
if kwargs.get("_async", False):
future = self._stub.HybridSearch.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
func = kwargs.get("_callback")
return SearchFuture(future, func)

response = self._stub.HybridSearch(request, timeout=timeout)
Expand Down Expand Up @@ -781,7 +781,7 @@ def search(
search_data=data,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
timeout=timeout,
)

Expand Down Expand Up @@ -817,7 +817,7 @@ def hybrid_search(
round_decimal=round_decimal,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
timeout=timeout,
)

Expand Down Expand Up @@ -977,7 +977,7 @@ def _check():

index_future = CreateIndexFuture(future)
index_future.add_callback(_check)
user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
index_future.add_callback(user_cb)
return index_future
Expand Down Expand Up @@ -1252,7 +1252,7 @@ def _check():
load_partitions_future = LoadPartitionsFuture(future)
load_partitions_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
load_partitions_future.add_callback(user_cb)

Expand Down Expand Up @@ -1461,7 +1461,7 @@ def _check():
flush_future = FlushFuture(future)
flush_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
flush_future.add_callback(user_cb)

Expand Down Expand Up @@ -1988,7 +1988,7 @@ def _check():
flush_future = FlushFuture(future)
flush_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
flush_future.add_callback(user_cb)

Expand Down
77 changes: 72 additions & 5 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create_collection_request(
raise ParamError(message=msg)
req.shards_num = num_shards

num_partitions = kwargs.get("num_partitions", None)
num_partitions = kwargs.get("num_partitions")
if num_partitions is not None:
if not isinstance(num_partitions, int) or isinstance(num_partitions, bool):
msg = f"invalid num_partitions type, got {type(num_partitions)}, expected int"
Expand Down Expand Up @@ -424,6 +424,73 @@ def _parse_row_request(
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@staticmethod
def _parse_upsert_row_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
fields_info: dict,
enable_dynamic: bool,
entities: List,
):
fields_data = {
field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"])
for field in fields_info
}
field_info_map = {field["name"]: field for field in fields_info}

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

try:
for entity in entities:
if not isinstance(entity, Dict):
msg = f"expected Dict, got '{type(entity).__name__}'"
raise TypeError(msg)
for k, v in entity.items():
if k not in fields_data and not enable_dynamic:
raise DataNotMatchException(
message=ExceptionsMessage.InsertUnexpectedField % k
)

if k in fields_data:
field_info, field_data = field_info_map[k], fields_data[k]
entity_helper.pack_field_value_to_field_data(v, field_data, field_info)

json_dict = {
k: v for k, v in entity.items() if k not in fields_data and enable_dynamic
}

if enable_dynamic:
json_value = entity_helper.convert_to_json(json_dict)
d_field.scalars.json_data.data.append(json_value)

except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

request.fields_data.extend([fields_data[field["name"]] for field in fields_info])

if enable_dynamic:
request.fields_data.append(d_field)

for _, field in enumerate(fields_info):
is_dynamic = False
field_name = field["name"]

if field.get("is_dynamic", False):
is_dynamic = True

for j, entity in enumerate(entities):
if is_dynamic and field_name in entity:
raise ParamError(
message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]"
)
if (enable_dynamic and len(fields_data) != len(fields_info) + 1) or (
not enable_dynamic and len(fields_data) != len(fields_info)
):
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@classmethod
def row_insert_param(
cls,
Expand Down Expand Up @@ -466,7 +533,7 @@ def row_upsert_param(
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)

@staticmethod
def _pre_insert_batch_check(
Expand Down Expand Up @@ -986,11 +1053,11 @@ def query_request(
consistency_level=kwargs.get("consistency_level", 0),
)

limit = kwargs.get("limit", None)
limit = kwargs.get("limit")
if limit is not None:
req.query_params.append(common_types.KeyValuePair(key="limit", value=str(limit)))

offset = kwargs.get("offset", None)
offset = kwargs.get("offset")
if offset is not None:
req.query_params.append(common_types.KeyValuePair(key="offset", value=str(offset)))

Expand Down Expand Up @@ -1063,7 +1130,7 @@ def get_replicas(cls, collection_id: int):

@classmethod
def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, **kwargs):
channel_names = kwargs.get("channel_names", None)
channel_names = kwargs.get("channel_names")
req = milvus_types.ImportRequest(
collection_name=collection_name,
partition_name=partition_name,
Expand Down
8 changes: 4 additions & 4 deletions pymilvus/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def wrapper(func: Any):
def handler(*args, **kwargs):
# This has to make sure every timeout parameter is passing
# throught kwargs form as `timeout=10`
_timeout = kwargs.get("timeout", None)
_retry_times = kwargs.get("retry_times", None)
_timeout = kwargs.get("timeout")
_retry_times = kwargs.get("retry_times")
_retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True)

retry_timeout = _timeout if _timeout is not None and isinstance(_timeout, int) else None
Expand Down Expand Up @@ -174,8 +174,8 @@ def tracing_request():
def wrapper(func: Callable):
@functools.wraps(func)
def handler(self: Callable, *args, **kwargs):
level = kwargs.get("log_level", None)
req_id = kwargs.get("client_request_id", None)
level = kwargs.get("log_level")
req_id = kwargs.get("client_request_id")
if level:
self.set_onetime_loglevel(level)
if req_id:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def construct_from_dataframe(cls, name: str, dataframe: pd.DataFrame, **kwargs):
pk_index = i
if pk_index == -1:
raise SchemaNotReadyException(message=ExceptionsMessage.PrimaryKeyNotExist)
if "auto_id" in kwargs and not isinstance(kwargs.get("auto_id", None), bool):
if "auto_id" in kwargs and not isinstance(kwargs.get("auto_id"), bool):
raise AutoIDException(message=ExceptionsMessage.AutoIDType)
auto_id = kwargs.pop("auto_id", False)
if auto_id:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)
raise ClusteringKeyException(message=ExceptionsMessage.IsClusteringKeyType)
self.is_partition_key = kwargs.get("is_partition_key", False)
self.is_clustering_key = kwargs.get("is_clustering_key", False)
self.element_type = kwargs.get("element_type", None)
self.element_type = kwargs.get("element_type")
if "mmap_enabled" in kwargs:
self._type_params["mmap_enabled"] = kwargs["mmap_enabled"]
self._parse_type_params()
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def list_indexes(
:rtype: str list
"""
indexes = _get_connection(using).list_indexes(collection_name, timeout, **kwargs)
field_name = kwargs.get("field_name", None)
field_name = kwargs.get("field_name")
index_name_list = []
for index in indexes:
if index is not None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,21 @@ def test_row_insert_param_with_auto_id(self):

Prepare.row_insert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

def test_row_upsert_param_with_auto_id(self):
import numpy as np
rng = np.random.default_rng(seed=19530)
dim = 8
schema = CollectionSchema([
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=dim),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True),
FieldSchema("float", DataType.DOUBLE)
])
rows = [
{"pk_field":1, "float": 1.0, "float_vector": rng.random((1, dim))[0], "a": 1},
{"pk_field":2, "float": 1.0, "float_vector": rng.random((1, dim))[0], "b": 1},
]

Prepare.row_upsert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

class TestAlterCollectionRequest:
def test_alter_collection_request(self):
Expand Down