diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 6d2e1c7d2..4ae6e54d7 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -220,7 +220,6 @@ class ExceptionsMessage: "Ambiguous parameter, either ids or filter should be specified, cannot support both." ) JSONKeyMustBeStr = "JSON key must be str." - ClusteringKeyNotPrimary = "Clustering key field should not be primary field" ClusteringKeyType = ( "Clustering key field type must be DataType.INT8, DataType.INT16, " "DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE, " diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 89e08dd27..3729a40bd 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -62,12 +62,8 @@ def validate_partition_key( ) -def validate_clustering_key( - clustering_key_field_name: Any, clustering_key_field: Any, primary_field_name: Any -): +def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field: Any): if clustering_key_field is not None: - if clustering_key_field.name == primary_field_name: - raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyNotPrimary) if clustering_key_field.dtype not in [ DataType.INT8, DataType.INT16, @@ -81,7 +77,7 @@ def validate_clustering_key( raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyType) elif clustering_key_field_name is not None: raise ClusteringKeyException( - message=ExceptionsMessage.PartitionKeyFieldNotExist % clustering_key_field_name + message=ExceptionsMessage.ClusteringKeyFieldNotExist % clustering_key_field_name ) @@ -170,9 +166,7 @@ def _check_fields(self): validate_partition_key( partition_key_field_name, self._partition_key_field, self._primary_field.name ) - validate_clustering_key( - clustering_key_field_name, self._clustering_key_field, self._primary_field.name - ) + validate_clustering_key(clustering_key_field_name, self._clustering_key_field) auto_id = self._kwargs.get("auto_id", False) if auto_id: diff --git a/tests/test_create_collection.py b/tests/test_create_collection.py index 11fc1d502..1b6b23425 100644 --- a/tests/test_create_collection.py +++ b/tests/test_create_collection.py @@ -202,3 +202,45 @@ def test_create_bf16_collection(self, collection_name): return_value = future.result() assert return_value.code == 0 assert return_value.reason == "success" + + def test_create_clustering_key_collection(self, collection_name): + id_field = { + "name": "my_id", + "type": DataType.INT64, + "auto_id": True, + "is_primary": True, + "is_clustering_key": True, + } + vector_field = { + "name": "embedding", + "type": DataType.FLOAT_VECTOR, + "metric_type": "L2", + "params": {"dim": "4"}, + } + fields = {"fields": [id_field, vector_field]} + future = self._milvus.create_collection( + collection_name=collection_name, fields=fields, _async=True + ) + + invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary( + self._servicer.methods_by_name["CreateCollection"] + ) + rpc.send_initial_metadata(()) + rpc.terminate( + common_pb2.Status( + code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success" + ), + (), + grpc.StatusCode.OK, + "", + ) + + request_schema = schema_pb2.CollectionSchema() + request_schema.ParseFromString(request.schema) + + assert request.collection_name == collection_name + assert Fields.equal(request_schema.fields, fields["fields"]) + + return_value = future.result() + assert return_value.code == 0 + assert return_value.reason == "success"