Skip to content

Commit

Permalink
Add support for elastic-transport sniffing
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Oct 27, 2021
1 parent 620032d commit 1cca684
Show file tree
Hide file tree
Showing 11 changed files with 902 additions and 354 deletions.
3 changes: 1 addition & 2 deletions elasticsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@

from ._async.client import AsyncElasticsearch
from ._sync.client import Elasticsearch
from .exceptions import ElasticsearchDeprecationWarning # noqa: F401
from .exceptions import (
ApiError,
AuthenticationException,
AuthorizationException,
ConflictError,
ConnectionError,
ConnectionTimeout,
ElasticsearchDeprecationWarning,
ElasticsearchException,
ElasticsearchWarning,
NotFoundError,
Expand Down Expand Up @@ -73,5 +73,4 @@
"AuthorizationException",
"UnsupportedProductError",
"ElasticsearchWarning",
"ElasticsearchDeprecationWarning",
]
104 changes: 101 additions & 3 deletions elasticsearch/_async/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@

import logging
import warnings
from typing import Optional
from typing import Any, Callable, Dict, Optional, Union

from elastic_transport import AsyncTransport, TransportError
from elastic_transport import AsyncTransport, NodeConfig, TransportError
from elastic_transport.client_utils import DEFAULT

from ...exceptions import NotFoundError
from ...serializer import DEFAULT_SERIALIZERS
from ._base import BaseClient, resolve_auth_headers
from ._base import (
BaseClient,
create_sniff_callback,
default_sniff_callback,
resolve_auth_headers,
)
from .async_search import AsyncSearchClient
from .autoscaling import AutoscalingClient
from .cat import CatClient
Expand Down Expand Up @@ -148,9 +153,21 @@ def __init__(
sniff_on_node_failure=DEFAULT,
sniff_timeout=DEFAULT,
min_delay_between_sniffing=DEFAULT,
sniffed_node_callback: Optional[
Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]]
] = None,
meta_header=DEFAULT,
# Deprecated
timeout=DEFAULT,
randomize_hosts=DEFAULT,
host_info_callback: Optional[
Callable[
[Dict[str, Any], Dict[str, Union[str, int]]],
Optional[Dict[str, Union[str, int]]],
]
] = None,
sniffer_timeout=DEFAULT,
sniff_on_connection_fail=DEFAULT,
# Internal use only
_transport: Optional[AsyncTransport] = None,
) -> None:
Expand All @@ -170,6 +187,86 @@ def __init__(
)
request_timeout = timeout

if randomize_hosts is not DEFAULT:
if randomize_nodes_in_pool is not DEFAULT:
raise ValueError(
"Can't specify both 'randomize_hosts' and 'randomize_nodes_in_pool', "
"instead only specify 'randomize_nodes_in_pool'"
)
warnings.warn(
"The 'randomize_hosts' parameter is deprecated in favor of 'randomize_nodes_in_pool'",
category=DeprecationWarning,
stacklevel=2,
)
randomize_nodes_in_pool = randomize_hosts

if sniffer_timeout is not DEFAULT:
if min_delay_between_sniffing is not DEFAULT:
raise ValueError(
"Can't specify both 'sniffer_timeout' and 'min_delay_between_sniffing', "
"instead only specify 'min_delay_between_sniffing'"
)
warnings.warn(
"The 'sniffer_timeout' parameter is deprecated in favor of 'min_delay_between_sniffing'",
category=DeprecationWarning,
stacklevel=2,
)
min_delay_between_sniffing = sniffer_timeout

if sniff_on_connection_fail is not DEFAULT:
if sniff_on_node_failure is not DEFAULT:
raise ValueError(
"Can't specify both 'sniff_on_connection_fail' and 'sniff_on_node_failure', "
"instead only specify 'sniff_on_node_failure'"
)
warnings.warn(
"The 'sniff_on_connection_fail' parameter is deprecated in favor of 'sniff_on_node_failure'",
category=DeprecationWarning,
stacklevel=2,
)
sniff_on_node_failure = sniff_on_connection_fail

# Setting min_delay_between_sniffing=True implies sniff_before_requests=True
if min_delay_between_sniffing is not DEFAULT:
sniff_before_requests = True

sniffing_options = (
sniff_timeout,
sniff_on_start,
sniff_before_requests,
sniff_on_node_failure,
sniffed_node_callback,
min_delay_between_sniffing,
sniffed_node_callback,
)
if cloud_id is not None and any(
x is not DEFAULT and x is not None for x in sniffing_options
):
raise ValueError(
"Sniffing should not be enabled when connecting to Elastic Cloud"
)

