diff --git a/.travis.yml b/.travis.yml index 0f67258..80f3649 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ dist: xenial language: python python: - - "2.7" - "3.6" - "3.7" - "3.8" diff --git a/Makefile b/Makefile index 99cfdf6..1627946 100644 --- a/Makefile +++ b/Makefile @@ -1,26 +1,26 @@ -ELASTICMOCK_VERSION='1.5.1' +ELASTICMOCK_VERSION='2.0' install: - @pip install -r requirements.txt + pip install -r requirements.txt test_install: install - @pip install -r requirements_test.txt + pip install -r requirements_test.txt test: test_install - @tox -p 20 --parallel--safe-build + python3 setup.py test upload: create_dist - @pip install twine - @twine upload dist/* - @git push + pip install twine + twine upload dist/* + git push create_dist: create_dist_commit update_pip - @rm -rf dist - @python setup.py sdist + rm -rf dist + python3 setup.py sdist create_dist_commit: - @git commit --all -m "Bump version ${ELASTICMOCK_VERSION}" - @git tag ${ELASTICMOCK_VERSION} + git commit --all -m "Bump version ${ELASTICMOCK_VERSION}" + git tag ${ELASTICMOCK_VERSION} update_pip: - @pip install --upgrade pip + pip install --upgrade pip diff --git a/elasticmock/fake_elasticsearch.py b/elasticmock/fake_elasticsearch.py index a93b65f..541a17a 100644 --- a/elasticmock/fake_elasticsearch.py +++ b/elasticmock/fake_elasticsearch.py @@ -18,6 +18,136 @@ unicode = str +class QueryType: + + BOOL = 'BOOL' + FILTER = 'FILTER' + MATCH = 'MATCH' + TERM = 'TERM' + TERMS = 'TERMS' + + @staticmethod + def get_query_type(type_str): + if type_str == 'bool': + return QueryType.BOOL + elif type_str == 'filter': + return QueryType.FILTER + elif type_str == 'match': + return QueryType.MATCH + elif type_str == 'term': + return QueryType.TERM + elif type_str == 'terms': + return QueryType.TERMS + else: + raise NotImplementedError(f'type {type_str} is not implemented for QueryType') + + +class FakeQueryCondition: + type = None + condition = None + + def __init__(self, type, condition): + self.type = type + self.condition = condition + + def evaluate(self, document): + return self._evaluate_for_query_type(document) + + def _evaluate_for_query_type(self, document): + if self.type == QueryType.MATCH: + return self._evaluate_for_match_query_type(document) + elif self.type == QueryType.TERM: + return self._evaluate_for_term_query_type(document) + elif self.type == QueryType.TERMS: + return self._evaluate_for_terms_query_type(document) + elif self.type == QueryType.BOOL: + return self._evaluate_for_compound_query_type(document) + elif self.type == QueryType.FILTER: + return self._evaluate_for_compound_query_type(document) + else: + raise NotImplementedError('Fake query evaluation not implemented for query type: %s' % self.type) + + def _evaluate_for_match_query_type(self, document): + return self._evaluate_for_field(document, True) + + def _evaluate_for_term_query_type(self, document): + return self._evaluate_for_field(document, False) + + def _evaluate_for_terms_query_type(self, document): + for field in self.condition: + for term in self.condition[field]: + if FakeQueryCondition(QueryType.TERM, {field: term}).evaluate(document): + return True + return False + + def _evaluate_for_field(self, document, ignore_case): + doc_source = document['_source'] + return_val = False + for field, value in self.condition.items(): + return_val = self._compare_value_for_field( + doc_source, + field, + value, + ignore_case + ) + if return_val: + break + return return_val + + def _evaluate_for_compound_query_type(self, document): + return_val = False + if isinstance(self.condition, dict): + for query_type, sub_query in self.condition.items(): + return_val = FakeQueryCondition( + QueryType.get_query_type(query_type), + sub_query + ).evaluate(document) + if not return_val: + return False + elif isinstance(self.condition, list): + for sub_condition in self.condition: + for sub_condition_key in sub_condition: + return_val = FakeQueryCondition( + QueryType.get_query_type(sub_condition_key), + sub_condition[sub_condition_key] + ).evaluate(document) + if not return_val: + return False + + return return_val + + def _compare_value_for_field(self, doc_source, field, value, ignore_case): + value = str(value).lower() if ignore_case and isinstance(value, str) \ + else value + doc_val = None + if hasattr(doc_source, field): + doc_val = getattr(doc_source, field) + elif field in doc_source: + doc_val = doc_source[field] + + if isinstance(doc_val, list): + for val in doc_val: + val = val if isinstance(val, (int, float, complex)) \ + else str(val) + if ignore_case and isinstance(val, str): + val = val.lower() + if isinstance(val, str) and value in val: + return True + if value == val: + return True + else: + doc_val = doc_val if isinstance(doc_val, (int, float, complex)) \ + else str(doc_val) + if ignore_case and isinstance(doc_val, str): + doc_val = doc_val.lower() + if isinstance(doc_val, str) and value in doc_val: + return True + if value == doc_val: + return True + + return False + + @for_all_methods([server_failure]) class FakeElasticsearch(Elasticsearch): __documents_dict = None @@ -55,8 +185,17 @@ def info(self, params=None, headers=None): 'tagline': 'You Know, for Search' } - @query_params('consistency', 'op_type', 'parent', 'refresh', 'replication', - 'routing', 'timeout', 'timestamp', 'ttl', 'version', 'version_type') + @query_params('consistency', + 'op_type', + 'parent', + 'refresh', + 'replication', + 'routing', + 'timeout', + 'timestamp', + 'ttl', + 'version', + 'version_type') def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None): if index not in self.__documents_dict: self.__documents_dict[index] = list() @@ -201,6 +340,9 @@ def count(self, index=None, doc_type=None, body=None, params=None, headers=None) return result + def _get_fake_query_condition(self, query_type_str, condition): + return FakeQueryCondition(QueryType.get_query_type(query_type_str), condition) + @query_params('_source', '_source_exclude', '_source_include', 'allow_no_indices', 'analyze_wildcard', 'analyzer', 'default_operator', 'df', 'expand_wildcards', 'explain', 'fielddata_fields', 'fields', @@ -213,6 +355,12 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None searchable_indexes = self._normalize_index_to_list(index) matches = [] + conditions = [] + + if body and 'query' in body: + query = body['query'] + for query_type_str, condition in query.items(): + conditions.append(self._get_fake_query_condition(query_type_str, condition)) for searchable_index in searchable_indexes: for document in self.__documents_dict[searchable_index]: if doc_type: @@ -220,7 +368,13 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None continue if isinstance(doc_type, str) and document.get('_type') != doc_type: continue - matches.append(document) + if conditions: + for condition in conditions: + if condition.evaluate(document): + matches.append(document) + break + else: + matches.append(document) result = { 'hits': { diff --git a/requirements.txt b/requirements.txt index 7f58b5a..822a3ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ elasticsearch>=1.9.0,<8.0.0 -mock==3.0.5 \ No newline at end of file +mock==3.0.5 +ipdb \ No newline at end of file diff --git a/setup.py b/setup.py index 72cb368..eafcc94 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import setuptools -__version__ = '1.5.1' +__version__ = '2.0' # read the contents of your readme file from os import path diff --git a/tests/fake_elasticsearch/test_search.py b/tests/fake_elasticsearch/test_search.py index 098f77b..74079fe 100644 --- a/tests/fake_elasticsearch/test_search.py +++ b/tests/fake_elasticsearch/test_search.py @@ -65,3 +65,70 @@ def test_search_with_scroll_param(self): self.assertNotEqual(None, result.get('_scroll_id', None)) self.assertEqual(30, len(result.get('hits').get('hits'))) self.assertEqual(100, result.get('hits').get('total')) + + def test_search_with_match_query(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_{0}'.format(i)}) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': 'TEST' } } }) + self.assertEqual(response['hits']['total'], 10) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 10) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': '3' } } }) + self.assertEqual(response['hits']['total'], 1) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 1) + self.assertEqual(hits[0]['_source'], {'data': 'test_3'}) + + def test_search_with_match_query_in_int_list(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': [i, 11, 13]}) + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': 1 } } }) + self.assertEqual(response['hits']['total'], 1) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 1) + self.assertEqual(hits[0]['_source'], {'data': [1, 11, 13] }) + + def test_search_with_match_query_in_string_list(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': [str(i), 'two', 'three']}) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': '1' } } }) + self.assertEqual(response['hits']['total'], 1) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 1) + self.assertEqual(hits[0]['_source'], {'data': ['1', 'two', 'three']}) + + def test_search_with_term_query(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_{0}'.format(i)}) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'term': {'data': 'TEST' } } }) + self.assertEqual(response['hits']['total'], 0) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 0) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'term': {'data': '3' } } }) + self.assertEqual(response['hits']['total'], 1) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 1) + self.assertEqual(hits[0]['_source'], {'data': 'test_3'}) + + def test_search_with_bool_query(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'id': i}) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'bool': {'filter': [{'term': {'id': 1}}]}}}) + self.assertEqual(response['hits']['total'], 1) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 1) + + def test_search_with_terms_query(self): + for i in range(0, 10): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'id': i}) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'terms': {'id': [1, 2, 3]}}}) + self.assertEqual(response['hits']['total'], 3) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 3) diff --git a/tox.ini b/tox.ini index be51eaf..cbc50e8 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,6 @@ # content of: tox.ini , put in same dir as setup.py [tox] envlist = - py27-elasticsearch{1,2,5,6,7} py36-elasticsearch{1,2,5,6,7} py37-elasticsearch{1,2,5,6,7} py38-elasticsearch{1,2,5,6,7}