From 01037c24ed3ac8ca55aa36c4dfb8d77a77d76743 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Fri, 8 May 2020 17:03:19 -0500 Subject: [PATCH] Add pytest for async REST API tests --- elasticsearch/_async/client/__init__.py | 3 + elasticsearch/_async/client/cat.py | 6 +- elasticsearch/_async/client/cluster.py | 2 +- elasticsearch/_async/client/indices.py | 14 +- elasticsearch/client/__init__.py | 3 + elasticsearch/client/cat.py | 6 +- elasticsearch/client/indices.py | 14 +- setup.py | 1 + test_elasticsearch/run_tests.py | 7 +- test_elasticsearch/test_async/conftest.py | 10 + .../test_async/test_connection.py | 18 +- .../test_async/test_connection_pool.py | 4 + test_elasticsearch/test_async/test_helpers.py | 96 ---- .../test_async/test_server/test_helpers.py | 10 +- .../test_server/test_rest_api_spec.py | 432 ++++++++++++++++++ .../test_async/test_transport.py | 4 + 16 files changed, 508 insertions(+), 122 deletions(-) create mode 100644 test_elasticsearch/test_async/conftest.py delete mode 100644 test_elasticsearch/test_async/test_helpers.py create mode 100644 test_elasticsearch/test_async/test_server/test_rest_api_spec.py diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 3690d29e67..52482690d5 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -1297,6 +1297,7 @@ async def rank_eval(self, body, index=None, params=None, headers=None): @query_params( "max_docs", + "prefer_v2_templates", "refresh", "requests_per_second", "scroll", @@ -1316,6 +1317,8 @@ async def reindex(self, body, params=None, headers=None): prototype for the index request. :arg max_docs: Maximum number of documents to process (default: all documents) + :arg prefer_v2_templates: favor V2 templates instead of V1 + templates during index creation :arg refresh: Should the affected indexes be refreshed? :arg requests_per_second: The throttle to set on this request in sub-requests per second. -1 means no throttle. diff --git a/elasticsearch/_async/client/cat.py b/elasticsearch/_async/client/cat.py index f5839035d2..87bc7c8fe2 100644 --- a/elasticsearch/_async/client/cat.py +++ b/elasticsearch/_async/client/cat.py @@ -326,7 +326,7 @@ async def pending_tasks(self, params=None, headers=None): "GET", "/_cat/pending_tasks", params=params, headers=headers ) - @query_params("format", "h", "help", "local", "master_timeout", "s", "size", "v") + @query_params("format", "h", "help", "local", "master_timeout", "s", "time", "v") async def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): """ Returns cluster-wide thread pool statistics per node. By default the active, @@ -345,8 +345,8 @@ async def thread_pool(self, thread_pool_patterns=None, params=None, headers=None to master node :arg s: Comma-separated list of column names or column aliases to sort by - :arg size: The multiplier in which to display values Valid - choices: , k, m, g, t, p + :arg time: The unit in which to display time values Valid + choices: d, h, m, s, ms, micros, nanos :arg v: Verbose mode. Display column headers """ return await self.transport.perform_request( diff --git a/elasticsearch/_async/client/cluster.py b/elasticsearch/_async/client/cluster.py index 73a551f049..d58c567d13 100644 --- a/elasticsearch/_async/client/cluster.py +++ b/elasticsearch/_async/client/cluster.py @@ -137,7 +137,7 @@ async def stats(self, node_id=None, params=None, headers=None): "GET", "/_cluster/stats" if node_id in SKIP_IN_PATH - else _make_path("_cluster/stats/nodes", node_id), + else _make_path("_cluster", "stats", "nodes", node_id), params=params, headers=headers, ) diff --git a/elasticsearch/_async/client/indices.py b/elasticsearch/_async/client/indices.py index 4c1d0169f6..b37ed877b8 100644 --- a/elasticsearch/_async/client/indices.py +++ b/elasticsearch/_async/client/indices.py @@ -1301,7 +1301,7 @@ async def get_index_template(self, name=None, params=None, headers=None): "GET", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("create", "master_timeout", "order") + @query_params("cause", "create", "master_timeout") async def put_index_template(self, name, body, params=None, headers=None): """ Creates or updates an index template. @@ -1309,12 +1309,11 @@ async def put_index_template(self, name, body, params=None, headers=None): :arg name: The name of the template :arg body: The template definition + :arg cause: User defined reason for creating/updating the index + template :arg create: Whether the index template should only be added if new or can also replace an existing one :arg master_timeout: Specify timeout for connection to master - :arg order: The order for this template when merging multiple - matching ones (higher numbers are merged later, overriding the lower - numbers) """ for param in (name, body): if param in SKIP_IN_PATH: @@ -1349,7 +1348,7 @@ async def exists_index_template(self, name, params=None, headers=None): "HEAD", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("master_timeout") + @query_params("cause", "create", "master_timeout") async def simulate_index_template(self, name, body=None, params=None, headers=None): """ Simulate matching the given index name against the index templates in the @@ -1360,6 +1359,11 @@ async def simulate_index_template(self, name, body=None, params=None, headers=No name) :arg body: New index template definition, which will be included in the simulation, as if it already exists in the system + :arg cause: User defined reason for dry-run creating the new + template for simulation purposes + :arg create: Whether the index template we optionally defined in + the body should only be dry-run added if new or can also replace an + existing one :arg master_timeout: Specify timeout for connection to master """ if name in SKIP_IN_PATH: diff --git a/elasticsearch/client/__init__.py b/elasticsearch/client/__init__.py index 17c0e01bf7..44cd5db953 100644 --- a/elasticsearch/client/__init__.py +++ b/elasticsearch/client/__init__.py @@ -1297,6 +1297,7 @@ def rank_eval(self, body, index=None, params=None, headers=None): @query_params( "max_docs", + "prefer_v2_templates", "refresh", "requests_per_second", "scroll", @@ -1316,6 +1317,8 @@ def reindex(self, body, params=None, headers=None): prototype for the index request. :arg max_docs: Maximum number of documents to process (default: all documents) + :arg prefer_v2_templates: favor V2 templates instead of V1 + templates during index creation :arg refresh: Should the affected indexes be refreshed? :arg requests_per_second: The throttle to set on this request in sub-requests per second. -1 means no throttle. diff --git a/elasticsearch/client/cat.py b/elasticsearch/client/cat.py index 87422c7fc2..84282850e3 100644 --- a/elasticsearch/client/cat.py +++ b/elasticsearch/client/cat.py @@ -326,7 +326,7 @@ def pending_tasks(self, params=None, headers=None): "GET", "/_cat/pending_tasks", params=params, headers=headers ) - @query_params("format", "h", "help", "local", "master_timeout", "s", "size", "v") + @query_params("format", "h", "help", "local", "master_timeout", "s", "time", "v") def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): """ Returns cluster-wide thread pool statistics per node. By default the active, @@ -345,8 +345,8 @@ def thread_pool(self, thread_pool_patterns=None, params=None, headers=None): to master node :arg s: Comma-separated list of column names or column aliases to sort by - :arg size: The multiplier in which to display values Valid - choices: , k, m, g, t, p + :arg time: The unit in which to display time values Valid + choices: d, h, m, s, ms, micros, nanos :arg v: Verbose mode. Display column headers """ return self.transport.perform_request( diff --git a/elasticsearch/client/indices.py b/elasticsearch/client/indices.py index 12c8c4a4a2..1aaf522aaf 100644 --- a/elasticsearch/client/indices.py +++ b/elasticsearch/client/indices.py @@ -1299,7 +1299,7 @@ def get_index_template(self, name=None, params=None, headers=None): "GET", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("create", "master_timeout", "order") + @query_params("cause", "create", "master_timeout") def put_index_template(self, name, body, params=None, headers=None): """ Creates or updates an index template. @@ -1307,12 +1307,11 @@ def put_index_template(self, name, body, params=None, headers=None): :arg name: The name of the template :arg body: The template definition + :arg cause: User defined reason for creating/updating the index + template :arg create: Whether the index template should only be added if new or can also replace an existing one :arg master_timeout: Specify timeout for connection to master - :arg order: The order for this template when merging multiple - matching ones (higher numbers are merged later, overriding the lower - numbers) """ for param in (name, body): if param in SKIP_IN_PATH: @@ -1347,7 +1346,7 @@ def exists_index_template(self, name, params=None, headers=None): "HEAD", _make_path("_index_template", name), params=params, headers=headers ) - @query_params("master_timeout") + @query_params("cause", "create", "master_timeout") def simulate_index_template(self, name, body=None, params=None, headers=None): """ Simulate matching the given index name against the index templates in the @@ -1358,6 +1357,11 @@ def simulate_index_template(self, name, body=None, params=None, headers=None): name) :arg body: New index template definition, which will be included in the simulation, as if it already exists in the system + :arg cause: User defined reason for dry-run creating the new + template for simulation purposes + :arg create: Whether the index template we optionally defined in + the body should only be dry-run added if new or can also replace an + existing one :arg master_timeout: Specify timeout for connection to master """ if name in SKIP_IN_PATH: diff --git a/setup.py b/setup.py index 3bcdf13609..979542def8 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ ], python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4", install_requires=install_requires, + test_suite="test_elasticsearch.run_tests.run_all", tests_require=tests_require, extras_require={ "develop": tests_require + docs_require + generate_require, diff --git a/test_elasticsearch/run_tests.py b/test_elasticsearch/run_tests.py index b19bae8151..b133f3bf36 100755 --- a/test_elasticsearch/run_tests.py +++ b/test_elasticsearch/run_tests.py @@ -78,9 +78,14 @@ def run_all(argv=None): "--log-level=DEBUG", "--cache-clear", "-vv", - abspath(dirname(__file__)), ] + # Skip all async tests unless Python 3.6+ + if sys.version_info < (3, 6): + argv.append("--ignore=test_elasticsearch/test_async/") + + argv.append(abspath(dirname(__file__))) + exit_code = 0 try: subprocess.check_call(argv, stdout=sys.stdout, stderr=sys.stderr) diff --git a/test_elasticsearch/test_async/conftest.py b/test_elasticsearch/test_async/conftest.py new file mode 100644 index 0000000000..46e5120b44 --- /dev/null +++ b/test_elasticsearch/test_async/conftest.py @@ -0,0 +1,10 @@ +# Licensed to Elasticsearch B.V under one or more agreements. +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information + +import sys +import pytest + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 6), reason="'test_async' is only run on Python 3.6+" +) diff --git a/test_elasticsearch/test_async/test_connection.py b/test_elasticsearch/test_async/test_connection.py index 9df8923a31..9a69468d4f 100644 --- a/test_elasticsearch/test_async/test_connection.py +++ b/test_elasticsearch/test_async/test_connection.py @@ -4,14 +4,25 @@ # See the LICENSE file in the project root for more information import ssl +import gzip +import io from mock import Mock, patch import warnings from platform import python_version +import aiohttp +import pytest from elasticsearch import AIOHttpConnection from elasticsearch import __versionstr__ from ..test_cases import TestCase, SkipTest +pytestmark = pytest.mark.asyncio + + +def gzip_decompress(data): + buf = gzip.GzipFile(fileobj=io.BytesIO(data), mode="rb") + return buf.read() + class TestAIOHttpConnection(TestCase): async def _get_mock_connection(self, connection_params={}, response_body=b"{}"): @@ -232,14 +243,15 @@ def test_uses_https_if_verify_certs_is_off(self): self.assertEqual(con.scheme, "https") self.assertEqual(con.host, "https://localhost:9200") - def nowarn_when_test_uses_https_if_verify_certs_is_off(self): + async def test_nowarn_when_test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = Urllib3HttpConnection( + con = AIOHttpConnection( use_ssl=True, verify_certs=False, ssl_show_warn=False ) + con._create_aiohttp_session() self.assertEqual(0, len(w)) - self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool) + self.assertIsInstance(con.session, aiohttp.ClientSession) def test_doesnt_use_https_if_not_specified(self): con = AIOHttpConnection() diff --git a/test_elasticsearch/test_async/test_connection_pool.py b/test_elasticsearch/test_async/test_connection_pool.py index ce77c984ce..7237166872 100644 --- a/test_elasticsearch/test_async/test_connection_pool.py +++ b/test_elasticsearch/test_async/test_connection_pool.py @@ -3,6 +3,7 @@ # See the LICENSE file in the project root for more information import time +import pytest from elasticsearch import ( AsyncConnectionPool, @@ -15,6 +16,9 @@ from ..test_cases import TestCase +pytestmark = pytest.mark.asyncio + + class TestConnectionPool(TestCase): def test_dummy_cp_raises_exception_on_more_connections(self): self.assertRaises(ImproperlyConfigured, AsyncDummyConnectionPool, []) diff --git a/test_elasticsearch/test_async/test_helpers.py b/test_elasticsearch/test_async/test_helpers.py deleted file mode 100644 index 020f02724b..0000000000 --- a/test_elasticsearch/test_async/test_helpers.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -# Licensed to Elasticsearch B.V under one or more agreements. -# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. -# See the LICENSE file in the project root for more information - -import mock -import time -import threading -from nose.plugins.skip import SkipTest -from elasticsearch import helpers, Elasticsearch -from elasticsearch.serializer import JSONSerializer - -from ..test_cases import TestCase - -lock_side_effect = threading.Lock() - - -def mock_process_bulk_chunk(*args, **kwargs): - """ - Threadsafe way of mocking process bulk chunk: - https://stackoverflow.com/questions/39332139/thread-safe-version-of-mock-call-count - """ - - with lock_side_effect: - mock_process_bulk_chunk.call_count += 1 - time.sleep(0.1) - return [] - - -mock_process_bulk_chunk.call_count = 0 - - -class TestParallelBulk(TestCase): - @mock.patch( - "elasticsearch.helpers.actions._process_bulk_chunk", - side_effect=mock_process_bulk_chunk, - ) - def test_all_chunks_sent(self, _process_bulk_chunk): - actions = ({"x": i} for i in range(100)) - list(helpers.parallel_bulk(Elasticsearch(), actions, chunk_size=2)) - - self.assertEqual(50, mock_process_bulk_chunk.call_count) - - @SkipTest - @mock.patch( - "elasticsearch.helpers.actions._process_bulk_chunk", - # make sure we spend some time in the thread - side_effect=lambda *a: [ - (True, time.sleep(0.001) or threading.current_thread().ident) - ], - ) - def test_chunk_sent_from_different_threads(self, _process_bulk_chunk): - actions = ({"x": i} for i in range(100)) - results = list( - helpers.parallel_bulk( - Elasticsearch(), actions, thread_count=10, chunk_size=2 - ) - ) - self.assertTrue(len(set([r[1] for r in results])) > 1) - - -class TestChunkActions(TestCase): - def setUp(self): - super(TestChunkActions, self).setUp() - self.actions = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] - - def test_chunks_are_chopped_by_byte_size(self): - self.assertEqual( - 100, - len( - list(helpers._chunk_actions(self.actions, 100000, 1, JSONSerializer())) - ), - ) - - def test_chunks_are_chopped_by_chunk_size(self): - self.assertEqual( - 10, - len( - list( - helpers._chunk_actions(self.actions, 10, 99999999, JSONSerializer()) - ) - ), - ) - - def test_chunks_are_chopped_by_byte_size_properly(self): - max_byte_size = 170 - chunks = list( - helpers._chunk_actions( - self.actions, 100000, max_byte_size, JSONSerializer() - ) - ) - self.assertEqual(25, len(chunks)) - for chunk_data, chunk_actions in chunks: - chunk = u"".join(chunk_actions) - chunk = chunk if isinstance(chunk, str) else chunk.encode("utf-8") - self.assertLessEqual(len(chunk), max_byte_size) diff --git a/test_elasticsearch/test_async/test_server/test_helpers.py b/test_elasticsearch/test_async/test_server/test_helpers.py index 6fa2361b44..5b3e8aa127 100644 --- a/test_elasticsearch/test_async/test_server/test_helpers.py +++ b/test_elasticsearch/test_async/test_server/test_helpers.py @@ -444,7 +444,7 @@ async def clear_scroll(*_, **__): ] assert data == [{"search_data": 1}, {"scroll_data": 42}] - client_mock.scroll = MockScroll().scroll + client_mock.scroll = Mock() with pytest.raises(ScanError): data = [ doc @@ -455,7 +455,7 @@ async def clear_scroll(*_, **__): ) ] assert data == [{"search_data": 1}] - scroll_mock.assert_not_called() + client_mock.scroll.assert_not_called() async def test_no_scroll_id_fast_route(self): client_mock = Mock() @@ -585,7 +585,7 @@ async def test_reindex_passes_kwargs_to_scan_and_bulk( == (await async_client.count(index="prod_index", q="type:answers"))["count"] ) - assert {"answer": 42, "correct": True, "type": "answers",} == ( + assert {"answer": 42, "correct": True, "type": "answers"} == ( await async_client.get(index="prod_index", id=42) )["_source"] @@ -604,7 +604,7 @@ async def test_reindex_accepts_a_query(self, async_client, reindex_fixture): == (await async_client.count(index="prod_index", q="type:answers"))["count"] ) - assert {"answer": 42, "correct": True, "type": "answers",} == ( + assert {"answer": 42, "correct": True, "type": "answers"} == ( await async_client.get(index="prod_index", id=42) )["_source"] @@ -624,7 +624,7 @@ async def test_all_documents_get_moved(self, async_client, reindex_fixture): == (await async_client.count(index="prod_index", q="type:answers"))["count"] ) - assert {"answer": 42, "correct": True, "type": "answers",} == ( + assert {"answer": 42, "correct": True, "type": "answers"} == ( await async_client.get(index="prod_index", id=42) )["_source"] diff --git a/test_elasticsearch/test_async/test_server/test_rest_api_spec.py b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py new file mode 100644 index 0000000000..2664415c06 --- /dev/null +++ b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py @@ -0,0 +1,432 @@ +# Licensed to Elasticsearch B.V under one or more agreements. +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information + +""" +Dynamically generated set of TestCases based on set of yaml files decribing +some integration tests. These files are shared among all official Elasticsearch +clients. +""" +import pytest +import sys +import re +from os import walk, environ +from os.path import exists, join, dirname, pardir, relpath +import yaml +from shutil import rmtree +import warnings +import inspect + +from elasticsearch import TransportError, RequestError, ElasticsearchDeprecationWarning +from elasticsearch.compat import string_types +from elasticsearch.helpers.test import _get_version + +pytestmark = pytest.mark.asyncio + +# some params had to be changed in python, keep track of them so we can rename +# those in the tests accordingly +PARAMS_RENAMES = {"type": "doc_type", "from": "from_"} + +# mapping from catch values to http status codes +CATCH_CODES = {"missing": 404, "conflict": 409, "unauthorized": 401} + +# test features we have implemented +IMPLEMENTED_FEATURES = { + "gtelte", + "stash_in_path", + "headers", + "catch_unauthorized", + "default_shards", + "warnings", +} + +# broken YAML tests on some releases +SKIP_TESTS = { + "*": { + # Can't figure out the get_alias(expand_wildcards=open) failure. + "TestIndicesGetAlias10Basic", + # Disallowing expensive queries is 7.7+ + "TestSearch320DisallowQueries", + } +} + +# Test is inconsistent due to dictionaries not being ordered. +if sys.version_info < (3, 6): + SKIP_TESTS["*"].add("TestSearchAggregation250MovingFn") + + +XPACK_FEATURES = None +ES_VERSION = None + +YAML_DIR = environ.get( + "TEST_ES_YAML_DIR", + join( + dirname(__file__), + pardir, + pardir, + pardir, + pardir, + "elasticsearch", + "rest-api-spec", + "src", + "main", + "resources", + "rest-api-spec", + "test", + ), +) + + +YAML_TEST_SPECS = [] + +if exists(YAML_DIR): + # find all the test definitions in yaml files ... + for path, _, files in walk(YAML_DIR): + for filename in files: + if not filename.endswith((".yaml", ".yml")): + continue + + filepath = join(path, filename) + with open(filepath) as f: + tests = list(yaml.load_all(f)) + + setup_code = None + teardown_code = None + run_codes = [] + for i, test in enumerate(tests): + for test_name, definition in test.items(): + if test_name == "setup": + setup_code = definition + elif test_name == "teardown": + teardown_code = definition + else: + run_codes.append((i, definition)) + + for i, run_code in run_codes: + src = {"setup": setup_code, "run": run_code, "teardown": teardown_code} + # Pytest already replaces '.' with '_' so we do + # it ourselves so UI and 'SKIP_TESTS' match. + pytest_param_id = ( + "%s[%d]" % (relpath(filepath, YAML_DIR).rpartition(".")[0], i) + ).replace(".", "_") + + if pytest_param_id in SKIP_TESTS: + src["skip"] = True + + YAML_TEST_SPECS.append(pytest.param(src, id=pytest_param_id)) + + +async def await_if_coro(x): + if inspect.iscoroutine(x): + return await x + return x + + +class YamlRunner: + def __init__(self, client): + self.client = client + self.last_response = None + + self._run_code = None + self._setup_code = None + self._teardown_code = None + self._state = {} + + def use_spec(self, test_spec): + self._setup_code = test_spec.pop("setup", None) + self._run_code = test_spec.pop("run", None) + self._teardown_code = test_spec.pop("teardown") + + async def setup(self): + if self._setup_code: + await self.run_code(self._setup_code) + + async def teardown(self): + if self._teardown_code: + await self.run_code(self._teardown_code) + + for repo, definition in ( + await self.client.snapshot.get_repository(repository="_all") + ).items(): + await self.client.snapshot.delete_repository(repository=repo) + if definition["type"] == "fs": + rmtree( + "/tmp/%s" % definition["settings"]["location"], ignore_errors=True + ) + + # stop and remove all ML stuff + if await self._feature_enabled("ml"): + await self.client.ml.stop_datafeed(datafeed_id="*", force=True) + for feed in await self.client.ml.get_datafeeds(datafeed_id="*")[ + "datafeeds" + ]: + await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"]) + + await self.client.ml.close_job(job_id="*", force=True) + for job in await self.client.ml.get_jobs(job_id="*")["jobs"]: + await self.client.ml.delete_job( + job_id=job["job_id"], wait_for_completion=True, force=True + ) + + # stop and remove all Rollup jobs + if await self._feature_enabled("rollup"): + for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]: + await self.client.rollup.stop_job( + id=rollup["config"]["id"], wait_for_completion=True + ) + await self.client.rollup.delete_job(id=rollup["config"]["id"]) + + expand_wildcards = ["open", "closed"] + if (await self.es_version()) >= (7, 7): + expand_wildcards.append("hidden") + + await self.client.indices.delete( + index="*", ignore=404, expand_wildcards=expand_wildcards + ) + await self.client.indices.delete_template(name="*", ignore=404) + + async def es_version(self): + global ES_VERSION + if ES_VERSION is None: + version_string = (await self.client.info())["version"]["number"] + if "." not in version_string: + return () + version = version_string.strip().split(".") + ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version) + return ES_VERSION + + async def run(self): + await self.setup() + try: + await self.run_code(self._run_code) + finally: + await self.teardown() + + async def run_code(self, test): + """ Execute an instruction based on it's type. """ + print(test) + for action in test: + assert len(action) == 1 + action_type, action = list(action.items())[0] + + if hasattr(self, "run_" + action_type): + await await_if_coro(getattr(self, "run_" + action_type)(action)) + else: + raise InvalidActionType(action_type) + + async def run_do(self, action): + api = self.client + headers = action.pop("headers", None) + catch = action.pop("catch", None) + warn = action.pop("warnings", ()) + assert len(action) == 1 + + method, args = list(action.items())[0] + args["headers"] = headers + + # locate api endpoint + for m in method.split("."): + assert hasattr(api, m) + api = getattr(api, m) + + # some parameters had to be renamed to not clash with python builtins, + # compensate + for k in PARAMS_RENAMES: + if k in args: + args[PARAMS_RENAMES[k]] = args.pop(k) + + # resolve vars + for k in args: + args[k] = self._resolve(args[k]) + + warnings.simplefilter("always", category=ElasticsearchDeprecationWarning) + with warnings.catch_warnings(record=True) as caught_warnings: + try: + self.last_response = await api(**args) + except Exception as e: + if not catch: + raise + self.run_catch(catch, e) + else: + if catch: + raise AssertionError( + "Failed to catch %r in %r." % (catch, self.last_response) + ) + + # Filter out warnings raised by other components. + caught_warnings = [ + str(w.message) + for w in caught_warnings + if w.category == ElasticsearchDeprecationWarning + ] + + # Sorting removes the issue with order raised. We only care about + # if all warnings are raised in the single API call. + if sorted(warn) != sorted(caught_warnings): + raise AssertionError( + "Expected warnings not equal to actual warnings: expected=%r actual=%r" + % (warn, caught_warnings) + ) + + def run_catch(self, catch, exception): + if catch == "param": + assert isinstance(exception, TypeError) + return + + assert isinstance(exception, TransportError) + if catch in CATCH_CODES: + assert CATCH_CODES[catch] == exception.status_code + elif catch[0] == "/" and catch[-1] == "/": + assert ( + re.search(catch[1:-1], exception.error + " " + repr(exception.info)), + "%s not in %r" % (catch, exception.info), + ) is not None + self.last_response = exception.info + + async def run_skip(self, skip): + global IMPLEMENTED_FEATURES + + if "features" in skip: + features = skip["features"] + if not isinstance(features, (tuple, list)): + features = [features] + for feature in features: + if feature in IMPLEMENTED_FEATURES: + continue + pytest.skip("feature '%s' is not supported" % feature) + + if "version" in skip: + version, reason = skip["version"], skip["reason"] + if version == "all": + pytest.skip(reason) + min_version, max_version = version.split("-") + min_version = _get_version(min_version) or (0,) + max_version = _get_version(max_version) or (999,) + if min_version <= (await self.es_version()) <= max_version: + pytest.skip(reason) + + def run_gt(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) > value + + def run_gte(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) >= value + + def run_lt(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) < value + + def run_lte(self, action): + for key, value in action.items(): + value = self._resolve(value) + assert self._lookup(key) <= value + + def run_set(self, action): + for key, value in action.items(): + value = self._resolve(value) + self._state[value] = self._lookup(key) + + def run_is_false(self, action): + try: + value = self._lookup(action) + except AssertionError: + pass + else: + assert value in ("", None, False, 0) + + def run_is_true(self, action): + value = self._lookup(action) + assert value not in ("", None, False, 0) + + def run_length(self, action): + for path, expected in action.items(): + value = self._lookup(path) + expected = self._resolve(expected) + assert expected == len(value) + + def run_match(self, action): + for path, expected in action.items(): + value = self._lookup(path) + expected = self._resolve(expected) + + if ( + isinstance(expected, string_types) + and expected.startswith("/") + and expected.endswith("/") + ): + expected = re.compile(expected[1:-1], re.VERBOSE | re.MULTILINE) + assert expected.search(value), "%r does not match %r" % ( + value, + expected, + ) + else: + assert expected == value, "%r does not match %r" % (value, expected) + + def _resolve(self, value): + # resolve variables + if isinstance(value, string_types) and value.startswith("$"): + value = value[1:] + assert value in self._state + value = self._state[value] + if isinstance(value, string_types): + value = value.strip() + elif isinstance(value, dict): + value = dict((k, self._resolve(v)) for (k, v) in value.items()) + elif isinstance(value, list): + value = list(map(self._resolve, value)) + return value + + def _lookup(self, path): + # fetch the possibly nested value from last_response + value = self.last_response + if path == "$body": + return value + path = path.replace(r"\.", "\1") + for step in path.split("."): + if not step: + continue + step = step.replace("\1", ".") + step = self._resolve(step) + if step.isdigit() and step not in value: + step = int(step) + assert isinstance(value, list) + assert len(value) > step + else: + assert step in value + value = value[step] + return value + + async def _feature_enabled(self, name): + global XPACK_FEATURES, IMPLEMENTED_FEATURES + if XPACK_FEATURES is None: + try: + xinfo = await self.client.xpack.info() + XPACK_FEATURES = set( + f for f in xinfo["features"] if xinfo["features"][f]["enabled"] + ) + IMPLEMENTED_FEATURES.add("xpack") + except RequestError: + XPACK_FEATURES = set() + IMPLEMENTED_FEATURES.add("no_xpack") + return name in XPACK_FEATURES + + +@pytest.fixture(scope="function") +def runner(async_client): + return YamlRunner(async_client) + + +@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) +async def test_rest_api_spec(test_spec, runner): + if test_spec.get("skip", False): + pytest.skip("Manually skipped in 'SKIP_TESTS'") + runner.use_spec(test_spec) + await runner.run() + + +class InvalidActionType(Exception): + pass diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index 20ca8b7a31..4d4a855c37 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -6,6 +6,7 @@ from __future__ import unicode_literals import time from mock import patch +import pytest from elasticsearch import AsyncTransport from elasticsearch.connection import Connection @@ -15,6 +16,9 @@ from ..test_cases import TestCase +pytestmark = pytest.mark.asyncio + + class DummyConnection(Connection): def __init__(self, **kwargs): self.exception = kwargs.pop("exception", None)