Skip to content

Commit

Permalink
Make parse_grpc_uri() compatible with grpcio
Browse files Browse the repository at this point in the history
Now `parse_grpc_uri()` takes the channel type as an extra argument and
will create the appropriate type of channel based on this.

Signed-off-by: Leandro Lucarella <[email protected]>
  • Loading branch information
llucax committed May 31, 2024
1 parent 3c7f8ff commit 08bba2a
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 63 deletions.
3 changes: 2 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

## Upgrading

<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
- `channel.parse_grpc_uri()` takes an extra argument, the channel type (which can be either `grpclib.client.Channel` or `grpcio.aio.Channel`).

## New Features

- Add a `exception` module to provide client exceptions, including gRPC errors with one subclass per gRPC error status code.
- `channel.parse_grpc_uri()` can now be used with `grpcio` too.

## Bug Fixes

Expand Down
33 changes: 26 additions & 7 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

"""Handling of gRPC channels."""

from typing import TypeVar
from urllib.parse import parse_qs, urlparse

from grpclib.client import Channel
from . import _grpchacks


def _to_bool(value: str) -> bool:
Expand All @@ -17,7 +18,13 @@ def _to_bool(value: str) -> bool:
raise ValueError(f"Invalid boolean value '{value}'")


def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel:
ChannelT = TypeVar("ChannelT", _grpchacks.GrpclibChannel, _grpchacks.GrpcioChannel)
"""A `grpclib` or `grpcio` channel type."""


