Skip to content

Commit

Permalink
Add TypeChecker (#1315)
Browse files Browse the repository at this point in the history
Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored Feb 27, 2023
1 parent 49c79bd commit 51fb7c1
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 26 deletions.
4 changes: 2 additions & 2 deletions pymilvus/aio/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from typing import List
import grpc

from ..grpc_gen import milvus_pb2_grpc
from ..abstract_grpc_handler import SecureMixin
from ..mix_in import SecureMixin
from ..settings import DefaultConfig
from ..grpc_gen import milvus_pb2_grpc


class AsyncGrpcHandler(milvus_pb2_grpc.MilvusServiceStub, SecureMixin):
Expand Down
51 changes: 51 additions & 0 deletions pymilvus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,57 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

from typing import List, Any
from .exceptions import ParamError


# TODO
def generate_address(host, port) -> str:
return f"{host}:{port}"


def str_checker(var: Any) -> bool:
return isinstance(var, str)

def list_str_checker(var: Any) -> bool:
if var and isinstance(var, list):
return all(isinstance(v, str) for v in var)

return False

def int_checker(var: Any) -> bool:
return isinstance(var, int)


CHECKERS = {
"collection_name": str_checker,
"partition_name": str_checker,
"index_name": str_checker,
"field_name": str_checker,
"alias": str_checker,
"name": str_checker,
"round_decimal": int_checker,
"num_replica": int_checker,
"dim": int_checker,
"partition_names": list_str_checker,
"output_fields": list_str_checker,
# TODO
"anns_field": str_checker,
"limit": int_checker,
"topk": int_checker,
}

class TypeChecker:
@classmethod
def check(cls, **kwargs):
for k, var in kwargs.items():
checker = CHECKERS.get(k)
if checker is None:
raise ParamError(message=f"No checker for parameter {k}")

if not checker(var):
raise ParamError(message=f"Invalid type {type(var)} for {k}")

@property
def list_checkers(self) -> List[str]:
return CHECKERS.keys()
9 changes: 6 additions & 3 deletions pymilvus/v2/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..settings import DefaultConfig
from ..grpc_gen import milvus_pb2 as milvus_types
from ..grpc_gen import common_pb2, milvus_pb2_grpc
from ..utils import generate_address
from ..utils import generate_address, TypeChecker


class MilvusClient:
Expand Down Expand Up @@ -61,7 +61,6 @@ def get_server_version(self, **kwargs) -> str:

return resp.version

@NotImplementedError
def create_alias(self, alias: str, collection_name: str, **kwargs) -> None:
""" Create an alias for a collection
Expand All @@ -81,7 +80,11 @@ def create_alias(self, alias: str, collection_name: str, **kwargs) -> None:
>>> client = MilvusClient("localhost", "19530")
>>> client.create_alias("Gandalf", "hello_milvus")
"""
pass
TypeChecker.check(alias=alias, collection_name=collection_name)
req = milvus_types.CreateAliasRequest(alias=alias, collection_name=collection_name)
status = self.handler.CreateAlias(req)
if status.error_code != common_pb2.Success:
raise MilvusException(status.error_code, status.reason)

@NotImplementedError
def alter_alias(self, alias: str, collection_name: str, **kwargs) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/v2/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import List
import grpc

from ..abstract_grpc_handler import SecureMixin
from ..mix_in import SecureMixin
from ..settings import DefaultConfig
from ..grpc_gen import milvus_pb2_grpc, common_pb2

Expand Down
23 changes: 6 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,19 @@ def channel(request):
return channel


@pytest.fixture(scope="module")
@pytest.fixture(scope="class")
def client_thread(request):
client_execution_thread_pool = logging_pool.pool(2)

def teardown():
client_execution_thread_pool.shutdown(wait=True)
return client_execution_thread_pool

@pytest.fixture(scope="module")
def client(request):
channel = grpc_testing.channel([descriptor], grpc_testing.strict_real_time())

client = MilvusClient("fake", "fake", _channel=channel)
return client
request.addfinalizer(teardown)
return client_execution_thread_pool

@pytest.fixture(scope="function")
def rpc_future_GetVersion(client_thread):
def client_channel(request):
channel = grpc_testing.channel([descriptor], grpc_testing.strict_real_time())
client = MilvusClient("fake", "fake", _channel=channel)

get_server_version_future = client_thread.submit(client.get_server_version)
(invocation_metadata, request, rpc) = (
channel.take_unary_unary(descriptor.methods_by_name['GetVersion']))

rpc.send_initial_metadata(())
return rpc, get_server_version_future

client = MilvusClient("fake", "fake", _channel=channel)
return client, channel
87 changes: 87 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
from pymilvus.utils import TypeChecker
from pymilvus.exceptions import ParamError


class TestTypeChecker:
@pytest.mark.parametrize("type_str", [
"s",
"_",
])
def test_str_type(self, type_str):
TypeChecker.check(
collection_name=type_str,
alias=type_str,
name=type_str,
partition_name=type_str,
index_name=type_str,
field_name=type_str,
anns_field=type_str,
)

@pytest.mark.parametrize("not_type_str", [
None,
1,
1.0,
[1],
{1: 2},
{1, 2, 3},
(1, 2, 3),
])
def test_str_type_error(self, not_type_str):
with pytest.raises(ParamError):
TypeChecker.check(collection_name=not_type_str)

@pytest.mark.parametrize("type_int", [
1,
0,
-1,
1024,
65536,
])
def test_int_type(self, type_int):
TypeChecker.check(
round_decimal=type_int,
num_replica=type_int,
dim=type_int,
limit=type_int,
topk=type_int,
)

@pytest.mark.parametrize("not_type_int", [
None,
"abc",
1.0,
[1],
{1: 2},
{1, 2, 3},
(1, 2, 3),
])
def test_int_type_error(self, not_type_int):
with pytest.raises(ParamError):
TypeChecker.check(topk=not_type_int)

@pytest.mark.parametrize("type_list_str", [
["a"],
["1", "2"]
])
def test_list_str_type(self, type_list_str):
TypeChecker.check(
partition_names=type_list_str,
output_fields=type_list_str,
)

@pytest.mark.parametrize("not_type_list_str", [
None,
"abc",
1.0,
{1: 2},
{1, 2, 3},
(1, 2, 3),
[1],
[],
[1, "1"],
])
def test_int_type_error(self, not_type_list_str):
with pytest.raises(ParamError):
TypeChecker.check(topk=not_type_list_str)
43 changes: 40 additions & 3 deletions tests/test_v2_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@
from pymilvus.exceptions import MilvusException
from pymilvus.grpc_gen import common_pb2, milvus_pb2

descriptor = milvus_pb2.DESCRIPTOR.services_by_name['MilvusService']

def prep_channel(channel, method_name):
(invocation_metadata, request, rpc) = (
channel.take_unary_unary(descriptor.methods_by_name[method_name]))

rpc.send_initial_metadata(())
return rpc


@pytest.mark.usefixtures("client_thread")
class TestGetServerVersion:

@pytest.mark.parametrize("error_code", [
common_pb2.Success,
common_pb2.UnexpectedError,
common_pb2.ConnectFailed,
])
def test_normal(self, rpc_future_GetVersion, error_code):
rpc, future = rpc_future_GetVersion
def test_normal(self, client_channel, client_thread, error_code):
client, channel = client_channel
future = client_thread.submit(client.get_server_version)

rpc = prep_channel(channel, 'GetVersion')

reason = f"error: {error_code}" if error_code != common_pb2.Success else ""
reason = f"mock error: {error_code}" if error_code != common_pb2.Success else ""
version = "test.test.test" if error_code == common_pb2.Success else ""

expected_result = milvus_pb2.GetVersionResponse(
Expand All @@ -30,3 +44,26 @@ def test_normal(self, rpc_future_GetVersion, error_code):
else:
got_result = future.result()
assert got_result == version

class TestCreateAlias:
@pytest.mark.parametrize("error_code", [
common_pb2.Success,
common_pb2.UnexpectedError,
])
def test_normal(self, client_channel, client_thread, error_code):
client, channel = client_channel
future = client_thread.submit(client.create_alias, "alias", "coll")

rpc = prep_channel(channel, 'CreateAlias')

reason = f"mock error: {error_code}" if error_code != common_pb2.Success else ""

expected_result = common_pb2.Status(error_code=error_code, reason=reason)
rpc.terminate(expected_result, (), grpc.StatusCode.OK, '')

if error_code != common_pb2.Success:
with pytest.raises(MilvusException) as excinfo:
future.result()
assert error_code == excinfo.value.code
else:
future.result()

0 comments on commit 51fb7c1

Please sign in to comment.