Skip to content

Commit

Permalink
enhance: Enable set_properties and describe_database api (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliu1031 authored May 13, 2024
1 parent 7e8187a commit b362af0
Show file tree
Hide file tree
Showing 13 changed files with 490 additions and 313 deletions.
8 changes: 8 additions & 0 deletions examples/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def collection_read_write(collection, db_name):
# list collections within db1
print("\nlist collections of database db1:")
print(utility.list_collections())

# set properties of db1
print("\nset properties of db1:")
db_info = db.describe_database(db_name="db1")
print(db_info)
db.set_properties(db_name="db1", properties={"key": "value"})
db_info = db.describe_database(db_name="db1")
print(db_info)

print("\ndrop collection: col1_db2 from db1")
col2_db1.drop()
Expand Down
16 changes: 16 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
BulkInsertState,
CompactionPlans,
CompactionState,
DatabaseInfo,
DataType,
ExtraList,
GrantInfo,
Expand Down Expand Up @@ -1290,6 +1291,21 @@ def list_database(self, timeout: Optional[float] = None):
check_status(response.status)
return list(response.db_names)

@retry_on_rpc_failure()
def alter_database(
self, db_name: str, properties: dict, timeout: Optional[float] = None, **kwargs
):
request = Prepare.alter_database_req(db_name, properties)
status = self._stub.AlterDatabase(request, timeout=timeout)
check_status(status)

@retry_on_rpc_failure()
def describe_database(self, db_name: str, timeout: Optional[float] = None):
request = Prepare.describe_database_req(db_name=db_name)
resp = self._stub.DescribeDatabase(request, timeout=timeout)
check_status(resp.status)
return DatabaseInfo(resp)

@retry_on_rpc_failure()
def get_load_state(
self,
Expand Down
11 changes: 11 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,3 +1238,14 @@ def drop_database_req(cls, db_name: str):
@classmethod
def list_database_req(cls):
return milvus_types.ListDatabasesRequest()

@classmethod
def alter_database_req(cls, db_name: str, properties: Dict):
check_pass_param(db_name=db_name)
kvs = [common_types.KeyValuePair(key=k, value=str(v)) for k, v in properties.items()]
return milvus_types.AlterDatabaseRequest(db_name=db_name, properties=kvs)

@classmethod
def describe_database_req(cls, db_name: str):
check_pass_param(db_name=db_name)
return milvus_types.DescribeDatabaseRequest(db_name=db_name)
29 changes: 29 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,32 @@ def get_cost_extra(status: Optional[common_pb2.Status] = None):
# Construct extra dict, the cost unit is the vcu, similar to tokenlike the
def construct_cost_extra(cost: int):
return {"cost": cost}


class DatabaseInfo:
"""
Represents the information of a database.
Atributes:
name (str): The name of the database.
properties (dict): The properties of the database.
Example:
DatabaseInfo(name="test_db", id=1, properties={"key": "value"})
"""

@property
def name(self) -> str:
return self._name

@property
def properties(self) -> Dict:
return self._properties

def __init__(self, info: Any) -> None:
self._name = info.db_name
self._properties = {}

for p in info.properties:
self.properties[p.key] = p.value

def __str__(self) -> str:
return f"DatabaseInfo(name={self.name}, properties={self.properties})"
36 changes: 18 additions & 18 deletions pymilvus/grpc_gen/common_pb2.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions pymilvus/grpc_gen/common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class MsgType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
CreateDatabase: _ClassVar[MsgType]
DropDatabase: _ClassVar[MsgType]
ListDatabases: _ClassVar[MsgType]
AlterDatabase: _ClassVar[MsgType]
DescribeDatabase: _ClassVar[MsgType]

class DslType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
Expand Down Expand Up @@ -292,6 +294,7 @@ class ObjectPrivilege(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
PrivilegeListAliases: _ClassVar[ObjectPrivilege]
PrivilegeUpdateResourceGroups: _ClassVar[ObjectPrivilege]
PrivilegeAlterDatabase: _ClassVar[ObjectPrivilege]
PrivilegeDescribeDatabase: _ClassVar[ObjectPrivilege]

class StateCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
Expand Down Expand Up @@ -490,6 +493,8 @@ UpdateResourceGroups: MsgType
CreateDatabase: MsgType
DropDatabase: MsgType
ListDatabases: MsgType
AlterDatabase: MsgType
DescribeDatabase: MsgType
Dsl: DslType
BoolExprV1: DslType
UndefiedState: CompactionState
Expand Down Expand Up @@ -560,6 +565,7 @@ PrivilegeDescribeAlias: ObjectPrivilege
PrivilegeListAliases: ObjectPrivilege
PrivilegeUpdateResourceGroups: ObjectPrivilege
PrivilegeAlterDatabase: ObjectPrivilege
PrivilegeDescribeDatabase: ObjectPrivilege
Initializing: StateCode
Healthy: StateCode
Abnormal: StateCode
Expand Down
492 changes: 250 additions & 242 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

35 changes: 33 additions & 2 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ class FlushRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_names: _Optional[_Iterable[str]] = ...) -> None: ...

class FlushResponse(_message.Message):
__slots__ = ("status", "db_name", "coll_segIDs", "flush_coll_segIDs", "coll_seal_times", "coll_flush_ts")
__slots__ = ("status", "db_name", "coll_segIDs", "flush_coll_segIDs", "coll_seal_times", "coll_flush_ts", "channel_cps")
class CollSegIDsEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
Expand Down Expand Up @@ -860,19 +860,28 @@ class FlushResponse(_message.Message):
key: str
value: int
def __init__(self, key: _Optional[str] = ..., value: _Optional[int] = ...) -> None: ...
class ChannelCpsEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: _msg_pb2.MsgPosition
def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_msg_pb2.MsgPosition, _Mapping]] = ...) -> None: ...
STATUS_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLL_SEGIDS_FIELD_NUMBER: _ClassVar[int]
FLUSH_COLL_SEGIDS_FIELD_NUMBER: _ClassVar[int]
COLL_SEAL_TIMES_FIELD_NUMBER: _ClassVar[int]
COLL_FLUSH_TS_FIELD_NUMBER: _ClassVar[int]
CHANNEL_CPS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
db_name: str
coll_segIDs: _containers.MessageMap[str, _schema_pb2.LongArray]
flush_coll_segIDs: _containers.MessageMap[str, _schema_pb2.LongArray]
coll_seal_times: _containers.ScalarMap[str, int]
coll_flush_ts: _containers.ScalarMap[str, int]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., db_name: _Optional[str] = ..., coll_segIDs: _Optional[_Mapping[str, _schema_pb2.LongArray]] = ..., flush_coll_segIDs: _Optional[_Mapping[str, _schema_pb2.LongArray]] = ..., coll_seal_times: _Optional[_Mapping[str, int]] = ..., coll_flush_ts: _Optional[_Mapping[str, int]] = ...) -> None: ...
channel_cps: _containers.MessageMap[str, _msg_pb2.MsgPosition]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., db_name: _Optional[str] = ..., coll_segIDs: _Optional[_Mapping[str, _schema_pb2.LongArray]] = ..., flush_coll_segIDs: _Optional[_Mapping[str, _schema_pb2.LongArray]] = ..., coll_seal_times: _Optional[_Mapping[str, int]] = ..., coll_flush_ts: _Optional[_Mapping[str, int]] = ..., channel_cps: _Optional[_Mapping[str, _msg_pb2.MsgPosition]] = ...) -> None: ...

class QueryRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "expr", "output_fields", "partition_names", "travel_timestamp", "guarantee_timestamp", "query_params", "not_return_all_meta", "consistency_level", "use_default_consistency")
Expand Down Expand Up @@ -1897,6 +1906,28 @@ class AlterDatabaseRequest(_message.Message):
properties: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., db_id: _Optional[str] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class DescribeDatabaseRequest(_message.Message):
__slots__ = ("base", "db_name")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ...) -> None: ...

class DescribeDatabaseResponse(_message.Message):
__slots__ = ("status", "db_name", "dbID", "created_timestamp", "properties")
STATUS_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
DBID_FIELD_NUMBER: _ClassVar[int]
CREATED_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
PROPERTIES_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
db_name: str
dbID: int
created_timestamp: int
properties: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., db_name: _Optional[str] = ..., dbID: _Optional[int] = ..., created_timestamp: _Optional[int] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class ReplicateMessageRequest(_message.Message):
__slots__ = ("base", "channel_name", "BeginTs", "EndTs", "Msgs", "StartPositions", "EndPositions")
BASE_FIELD_NUMBER: _ClassVar[int]
Expand Down
33 changes: 33 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,11 @@ def __init__(self, channel):
request_serializer=milvus__pb2.AlterDatabaseRequest.SerializeToString,
response_deserializer=common__pb2.Status.FromString,
)
self.DescribeDatabase = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/DescribeDatabase',
request_serializer=milvus__pb2.DescribeDatabaseRequest.SerializeToString,
response_deserializer=milvus__pb2.DescribeDatabaseResponse.FromString,
)
self.ReplicateMessage = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/ReplicateMessage',
request_serializer=milvus__pb2.ReplicateMessageRequest.SerializeToString,
Expand Down Expand Up @@ -957,6 +962,12 @@ def AlterDatabase(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def DescribeDatabase(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ReplicateMessage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -1386,6 +1397,11 @@ def add_MilvusServiceServicer_to_server(servicer, server):
request_deserializer=milvus__pb2.AlterDatabaseRequest.FromString,
response_serializer=common__pb2.Status.SerializeToString,
),
'DescribeDatabase': grpc.unary_unary_rpc_method_handler(
servicer.DescribeDatabase,
request_deserializer=milvus__pb2.DescribeDatabaseRequest.FromString,
response_serializer=milvus__pb2.DescribeDatabaseResponse.SerializeToString,
),
'ReplicateMessage': grpc.unary_unary_rpc_method_handler(
servicer.ReplicateMessage,
request_deserializer=milvus__pb2.ReplicateMessageRequest.FromString,
Expand Down Expand Up @@ -2829,6 +2845,23 @@ def AlterDatabase(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def DescribeDatabase(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/DescribeDatabase',
milvus__pb2.DescribeDatabaseRequest.SerializeToString,
milvus__pb2.DescribeDatabaseResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def ReplicateMessage(request,
target,
Expand Down
Loading

0 comments on commit b362af0

Please sign in to comment.