Skip to content

Commit

Permalink
fix: upsert rows when set autoid==true fail
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Oct 9, 2024
1 parent c7de801 commit 9b22e8a
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pymilvus/bulk_writer/local_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def commit(self, **kwargs):
f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}"
)
_async = kwargs.get("_async", False)
call_back = kwargs.get("call_back", None)
call_back = kwargs.get("call_back")

x = Thread(target=self._flush, args=(call_back,))
logger.info(f"Flush thread begin, name: {x.name}")
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/asynch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def result(self, **kwargs):
self.exception()
with self._condition:
# future not finished. wait callback being called.
to = kwargs.get("timeout", None)
to = kwargs.get("timeout")
if to is None:
to = self._kwargs.get("timeout", None)

Expand Down
36 changes: 18 additions & 18 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def __init__(
self._address = addr if addr is not None else self.__get_address(uri, host, port)
self._log_level = None
self._request_id = None
self._user = kwargs.get("user", None)
self._user = kwargs.get("user")
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name", None))
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel()
self.callbacks = []

Expand Down Expand Up @@ -125,9 +125,9 @@ def _set_authorization(self, **kwargs):

self._authorization_interceptor = None
self._setup_authorization_interceptor(
kwargs.get("user", None),
kwargs.get("password", None),
kwargs.get("token", None),
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)

def __enter__(self):
Expand Down Expand Up @@ -470,7 +470,7 @@ def get_partition_stats(
return response.stats

def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
schema = kwargs.get("schema", None)
schema = kwargs.get("schema")
if not schema:
schema = self.describe_collection(collection_name, timeout=timeout)

Expand Down Expand Up @@ -567,7 +567,7 @@ def batch_insert(
)
rf = self._stub.Insert.future(request, timeout=timeout)
if kwargs.get("_async", False):
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(rf, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -599,12 +599,12 @@ def delete(
partition_name,
expression,
consistency_level=kwargs.get("consistency_level", 0),
param_name=kwargs.get("param_name", None),
param_name=kwargs.get("param_name"),
)
future = self._stub.Delete.future(req, timeout=timeout)

if kwargs.get("_async", False):
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(future, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -664,7 +664,7 @@ def upsert(
)
rf = self._stub.Upsert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
cb = kwargs.get("_callback", None)
cb = kwargs.get("_callback")
f = MutationFuture(rf, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f
Expand Down Expand Up @@ -727,7 +727,7 @@ def _execute_search(
try:
if kwargs.get("_async", False):
future = self._stub.Search.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
func = kwargs.get("_callback")
return SearchFuture(future, func)

response = self._stub.Search(request, timeout=timeout)
Expand All @@ -746,7 +746,7 @@ def _execute_hybrid_search(
try:
if kwargs.get("_async", False):
future = self._stub.HybridSearch.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
func = kwargs.get("_callback")
return SearchFuture(future, func)

response = self._stub.HybridSearch(request, timeout=timeout)
Expand Down Expand Up @@ -781,7 +781,7 @@ def search(
search_data=data,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
timeout=timeout,
)

Expand Down Expand Up @@ -817,7 +817,7 @@ def hybrid_search(
round_decimal=round_decimal,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
timeout=timeout,
)

Expand Down Expand Up @@ -977,7 +977,7 @@ def _check():

index_future = CreateIndexFuture(future)
index_future.add_callback(_check)
user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
index_future.add_callback(user_cb)
return index_future
Expand Down Expand Up @@ -1252,7 +1252,7 @@ def _check():
load_partitions_future = LoadPartitionsFuture(future)
load_partitions_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
load_partitions_future.add_callback(user_cb)

Expand Down Expand Up @@ -1461,7 +1461,7 @@ def _check():
flush_future = FlushFuture(future)
flush_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
flush_future.add_callback(user_cb)

Expand Down Expand Up @@ -1988,7 +1988,7 @@ def _check():
flush_future = FlushFuture(future)
flush_future.add_callback(_check)

user_cb = kwargs.get("_callback", None)
user_cb = kwargs.get("_callback")
if user_cb:
flush_future.add_callback(user_cb)

Expand Down
77 changes: 72 additions & 5 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create_collection_request(
raise ParamError(message=msg)
req.shards_num = num_shards

num_partitions = kwargs.get("num_partitions", None)
num_partitions = kwargs.get("num_partitions")
if num_partitions is not None:
if not isinstance(num_partitions, int) or isinstance(num_partitions, bool):
msg = f"invalid num_partitions type, got {type(num_partitions)}, expected int"
Expand Down Expand Up @@ -424,6 +424,73 @@ def _parse_row_request(
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@staticmethod
def _parse_upsert_row_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
fields_info: dict,
enable_dynamic: bool,
entities: List,
):
fields_data = {
field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"])
for field in fields_info
}
field_info_map = {field["name"]: field for field in fields_info}

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

try:
for entity in entities:
if not isinstance(entity, Dict):
msg = f"expected Dict, got '{type(entity).__name__}'"
raise TypeError(msg)
for k, v in entity.items():
if k not in fields_data and not enable_dynamic:
raise DataNotMatchException(
message=ExceptionsMessage.InsertUnexpectedField % k
)

if k in fields_data:
field_info, field_data = field_info_map[k], fields_data[k]
entity_helper.pack_field_value_to_field_data(v, field_data, field_info)

json_dict = {
k: v for k, v in entity.items() if k not in fields_data and enable_dynamic
}

if enable_dynamic:
json_value = entity_helper.convert_to_json(json_dict)
d_field.scalars.json_data.data.append(json_value)

except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

request.fields_data.extend([fields_data[field["name"]] for field in fields_info])

if enable_dynamic:
request.fields_data.append(d_field)

for _, field in enumerate(fields_info):
is_dynamic = False
field_name = field["name"]

if field.get("is_dynamic", False):
is_dynamic = True

for j, entity in enumerate(entities):
if is_dynamic and field_name in entity:
raise ParamError(
message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]"
)
if (enable_dynamic and len(fields_data) != len(fields_info) + 1) or (
not enable_dynamic and len(fields_data) != len(fields_info)
):
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@classmethod
def row_insert_param(
cls,
Expand Down Expand Up @@ -466,7 +533,7 @@ def row_upsert_param(
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)

@staticmethod
def _pre_insert_batch_check(
Expand Down Expand Up @@ -986,11 +1053,11 @@ def query_request(
consistency_level=kwargs.get("consistency_level", 0),
)

limit = kwargs.get("limit", None)
limit = kwargs.get("limit")
if limit is not None:
req.query_params.append(common_types.KeyValuePair(key="limit", value=str(limit)))

offset = kwargs.get("offset", None)
offset = kwargs.get("offset")
if offset is not None:
req.query_params.append(common_types.KeyValuePair(key="offset", value=str(offset)))

Expand Down Expand Up @@ -1063,7 +1130,7 @@ def get_replicas(cls, collection_id: int):

@classmethod
def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, **kwargs):
channel_names = kwargs.get("channel_names", None)
channel_names = kwargs.get("channel_names")
req = milvus_types.ImportRequest(
collection_name=collection_name,
partition_name=partition_name,
Expand Down
8 changes: 4 additions & 4 deletions pymilvus/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def wrapper(func: Any):
def handler(*args, **kwargs):
# This has to make sure every timeout parameter is passing
# throught kwargs form as `timeout=10`
_timeout = kwargs.get("timeout", None)
_retry_times = kwargs.get("retry_times", None)
_timeout = kwargs.get("timeout")
_retry_times = kwargs.get("retry_times")
_retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True)

retry_timeout = _timeout if _timeout is not None and isinstance(_timeout, int) else None
Expand Down Expand Up @@ -174,8 +174,8 @@ def tracing_request():
def wrapper(func: Callable):
@functools.wraps(func)
def handler(self: Callable, *args, **kwargs):
level = kwargs.get("log_level", None)
req_id = kwargs.get("client_request_id", None)
level = kwargs.get("log_level")
req_id = kwargs.get("client_request_id")
if level:
self.set_onetime_loglevel(level)
if req_id:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def construct_from_dataframe(cls, name: str, dataframe: pd.DataFrame, **kwargs):
pk_index = i
if pk_index == -1:
raise SchemaNotReadyException(message=ExceptionsMessage.PrimaryKeyNotExist)
if "auto_id" in kwargs and not isinstance(kwargs.get("auto_id", None), bool):
if "auto_id" in kwargs and not isinstance(kwargs.get("auto_id"), bool):
raise AutoIDException(message=ExceptionsMessage.AutoIDType)
auto_id = kwargs.pop("auto_id", False)
if auto_id:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)
raise ClusteringKeyException(message=ExceptionsMessage.IsClusteringKeyType)
self.is_partition_key = kwargs.get("is_partition_key", False)
self.is_clustering_key = kwargs.get("is_clustering_key", False)
self.element_type = kwargs.get("element_type", None)
self.element_type = kwargs.get("element_type")
if "mmap_enabled" in kwargs:
self._type_params["mmap_enabled"] = kwargs["mmap_enabled"]
self._parse_type_params()
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def list_indexes(
:rtype: str list
"""
indexes = _get_connection(using).list_indexes(collection_name, timeout, **kwargs)
field_name = kwargs.get("field_name", None)
field_name = kwargs.get("field_name")
index_name_list = []
for index in indexes:
if index is not None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,21 @@ def test_row_insert_param_with_auto_id(self):

Prepare.row_insert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

def test_row_upsert_param_with_auto_id(self):
import numpy as np
rng = np.random.default_rng(seed=19530)
dim = 8
schema = CollectionSchema([
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=dim),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True),
FieldSchema("float", DataType.DOUBLE)
])
rows = [
{"pk_field":1, "float": 1.0, "float_vector": rng.random((1, dim))[0], "a": 1},
{"pk_field":2, "float": 1.0, "float_vector": rng.random((1, dim))[0], "b": 1},
]

Prepare.row_upsert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

class TestAlterCollectionRequest:
def test_alter_collection_request(self):
Expand Down

0 comments on commit 9b22e8a

Please sign in to comment.