def parse_grpc_uri(
uri: str, channel_type: type[ChannelT], /, *, default_port: int = 9090
) -> ChannelT:
"""Create a grpclib client channel from a URI.
The URI must have the following format:
Expand All @@ -38,6 +45,7 @@ def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel:
Args:
uri: The gRPC URI specifying the connection parameters.
channel_type: The type of channel to create.
default_port: The default port number to use if the URI does not specify one.
Returns:
Expand Down Expand Up @@ -68,8 +76,19 @@ def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel:
uri,
)

return Channel(
host=parsed_uri.hostname,
port=parsed_uri.port or default_port,
ssl=ssl,
)
host = parsed_uri.hostname
port = parsed_uri.port or default_port
match channel_type:
case _grpchacks.GrpcioChannel:
target = f"{host}:{port}"
return (
_grpchacks.grpcio_secure_channel(
target, _grpchacks.grpcio_ssl_channel_credentials()
)
if ssl
else _grpchacks.grpcio_insecure_channel(target)
)
case _grpchacks.GrpclibChannel:
return _grpchacks.GrpclibChannel(host=host, port=port, ssl=ssl)
case _:
assert False, "Unexpected channel type: {channel_type}"
159 changes: 104 additions & 55 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,83 +3,132 @@

"""Test cases for the channel module."""

import unittest.mock
from dataclasses import dataclass
from unittest import mock

import pytest

from frequenz.client.base import _grpchacks
from frequenz.client.base.channel import parse_grpc_uri

VALID_URLS = [
("grpc://localhost", "localhost", 9090, False),
("grpc://localhost:1234", "localhost", 1234, False),
("grpc://localhost:1234?ssl=true", "localhost", 1234, True),
("grpc://localhost:1234?ssl=false", "localhost", 1234, False),
("grpc://localhost:1234?ssl=1", "localhost", 1234, True),
("grpc://localhost:1234?ssl=0", "localhost", 1234, False),
("grpc://localhost:1234?ssl=on", "localhost", 1234, True),
("grpc://localhost:1234?ssl=off", "localhost", 1234, False),
("grpc://localhost:1234?ssl=TRUE", "localhost", 1234, True),
("grpc://localhost:1234?ssl=FALSE", "localhost", 1234, False),
("grpc://localhost:1234?ssl=ON", "localhost", 1234, True),
("grpc://localhost:1234?ssl=OFF", "localhost", 1234, False),
("grpc://localhost:1234?ssl=0&ssl=1", "localhost", 1234, True),
("grpc://localhost:1234?ssl=1&ssl=0", "localhost", 1234, False),
]

@dataclass(frozen=True)
class _FakeChannel:
host: str
port: int
ssl: bool


@pytest.mark.parametrize(
"uri, host, port, ssl",
[
("grpc://localhost", "localhost", 9090, False),
("grpc://localhost:1234", "localhost", 1234, False),
("grpc://localhost:1234?ssl=true", "localhost", 1234, True),
("grpc://localhost:1234?ssl=false", "localhost", 1234, False),
("grpc://localhost:1234?ssl=1", "localhost", 1234, True),
("grpc://localhost:1234?ssl=0", "localhost", 1234, False),
("grpc://localhost:1234?ssl=on", "localhost", 1234, True),
("grpc://localhost:1234?ssl=off", "localhost", 1234, False),
("grpc://localhost:1234?ssl=TRUE", "localhost", 1234, True),
("grpc://localhost:1234?ssl=FALSE", "localhost", 1234, False),
("grpc://localhost:1234?ssl=ON", "localhost", 1234, True),
("grpc://localhost:1234?ssl=OFF", "localhost", 1234, False),
("grpc://localhost:1234?ssl=0&ssl=1", "localhost", 1234, True),
("grpc://localhost:1234?ssl=1&ssl=0", "localhost", 1234, False),
],
)
def test_parse_uri_ok(
@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpclib_parse_uri_ok(
uri: str,
host: str,
port: int,
ssl: bool,
) -> None:
"""Test successful parsing of gRPC URIs."""
with unittest.mock.patch(
"frequenz.client.base.channel.Channel",
"""Test successful parsing of gRPC URIs using grpclib."""

@dataclass(frozen=True)
class _FakeChannel:
host: str
port: int
ssl: bool

with mock.patch(
"frequenz.client.base.channel._grpchacks.GrpclibChannel",
return_value=_FakeChannel(host, port, ssl),
):
channel = parse_grpc_uri(uri)
channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel)

assert isinstance(channel, _FakeChannel)
assert channel.host == host
assert channel.port == port
assert channel.ssl == ssl


@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpcio_parse_uri_ok(
uri: str,
host: str,
port: int,
ssl: bool,
) -> None:
"""Test successful parsing of gRPC URIs using grpcio."""
expected_channel = mock.MagicMock(
name="mock_channel", spec=_grpchacks.GrpcioChannel
)
expected_credentials = mock.MagicMock(
name="mock_credentials", spec=_grpchacks.GrpcioChannel
)

with (
mock.patch(
"frequenz.client.base.channel._grpchacks.grpcio_insecure_channel",
return_value=expected_channel,
) as insecure_channel_mock,
mock.patch(
"frequenz.client.base.channel._grpchacks.grpcio_secure_channel",
return_value=expected_channel,
) as secure_channel_mock,
mock.patch(
"frequenz.client.base.channel._grpchacks.grpcio_ssl_channel_credentials",
return_value=expected_credentials,
) as ssl_channel_credentials_mock,
):
channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel)

assert channel == expected_channel
expected_target = f"{host}:{port}"
if ssl:
ssl_channel_credentials_mock.assert_called_once_with()
secure_channel_mock.assert_called_once_with(
expected_target, expected_credentials
)
else:
insecure_channel_mock.assert_called_once_with(expected_target)


INVALID_URLS = [
("http://localhost", "Invalid scheme 'http' in the URI, expected 'grpc'"),
("grpc://", "Host name is missing in URI 'grpc://'"),
("grpc://localhost:1234?ssl=invalid", "Invalid boolean value 'invalid'"),
("grpc://localhost:1234?ssl=1&ssl=invalid", "Invalid boolean value 'invalid'"),
("grpc://:1234", "Host name is missing"),
("grpc://host:1234;param", "Port could not be cast to integer value"),
("grpc://host:1234/path", "Unexpected path '/path'"),
("grpc://host:1234#frag", "Unexpected fragment 'frag'"),
("grpc://user@host:1234", "Unexpected username 'user'"),
("grpc://:pass@host:1234?user:pass", "Unexpected password 'pass'"),
(
"grpc://localhost?ssl=1&ssl=1&ssl=invalid",
"Invalid boolean value 'invalid'",
),
(
"grpc://localhost:1234?ssl=1&ffl=true",
"Unexpected query parameters {'ffl': 'true'}",
),
]


@pytest.mark.parametrize("uri, error_msg", INVALID_URLS)
@pytest.mark.parametrize(
"uri, error_msg",
[
("http://localhost", "Invalid scheme 'http' in the URI, expected 'grpc'"),
("grpc://", "Host name is missing in URI 'grpc://'"),
("grpc://localhost:1234?ssl=invalid", "Invalid boolean value 'invalid'"),
("grpc://localhost:1234?ssl=1&ssl=invalid", "Invalid boolean value 'invalid'"),
("grpc://:1234", "Host name is missing"),
("grpc://host:1234;param", "Port could not be cast to integer value"),
("grpc://host:1234/path", "Unexpected path '/path'"),
("grpc://host:1234#frag", "Unexpected fragment 'frag'"),
("grpc://user@host:1234", "Unexpected username 'user'"),
("grpc://:pass@host:1234?user:pass", "Unexpected password 'pass'"),
(
"grpc://localhost?ssl=1&ssl=1&ssl=invalid",
"Invalid boolean value 'invalid'",
),
(
"grpc://localhost:1234?ssl=1&ffl=true",
"Unexpected query parameters {'ffl': 'true'}",
),
],
"channel_type", [_grpchacks.GrpclibChannel, _grpchacks.GrpcioChannel], ids=str
)
def test_parse_uri_error(uri: str, error_msg: str) -> None:
"""Test parsing of invalid gRPC URIs."""
def test_grpclib_parse_uri_error(
uri: str,
error_msg: str,
channel_type: type,
) -> None:
"""Test parsing of invalid gRPC URIs for grpclib."""
with pytest.raises(ValueError, match=error_msg):
parse_grpc_uri(uri)
parse_grpc_uri(uri, channel_type)

0 comments on commit 08bba2a

Please sign in to comment.