sniff_callback = None
if host_info_callback is not None:
if sniffed_node_callback is not None:
raise ValueError(
"Can't specify both 'host_info_callback' and 'sniffed_node_callback', "
"instead only specify 'sniffed_node_callback'"
)
sniff_callback = create_sniff_callback(
host_info_callback=host_info_callback
)
elif sniffed_node_callback is not None:
sniff_callback = create_sniff_callback(
sniffed_node_callback=sniffed_node_callback
)
elif (
sniff_on_start is True
or sniff_before_requests is True
or sniff_on_node_failure is True
):
sniff_callback = default_sniff_callback

if _transport is None:
node_configs = client_node_configs(
hosts,
Expand Down Expand Up @@ -222,6 +319,7 @@ def __init__(
_transport = transport_class(
node_configs,
client_meta_service=CLIENT_META_SERVICE,
sniff_callback=sniff_callback,
**transport_kwargs,
)

Expand Down
123 changes: 118 additions & 5 deletions elasticsearch/_async/client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,31 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Collection, Mapping, Optional, Tuple, TypeVar, Union

from elastic_transport import AsyncTransport, HttpHeaders
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

from elastic_transport import AsyncTransport, HttpHeaders, NodeConfig, SniffOptions
from elastic_transport.client_utils import DEFAULT, DefaultType, resolve_default

from ...compat import urlencode
from ...exceptions import HTTP_EXCEPTIONS, ApiError, UnsupportedProductError
from .utils import _base64_auth_header
from ...exceptions import (
HTTP_EXCEPTIONS,
ApiError,
ConnectionError,
SerializationError,
UnsupportedProductError,
)
from .utils import _TYPE_ASYNC_SNIFF_CALLBACK, _base64_auth_header

SelfType = TypeVar("SelfType", bound="BaseClient")
SelfNamespacedType = TypeVar("SelfNamespacedType", bound="NamespacedClient")
Expand Down Expand Up @@ -74,6 +91,102 @@ def resolve_auth_headers(
return headers


def create_sniff_callback(
host_info_callback: Optional[
Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]]
] = None,
sniffed_node_callback: Optional[
Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]]
] = None,
) -> _TYPE_ASYNC_SNIFF_CALLBACK:
assert (host_info_callback is None) != (sniffed_node_callback is None)

# Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback'
if host_info_callback is not None:

def _sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
assert host_info_callback is not None
if (
host_info_callback( # type ignore[misc]
node_info, {"host": node_config.host, "port": node_config.port}
)
is None
):
return None
return node_config

sniffed_node_callback = _sniffed_node_callback

async def sniff_callback(
transport: AsyncTransport, sniff_options: SniffOptions
) -> List[NodeConfig]:
for _ in transport.node_pool.all():
try:
meta, node_infos = await transport.perform_request(
"GET",
"/_nodes/_all/http",
headers={"accept": "application/json"},
request_timeout=(
sniff_options.sniff_timeout
if not sniff_options.is_initial_sniff
else None
),
)
except (SerializationError, ConnectionError):
continue

if not 200 <= meta.status <= 299:
continue

node_configs = []
for node_info in node_infos.get("nodes", {}).values():
address = node_info.get("http", {}).get("publish_address")
if not address or ":" not in address:
continue

if "/" in address:
# Support 7.x host/ip:port behavior where http.publish_host has been set.
fqdn, ipaddress = address.split("/", 1)
host = fqdn
_, port_str = ipaddress.rsplit(":", 1)
port = int(port_str)
else:
host, port_str = address.rsplit(":", 1)
port = int(port_str)

assert sniffed_node_callback is not None
sniffed_node = sniffed_node_callback(
node_info, meta.node.replace(host=host, port=port)
)
if sniffed_node is None:
continue

# Use the node which was able to make the request as a base.
node_configs.append(sniffed_node)

if node_configs:
return node_configs

return []

return sniff_callback


def _default_sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
if node_info.get("roles", []) == ["master"]:
return None
return node_config


default_sniff_callback = create_sniff_callback(
sniffed_node_callback=_default_sniffed_node_callback
)


class BaseClient:
def __init__(self, _transport: AsyncTransport) -> None:
self._transport = _transport
Expand Down
2 changes: 2 additions & 0 deletions elasticsearch/_async/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from ..._sync.client.utils import (
_TYPE_ASYNC_SNIFF_CALLBACK,
_TYPE_HOSTS,
CLIENT_META_SERVICE,
SKIP_IN_PATH,
Expand All @@ -30,6 +31,7 @@

__all__ = [
"CLIENT_META_SERVICE",
"_TYPE_ASYNC_SNIFF_CALLBACK",
"_deprecated_options",
"_TYPE_HOSTS",
"SKIP_IN_PATH",
Expand Down
Loading

0 comments on commit 1cca684

Please sign in to comment.