diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index efbe9ca49..0389daec0 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -9,6 +9,27 @@ from .singleton_utils import Singleton +def validate_strs(**kwargs): + """validate if all values are legal non-emtpy str""" + invalid_pair = {k: v for k, v in kwargs.items() if not validate_str(v)} + if invalid_pair: + msg = f"Illegal str variables: {invalid_pair}, expect non-empty str" + raise ParamError(message=msg) + + +def validate_nullable_strs(**kwargs): + """validate if all values are either None or legal non-empty str""" + invalid_pair = {k: v for k, v in kwargs.items() if v is not None and not validate_str(v)} + if invalid_pair: + msg = f"Illegal nullable str variables: {invalid_pair}, expect None or non-empty str" + raise ParamError(message=msg) + + +def validate_str(var: Any) -> bool: + """check if a variable is legal non-empty str""" + return var and isinstance(var, str) + + def is_legal_address(addr: Any) -> bool: if not isinstance(addr, str): return False @@ -60,7 +81,7 @@ def is_legal_index_size(index_size: Any) -> bool: def is_legal_table_name(table_name: Any) -> bool: - return table_name and isinstance(table_name, str) + return validate_str(table_name) def is_legal_db_name(db_name: Any) -> bool: diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 6889c7562..bd7418b4a 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -595,17 +595,15 @@ def delete( check_pass_param(collection_name=collection_name, timeout=timeout) try: req = Prepare.delete_request( - collection_name, - partition_name, - expression, - consistency_level=kwargs.get("consistency_level", 0), - param_name=kwargs.pop("param_name", None), + collection_name=collection_name, + filter=expression, + partition_name=partition_name, + consistency_level=kwargs.pop("consistency_level", 0), **kwargs, ) future = self._stub.Delete.future(req, timeout=timeout) - if kwargs.get("_async", False): - cb = kwargs.get("_callback") + cb = kwargs.pop("_callback") f = MutationFuture(future, cb, timeout=timeout, **kwargs) f.add_callback(ts_utils.update_ts_on_mutation(collection_name)) return f diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 8a4fc9ead..380939b38 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -11,7 +11,7 @@ from pymilvus.orm.schema import CollectionSchema from pymilvus.orm.types import infer_dtype_by_scalar_data -from . import __version__, blob, entity_helper, ts_utils, utils +from . import __version__, blob, check, entity_helper, ts_utils, utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, @@ -660,29 +660,21 @@ def batch_upsert_param( def delete_request( cls, collection_name: str, - partition_name: str, - expr: str, - consistency_level: Optional[Union[int, str]], + filter: str, + partition_name: Optional[str] = None, + consistency_level: Optional[Union[int, str]] = None, **kwargs, ): - def check_str(instr: str, prefix: str): - if instr is None: - raise ParamError(message=f"{prefix} cannot be None") - if not isinstance(instr, str): - raise ParamError(message=f"{prefix} value {instr} is illegal") - if len(instr) == 0: - raise ParamError(message=f"{prefix} cannot be empty") - - check_str(collection_name, "collection_name") - if partition_name is not None and partition_name != "": - check_str(partition_name, "partition_name") - param_name = kwargs.get("param_name", "expr") - check_str(expr, param_name) + check.validate_strs( + collection_name=collection_name, + filter=filter, + ) + check.validate_nullable_strs(partition_name=partition_name) return milvus_types.DeleteRequest( collection_name=collection_name, partition_name=partition_name, - expr=expr, + expr=filter, consistency_level=get_consistency_level(consistency_level), expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), ) diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index a25d5c620..988747ccf 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -400,8 +400,8 @@ def search( limit=limit, output_fields=output_fields, partition_names=partition_names, - timeout=timeout, expr_params=kwargs.pop("filter_params", {}), + timeout=timeout, **kwargs, ) except Exception as ex: @@ -543,10 +543,10 @@ def delete( collection_name: str, ids: Optional[Union[list, str, int]] = None, timeout: Optional[float] = None, - filter: Optional[str] = "", - partition_name: Optional[str] = "", + filter: Optional[str] = None, + partition_name: Optional[str] = None, **kwargs, - ) -> Dict: + ) -> Dict[str, int]: """Delete entries in the collection by their pk or by filter. Starting from version 2.3.2, Milvus no longer includes the primary keys in the result @@ -558,14 +558,17 @@ def delete( Milvus(previous 2.3.2) is not empty, the list of primary keys is still returned. Args: - ids (list, str, int): The pk's to delete. Depending on pk_field type it can be int - or str or alist of either. Default to None. - filter(str, optional): A filter to use for the deletion. Defaults to empty. + ids (list, str, int, optional): The pk's to delete. + Depending on pk_field type it can be int or str or a list of either. + Default to None. + filter(str, optional): A filter to use for the deletion. Defaults to none. timeout (int, optional): Timeout to use, overides the client level assigned at init. Defaults to None. + Note: You need to passin either ids or filter, and they cannot be used at the same time. + Returns: - Dict: Number of rows that were deleted. + Dict: with key 'deleted_count' and value number of rows that were deleted. """ pks = kwargs.get("pks", []) if isinstance(pks, (int, str)): @@ -589,35 +592,32 @@ def delete( msg = f"wrong type of argument ids, expect list, int or str, got '{type(ids).__name__}'" raise TypeError(msg) + # validate ambiguous delete filter param before describe collection rpc + if filter and len(pks) > 0: + raise ParamError(message=ExceptionsMessage.AmbiguousDeleteFilterParam) + expr = "" conn = self._get_connection() - if pks: + if len(pks) > 0: try: schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) except Exception as ex: logger.error("Failed to describe collection: %s", collection_name) raise ex from ex - expr = self._pack_pks_expr(schema_dict, pks) - - if filter: - if expr: - raise ParamError(message=ExceptionsMessage.AmbiguousDeleteFilterParam) - + else: if not isinstance(filter, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) - expr = filter ret_pks = [] try: res = conn.delete( - collection_name, - expr, - partition_name, - timeout=timeout, - param_name="filter or ids", + collection_name=collection_name, + expression=expr, + partition_name=partition_name, expr_params=kwargs.pop("filter_params", {}), + timeout=timeout, **kwargs, ) if res.primary_keys: @@ -626,6 +626,7 @@ def delete( logger.error("Failed to delete primary keys in collection: %s", collection_name) raise ex from ex + # compatible with deletions that returns primary keys if ret_pks: return ret_pks diff --git a/tests/test_prepare.py b/tests/test_prepare.py index b339abda1..ea5b76694 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -8,6 +8,18 @@ class TestPrepare: + @pytest.mark.parametrize("coll_name", [None, "", -1, 1.1, []]) + @pytest.mark.parametrize("expr", [None, "", -1, 1.1, []]) + def test_delete_request_wrong_coll_name(self, coll_name: str, expr: str): + with pytest.raises(MilvusException): + Prepare.delete_request(coll_name, expr, None, 0) + + @pytest.mark.parametrize("part_name", []) + def test_delete_request_wrong_part_name(self, part_name): + with pytest.raises(MilvusException): + Prepare.delete_request("coll", "id>1", part_name, 0) + + def test_search_requests_with_expr_offset(self): fields = [ FieldSchema("pk", DataType.INT64, is_primary=True), @@ -42,7 +54,7 @@ def test_search_requests_with_expr_offset(self): params = json.loads(p.value) if PAGE_RETAIN_ORDER_FIELD in params: page_retain_order_exists = True - assert params[PAGE_RETAIN_ORDER_FIELD] == True + assert params[PAGE_RETAIN_ORDER_FIELD] is True assert offset_exists is True assert page_retain_order_exists is True @@ -112,7 +124,7 @@ def test_get_schema_from_collection_schema(self): c_schema = Prepare.get_schema_from_collection_schema("random", schema) - assert c_schema.enable_dynamic_field == False + assert c_schema.enable_dynamic_field is False assert c_schema.name == "random" assert len(c_schema.fields) == 2 assert c_schema.fields[0].name == "field_vector"