From 9fe4639611d870bd3d21a04eadba0a579f415e1c Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Fri, 8 May 2020 17:03:19 -0500 Subject: [PATCH] Add pytest for async REST API tests --- elasticsearch/_async/client/__init__.py | 3 + elasticsearch/_async/client/cat.py | 6 +- elasticsearch/_async/client/cluster.py | 2 +- elasticsearch/_async/client/indices.py | 14 +- elasticsearch/client/__init__.py | 3 + elasticsearch/client/cat.py | 6 +- elasticsearch/client/indices.py | 14 +- .../test_server/test_rest_api_spec.py | 434 ++++++++++++++++++ 8 files changed, 465 insertions(+), 17 deletions(-) create mode 100644 test_elasticsearch/test_async/test_server/test_rest_api_spec.py diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 3690d29e67..52482690d5 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -1297,6 +1297,7 @@ async def rank_eval(self, body, index=None, params=None, headers=None): @query_params( "max_docs", + "prefer_v2_templates", "refresh", "requests_per_second", "scroll", @@ -1316,6 +1317,8 @@ async def reindex(self, body, params=None, headers=None): prototype for the index request. :arg max_docs: Maximum number of documents to process (default: all documents) + :arg prefer_v2_templates: favor V2 templates instead of V1 + templates during index creation :arg refresh: Should the affected indexes be refreshed? :arg requests_per_second: The throttle to set on this request in sub-requests per second. -1 means no throttle. diff --git a/elasticsearch/_async/client/cat.py b/elasticsearch/_async/client/cat.py index f5839035d2..87bc7c8fe2 100644 --- a/elasticsearch/_async/client/cat.py +++ b/elasticsearch/_async/client/cat.py @@ -326,7 +326,7 @@ async def pending_tasks(self, params=None, headers=None): "GET", "/_cat/pending_tasks", params=params, headers=headers ) - @query_params("format", "h", "help", "local", "master_timeout", "s", "size", "v") + @query_params("format", "h", "help", "local", "master_timeout", "s", "time", "v") async def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): """ Returns cluster-wide thread pool statistics per node. By default the active, @@ -345,8 +345,8 @@ async def thread_pool(self, thread_pool_patterns=None, params=None, headers=None to master node :arg s: Comma-separated list of column names or column aliases to sort by - :arg size: The multiplier in which to display values Valid - choices: , k, m, g, t, p + :arg time: The unit in which to display time values Valid + choices: d, h, m, s, ms, micros, nanos :arg v: Verbose mode. Display column headers """ return await self.transport.perform_request( diff --git a/elasticsearch/_async/client/cluster.py b/elasticsearch/_async/client/cluster.py index 73a551f049..d58c567d13 100644 --- a/elasticsearch/_async/client/cluster.py +++ b/elasticsearch/_async/client/cluster.py @@ -137,7 +137,7 @@ async def stats(self, node_id=None, params=None, headers=None): "GET", "/_cluster/stats" if node_id in SKIP_IN_PATH - else _make_path("_cluster/stats/nodes", node_id), + else _make_path("_cluster", "stats", "nodes", node_id), params=params, headers=headers, ) diff --git a/elasticsearch/_async/client/indices.py b/elasticsearch/_async/client/indices.py index 4c1d0169f6..b37ed877b8 100644 --- a/elasticsearch/_async/client/indices.py +++ b/elasticsearch/_async/client/indices.py @@ -1301,7 +1301,7 @@ async def get_index_template(self, name=None, params=None, headers=None): "GET", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("create", "master_timeout", "order") + @query_params("cause", "create", "master_timeout") async def put_index_template(self, name, body, params=None, headers=None): """ Creates or updates an index template. @@ -1309,12 +1309,11 @@ async def put_index_template(self, name, body, params=None, headers=None): :arg name: The name of the template :arg body: The template definition + :arg cause: User defined reason for creating/updating the index + template :arg create: Whether the index template should only be added if new or can also replace an existing one :arg master_timeout: Specify timeout for connection to master - :arg order: The order for this template when merging multiple - matching ones (higher numbers are merged later, overriding the lower - numbers) """ for param in (name, body): if param in SKIP_IN_PATH: @@ -1349,7 +1348,7 @@ async def exists_index_template(self, name, params=None, headers=None): "HEAD", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("master_timeout") + @query_params("cause", "create", "master_timeout") async def simulate_index_template(self, name, body=None, params=None, headers=None): """ Simulate matching the given index name against the index templates in the @@ -1360,6 +1359,11 @@ async def simulate_index_template(self, name, body=None, params=None, headers=No name) :arg body: New index template definition, which will be included in the simulation, as if it already exists in the system + :arg cause: User defined reason for dry-run creating the new + template for simulation purposes + :arg create: Whether the index template we optionally defined in + the body should only be dry-run added if new or can also replace an + existing one :arg master_timeout: Specify timeout for connection to master """ if name in SKIP_IN_PATH: diff --git a/elasticsearch/client/__init__.py b/elasticsearch/client/__init__.py index 17c0e01bf7..44cd5db953 100644 --- a/elasticsearch/client/__init__.py +++ b/elasticsearch/client/__init__.py @@ -1297,6 +1297,7 @@ def rank_eval(self, body, index=None, params=None, headers=None): @query_params( "max_docs", + "prefer_v2_templates", "refresh", "requests_per_second", "scroll", @@ -1316,6 +1317,8 @@ def reindex(self, body, params=None, headers=None): prototype for the index request. :arg max_docs: Maximum number of documents to process (default: all documents) + :arg prefer_v2_templates: favor V2 templates instead of V1 + templates during index creation :arg refresh: Should the affected indexes be refreshed? :arg requests_per_second: The throttle to set on this request in sub-requests per second. -1 means no throttle. diff --git a/elasticsearch/client/cat.py b/elasticsearch/client/cat.py index 87422c7fc2..84282850e3 100644 --- a/elasticsearch/client/cat.py +++ b/elasticsearch/client/cat.py @@ -326,7 +326,7 @@ def pending_tasks(self, params=None, headers=None): "GET", "/_cat/pending_tasks", params=params, headers=headers ) - @query_params("format", "h", "help", "local", "master_timeout", "s", "size", "v") + @query_params("format", "h", "help", "local", "master_timeout", "s", "time", "v") def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): """ Returns cluster-wide thread pool statistics per node. By default the active, @@ -345,8 +345,8 @@ def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): to master node :arg s: Comma-separated list of column names or column aliases to sort by - :arg size: The multiplier in which to display values Valid - choices: , k, m, g, t, p + :arg time: The unit in which to display time values Valid + choices: d, h, m, s, ms, micros, nanos :arg v: Verbose mode. Display column headers """ return self.transport.perform_request( diff --git a/elasticsearch/client/indices.py b/elasticsearch/client/indices.py index 12c8c4a4a2..1aaf522aaf 100644 --- a/elasticsearch/client/indices.py +++ b/elasticsearch/client/indices.py @@ -1299,7 +1299,7 @@ def get_index_template(self, name=None, params=None, headers=None): "GET", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("create", "master_timeout", "order") + @query_params("cause", "create", "master_timeout") def put_index_template(self, name, body, params=None, headers=None): """ Creates or updates an index template. @@ -1307,12 +1307,11 @@ def put_index_template(self, name, body, params=None, headers=None): :arg name: The name of the template :arg body: The template definition + :arg cause: User defined reason for creating/updating the index + template :arg create: Whether the index template should only be added if new or can also replace an existing one :arg master_timeout: Specify timeout for connection to master - :arg order: The order for this template when merging multiple - matching ones (higher numbers are merged later, overriding the lower - numbers) """ for param in (name, body): if param in SKIP_IN_PATH: @@ -1347,7 +1346,7 @@ def exists_index_template(self, name, params=None, headers=None): "HEAD", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("master_timeout") + @query_params("cause", "create", "master_timeout") def simulate_index_template(self, name, body=None, params=None, headers=None): """ Simulate matching the given index name against the index templates in the @@ -1358,6 +1357,11 @@ def simulate_index_template(self, name, body=None, params=None, headers=None): name) :arg body: New index template definition, which will be included in the simulation, as if it already exists in the system + :arg cause: User defined reason for dry-run creating the new + template for simulation purposes + :arg create: Whether the index template we optionally defined in + the body should only be dry-run added if new or can also replace an + existing one :arg master_timeout: Specify timeout for connection to master """ if name in SKIP_IN_PATH: diff --git a/test_elasticsearch/test_async/test_server/test_rest_api_spec.py b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py new file mode 100644 index 0000000000..8927c5489c --- /dev/null +++ b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py @@ -0,0 +1,434 @@ +# Licensed to Elasticsearch B.V under one or more agreements. +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information + +""" +Dynamically generated set of TestCases based on set of yaml files decribing +some integration tests. These files are shared among all official Elasticsearch +clients. +""" +import pytest +import sys +import re +from os import walk, environ +from os.path import exists, join, dirname, pardir, relpath +import yaml +from shutil import rmtree +import warnings +import inspect + +from elasticsearch import TransportError, RequestError, ElasticsearchDeprecationWarning +from elasticsearch.compat import string_types +from elasticsearch.helpers.test import _get_version + +from ..test_cases import SkipTest + +pytestmark = pytest.mark.asyncio + +# some params had to be changed in python, keep track of them so we can rename +# those in the tests accordingly +PARAMS_RENAMES = {"type": "doc_type", "from": "from_"} + +# mapping from catch values to http status codes +CATCH_CODES = {"missing": 404, "conflict": 409, "unauthorized": 401} + +# test features we have implemented +IMPLEMENTED_FEATURES = { + "gtelte", + "stash_in_path", + "headers", + "catch_unauthorized", + "default_shards", + "warnings", +} + +# broken YAML tests on some releases +SKIP_TESTS = { + "*": { + # Can't figure out the get_alias(expand_wildcards=open) failure. + "TestIndicesGetAlias10Basic", + # Disallowing expensive queries is 7.7+ + "TestSearch320DisallowQueries", + } +} + +# Test is inconsistent due to dictionaries not being ordered. +if sys.version_info < (3, 6): + SKIP_TESTS["*"].add("TestSearchAggregation250MovingFn") + + +XPACK_FEATURES = None +ES_VERSION = None + +YAML_DIR = environ.get( + "TEST_ES_YAML_DIR", + join( + dirname(__file__), + pardir, + pardir, + pardir, + pardir, + "elasticsearch", + "rest-api-spec", + "src", + "main", + "resources", + "rest-api-spec", + "test", + ), +) + + +YAML_TEST_SPECS = [] + +if exists(YAML_DIR): + # find all the test definitions in yaml files ... + for path, _, files in walk(YAML_DIR): + for filename in files: + if not filename.endswith((".yaml", ".yml")): + continue + + filepath = join(path, filename) + with open(filepath) as f: + tests = list(yaml.load_all(f)) + + setup_code = None + teardown_code = None + run_codes = [] + for i, test in enumerate(tests): + for test_name, definition in test.items(): + if test_name == "setup": + setup_code = definition + elif test_name == "teardown": + teardown_code = definition + else: + run_codes.append((i, definition)) + + for i, run_code in run_codes: + src = {"setup": setup_code, "run": run_code, "teardown": teardown_code} + # Pytest already replaces '.' with '_' so we do + # it ourselves so UI and 'SKIP_TESTS' match. + pytest_param_id = ( + "%s[%d]" % (relpath(filepath, YAML_DIR).rpartition(".")[0], i) + ).replace(".", "_") + + if pytest_param_id in SKIP_TESTS: + src["skip"] = True + + YAML_TEST_SPECS.append(pytest.param(src, id=pytest_param_id)) + + +async def await_if_coro(x): + if inspect.iscoroutine(x): + return await x + return x + + +class YamlRunner: + def __init__(self, client): + self.client = client + self.last_response = None + + self._run_code = None + self._setup_code = None + self._teardown_code = None + self._state = {} + + def use_spec(self, test_spec): + self._setup_code = test_spec.pop("setup", None) + self._run_code = test_spec.pop("run", None) + self._teardown_code = test_spec.pop("teardown") + + async def setup(self): + if self._setup_code: + await self.run_code(self._setup_code) + + async def teardown(self): + if self._teardown_code: + await self.run_code(self._teardown_code) + + for repo, definition in ( + await self.client.snapshot.get_repository(repository="_all") + ).items(): + await self.client.snapshot.delete_repository(repository=repo) + if definition["type"] == "fs": + rmtree( + "/tmp/%s" % definition["settings"]["location"], ignore_errors=True + ) + + # stop and remove all ML stuff + if await self._feature_enabled("ml"): + await self.client.ml.stop_datafeed(datafeed_id="*", force=True) + for feed in await self.client.ml.get_datafeeds(datafeed_id="*")[ + "datafeeds" + ]: + await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"]) + + await self.client.ml.close_job(job_id="*", force=True) + for job in await self.client.ml.get_jobs(job_id="*")["jobs"]: + await self.client.ml.delete_job( + job_id=job["job_id"], wait_for_completion=True, force=True + ) + + # stop and remove all Rollup jobs + if await self._feature_enabled("rollup"): + for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]: + await self.client.rollup.stop_job( + id=rollup["config"]["id"], wait_for_completion=True + ) + await self.client.rollup.delete_job(id=rollup["config"]["id"]) + + expand_wildcards = ["open", "closed"] + if (await self.es_version()) >= (7, 7): + expand_wildcards.append("hidden") + + await self.client.indices.delete( + index="*", ignore=404, expand_wildcards=expand_wildcards + ) + await self.client.indices.delete_template(name="*", ignore=404) + + async def es_version(self): + global ES_VERSION + if ES_VERSION is None: + version_string = (await self.client.info())["version"]["number"] + if "." not in version_string: + return () + version = version_string.strip().split(".") + ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version) + return ES_VERSION + + async def run(self): + await self.setup() + try: + await self.run_code(self._run_code) + finally: + await self.teardown() + + async def run_code(self, test): + """ Execute an instruction based on it's type. """ + print(test) + for action in test: + assert len(action) == 1 + action_type, action = list(action.items())[0] + + if hasattr(self, "run_" + action_type): + await await_if_coro(getattr(self, "run_" + action_type)(action)) + else: + raise InvalidActionType(action_type) + + async def run_do(self, action): + api = self.client + headers = action.pop("headers", None) + catch = action.pop("catch", None) + warn = action.pop("warnings", ()) + assert len(action) == 1 + + method, args = list(action.items())[0] + args["headers"] = headers + + # locate api endpoint + for m in method.split("."): + assert hasattr(api, m) + api = getattr(api, m) + + # some parameters had to be renamed to not clash with python builtins, + # compensate + for k in PARAMS_RENAMES: + if k in args: + args[PARAMS_RENAMES[k]] = args.pop(k) + + # resolve vars + for k in args: + args[k] = self._resolve(args[k]) + + warnings.simplefilter("always", category=ElasticsearchDeprecationWarning) + with warnings.catch_warnings(record=True) as caught_warnings: + try: + self.last_response = await api(**args) + except Exception as e: + if not catch: + raise + self.run_catch(catch, e) + else: + if catch: + raise AssertionError( + "Failed to catch %r in %r." % (catch, self.last_response) + ) + + # Filter out warnings raised by other components. + caught_warnings = [ + str(w.message) + for w in caught_warnings + if w.category == ElasticsearchDeprecationWarning + ] + + # Sorting removes the issue with order raised. We only care about + # if all warnings are raised in the single API call. + if sorted(warn) != sorted(caught_warnings): + raise AssertionError( + "Expected warnings not equal to actual warnings: expected=%r actual=%r" + % (warn, caught_warnings) + ) + + def run_catch(self, catch, exception): + if catch == "param": + assert isinstance(exception, TypeError) + return + + assert isinstance(exception, TransportError) + if catch in CATCH_CODES: + assert CATCH_CODES[catch] == exception.status_code + elif catch[0] == "/" and catch[-1] == "/": + assert ( + re.search(catch[1:-1], exception.error + " " + repr(exception.info)), + "%s not in %r" % (catch, exception.info), + ) is not None + self.last_response = exception.info + + async def run_skip(self, skip): + global IMPLEMENTED_FEATURES + + if "features" in skip: + features = skip["features"] + if not isinstance(features, (tuple, list)): + features = [features] + for feature in features: + if feature in IMPLEMENTED_FEATURES: + continue + pytest.skip("feature '%s' is not supported" % feature) + + if "version" in skip: + version, reason = skip["version"], skip["reason"] + if version == "all": + pytest.skip(reason) + min_version, max_version = version.split("-") + min_version = _get_version(min_version) or (0,) + max_version = _get_version(max_version) or (999,) + if min_version <= (await self.es_version()) <= max_version: + pytest.skip(reason) + + def run_gt(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) > value + + def run_gte(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) >= value + + def run_lt(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) < value + + def run_lte(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) <= value + + def run_set(self, action): + for key, value in action.items(): + value = self._resolve(value) + self._state[value] = self._lookup(key) + + def run_is_false(self, action): + try: + value = self._lookup(action) + except AssertionError: + pass + else: + assert value in ("", None, False, 0) + + def run_is_true(self, action): + value = self._lookup(action) + assert value not in ("", None, False, 0) + + def run_length(self, action): + for path, expected in action.items(): + value = self._lookup(path) + expected = self._resolve(expected) + assert expected == len(value) + + def run_match(self, action): + for path, expected in action.items(): + value = self._lookup(path) + expected = self._resolve(expected) + + if ( + isinstance(expected, string_types) + and expected.startswith("/") + and expected.endswith("/") + ): + expected = re.compile(expected[1:-1], re.VERBOSE | re.MULTILINE) + assert expected.search(value), "%r does not match %r" % ( + value, + expected, + ) + else: + assert expected == value, "%r does not match %r" % (value, expected) + + def _resolve(self, value): + # resolve variables + if isinstance(value, string_types) and value.startswith("$"): + value = value[1:] + assert value in self._state + value = self._state[value] + if isinstance(value, string_types): + value = value.strip() + elif isinstance(value, dict): + value = dict((k, self._resolve(v)) for (k, v) in value.items()) + elif isinstance(value, list): + value = list(map(self._resolve, value)) + return value + + def _lookup(self, path): + # fetch the possibly nested value from last_response + value = self.last_response + if path == "$body": + return value + path = path.replace(r"\.", "\1") + for step in path.split("."): + if not step: + continue + step = step.replace("\1", ".") + step = self._resolve(step) + if step.isdigit() and step not in value: + step = int(step) + assert isinstance(value, list) + assert len(value) > step + else: + assert step in value + value = value[step] + return value + + async def _feature_enabled(self, name): + global XPACK_FEATURES, IMPLEMENTED_FEATURES + if XPACK_FEATURES is None: + try: + xinfo = await self.client.xpack.info() + XPACK_FEATURES = set( + f for f in xinfo["features"] if xinfo["features"][f]["enabled"] + ) + IMPLEMENTED_FEATURES.add("xpack") + except RequestError: + XPACK_FEATURES = set() + IMPLEMENTED_FEATURES.add("no_xpack") + return name in XPACK_FEATURES + + +@pytest.fixture(scope="function") +def runner(async_client): + return YamlRunner(async_client) + + +@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) +async def test_rest_api_spec(test_spec, runner): + if test_spec.get("skip", False): + pytest.skip("Manually skipped in 'SKIP_TESTS'") + runner.use_spec(test_spec) + await runner.run() + + +class InvalidActionType(Exception): + pass