Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support http2 keep-alive #90

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

* The `ExponentialBackoff` and `LinearBackoff` classes now require keyword arguments for their constructor. This change was made to make the classes easier to use and to avoid confusion with the order of the arguments.

- HTTP2 keep-alive is now enabled by default, with an interval of 60 seconds between pings, and a 20 second timeout for responses from the service. These values are configurable and may be updated based on specific requirements.
daniel-zullo-frequenz marked this conversation as resolved.
Show resolved Hide resolved

## New Features

<!-- Here goes the main new features and examples or instructions on how to use them -->
- Added support for HTTP2 keep-alive.

## Bug Fixes

Expand Down
105 changes: 103 additions & 2 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
import pathlib
from datetime import timedelta
from typing import assert_never
from urllib.parse import parse_qs, urlparse

Expand Down Expand Up @@ -41,6 +42,20 @@ class SslOptions:
"""


@dataclasses.dataclass(frozen=True)
class KeepAliveOptions:
"""Options for HTTP2 keep-alive pings."""

enabled: bool = True
"""Whether HTTP2 keep-alive should be enabled."""

interval: timedelta = timedelta(seconds=60)
"""The interval between HTTP2 pings."""

timeout: timedelta = timedelta(seconds=20)
"""The time to wait for a HTTP2 keep-alive response."""


@dataclasses.dataclass(frozen=True)
class ChannelOptions:
"""Options for a gRPC channel."""
Expand All @@ -51,6 +66,9 @@ class ChannelOptions:
ssl: SslOptions = SslOptions()
"""SSL options for the channel."""

keep_alive: KeepAliveOptions = KeepAliveOptions()
"""HTTP2 keep-alive options for the channel."""


def parse_grpc_uri(
uri: str,
Expand Down Expand Up @@ -120,6 +138,40 @@ def parse_grpc_uri(
parsed_uri.netloc if parsed_uri.port else f"{parsed_uri.netloc}:{defaults.port}"
)

keep_alive = (
defaults.keep_alive.enabled
if options.keep_alive is None
else options.keep_alive
)
channel_options = (
[
("grpc.http2.max_pings_without_data", 0),
("grpc.keepalive_permit_without_calls", 1),
llucax marked this conversation as resolved.
Show resolved Hide resolved
(
"grpc.keepalive_time_ms",
(
(
defaults.keep_alive.interval
if options.keep_alive_interval is None
else options.keep_alive_interval
).total_seconds()
* 1000
),
),
(
"grpc.keepalive_timeout_ms",
(
defaults.keep_alive.timeout
if options.keep_alive_timeout is None
else options.keep_alive_timeout
).total_seconds()
* 1000,
),
]
if keep_alive
else None
)

ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
if ssl:
return secure_channel(
Expand All @@ -141,8 +193,9 @@ def parse_grpc_uri(
defaults.ssl.certificate_chain,
),
),
channel_options,
)
return insecure_channel(target)
return insecure_channel(target, channel_options)


def _to_bool(value: str) -> bool:
Expand All @@ -160,6 +213,9 @@ class _QueryParams:
ssl_root_certificates_path: pathlib.Path | None
ssl_private_key_path: pathlib.Path | None
ssl_certificate_chain_path: pathlib.Path | None
keep_alive: bool | None
keep_alive_interval: timedelta | None
keep_alive_timeout: timedelta | None


def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
Expand Down Expand Up @@ -200,6 +256,26 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
)

keep_alive_option = options.pop("keep_alive", None)
keep_alive: bool | None = None
if keep_alive_option is not None:
keep_alive = _to_bool(keep_alive_option)

keep_alive_opts = {
k: options.pop(k, None)
for k in ("keep_alive_interval_s", "keep_alive_timeout_s")
}

if keep_alive is False:
erros = []
for opt_name, opt in keep_alive_opts.items():
if opt is not None:
erros.append(opt_name)
if erros:
raise ValueError(
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled",
)

if options:
names = ", ".join(options)
raise ValueError(
Expand All @@ -209,7 +285,32 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:

return _QueryParams(
ssl=ssl,
**{k: pathlib.Path(v) if v is not None else None for k, v in ssl_opts.items()},
ssl_root_certificates_path=(
pathlib.Path(ssl_opts["ssl_root_certificates_path"])
if ssl_opts["ssl_root_certificates_path"] is not None
else None
),
ssl_private_key_path=(
pathlib.Path(ssl_opts["ssl_private_key_path"])
if ssl_opts["ssl_private_key_path"] is not None
else None
),
ssl_certificate_chain_path=(
pathlib.Path(ssl_opts["ssl_certificate_chain_path"])
if ssl_opts["ssl_certificate_chain_path"] is not None
else None
),
keep_alive=keep_alive,
keep_alive_interval=(
timedelta(seconds=float(keep_alive_opts["keep_alive_interval_s"]))
if keep_alive_opts["keep_alive_interval_s"] is not None
else None
),
keep_alive_timeout=(
timedelta(seconds=float(keep_alive_opts["keep_alive_timeout_s"]))
if keep_alive_opts["keep_alive_timeout_s"] is not None
else None
),
)


Expand Down
98 changes: 96 additions & 2 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
import pathlib
from datetime import timedelta
from unittest import mock

import pytest
Expand All @@ -13,6 +14,7 @@

from frequenz.client.base.channel import (
ChannelOptions,
KeepAliveOptions,
SslOptions,
_to_bool,
parse_grpc_uri,
Expand Down Expand Up @@ -136,6 +138,67 @@ class _ValidUrlTestCase:
),
),
),
_ValidUrlTestCase(
title="Keep-alive no defaults",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300"
+ "&keep_alive_timeout_s=60",
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(minutes=5),
timeout=timedelta(minutes=1),
),
),
),
_ValidUrlTestCase(
title="Keep-alive default timeout",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300",
defaults=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(seconds=10),
timeout=timedelta(seconds=2),
),
),
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(seconds=300),
timeout=timedelta(seconds=2),
),
),
),
_ValidUrlTestCase(
title="Keep-alive default interval",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_timeout_s=60",
defaults=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True, interval=timedelta(minutes=30)
),
),
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
timeout=timedelta(minutes=1),
interval=timedelta(minutes=30),
),
),
),
_ValidUrlTestCase(
title="keep-alive disabled",
uri="grpc://localhost:1234?keep_alive=0",
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(enabled=False),
),
),
],
ids=lambda case: case.title,
)
Expand Down Expand Up @@ -198,6 +261,35 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals

assert channel == expected_channel
expected_target = f"{expected_host}:{expected_port}"
expected_keep_alive = (
expected_options.keep_alive if "keep_alive=" in uri else defaults.keep_alive
)
expected_keep_alive_interval = (
expected_keep_alive.interval
if "keep_alive_interval_s=" in uri
else defaults.keep_alive.interval
)
expected_keep_alive_timeout = (
expected_keep_alive.timeout
if "keep_alive_timeout_s=" in uri
else defaults.keep_alive.timeout
)
expected_channel_options = (
[
("grpc.http2.max_pings_without_data", 0),
("grpc.keepalive_permit_without_calls", 1),
(
"grpc.keepalive_time_ms",
(expected_keep_alive_interval.total_seconds() * 1000),
),
(
"grpc.keepalive_timeout_ms",
expected_keep_alive_timeout.total_seconds() * 1000,
),
]
if expected_keep_alive.enabled
else None
)
if expected_ssl:
if isinstance(expected_root_certificates, pathlib.Path):
get_contents_mock.assert_any_call(
Expand All @@ -223,10 +315,12 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
certificate_chain=expected_certificate_chain,
)
secure_channel_mock.assert_called_once_with(
expected_target, expected_credentials
expected_target, expected_credentials, expected_channel_options
)
else:
insecure_channel_mock.assert_called_once_with(expected_target)
insecure_channel_mock.assert_called_once_with(
expected_target, expected_channel_options
)


@pytest.mark.parametrize("value", ["true", "on", "1", "TrUe", "On", "ON", "TRUE"])
Expand Down
Loading