diff --git a/Makefile b/Makefile index bca2676..eea0b97 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -ELASTICMOCK_VERSION='1.9.0' +ELASTICMOCK_VERSION='1.10.0' install: pip3 install -r requirements.txt diff --git a/elasticmock/fake_elasticsearch.py b/elasticmock/fake_elasticsearch.py index cf21343..dab9b74 100644 --- a/elasticmock/fake_elasticsearch.py +++ b/elasticmock/fake_elasticsearch.py @@ -36,6 +36,7 @@ class QueryType: MULTI_MATCH = 'MULTI_MATCH' MUST_NOT = 'MUST_NOT' WILDCARD = 'WILDCARD' + PREFIX = 'PREFIX' @staticmethod def get_query_type(type_str): @@ -65,6 +66,8 @@ def get_query_type(type_str): return QueryType.MUST_NOT elif type_str == 'wildcard': return QueryType.WILDCARD + elif type_str == 'prefix': + return QueryType.PREFIX else: raise NotImplementedError(f'type {type_str} is not implemented for QueryType') @@ -157,6 +160,8 @@ def _evaluate_for_query_type(self, document): return self._evaluate_for_terms_query_type(document) elif self.type == QueryType.WILDCARD: return self._evaluate_for_wildcard_query_type(document) + elif self.type == QueryType.PREFIX: + return self._evaluate_for_prefix_query_type(document) elif self.type == QueryType.RANGE: return self._evaluate_for_range_query_type(document) elif self.type == QueryType.BOOL: @@ -191,12 +196,21 @@ def _evaluate_for_wildcard_query_type(self, document): return_val = False if isinstance(self.condition, dict): for _, sub_query in self.condition.items(): - return_val = self._evaluate_for_field(document, True, True) + return_val = self._evaluate_for_field(document, True, is_wildcard=True) if not return_val: return False return return_val - def _evaluate_for_field(self, document, ignore_case=True, is_wildcard=False): + def _evaluate_for_prefix_query_type(self, document): + return_val = False + if isinstance(self.condition, dict): + for _, sub_query in self.condition.items(): + return_val = self._evaluate_for_field(document, ignore_case=False, is_prefix=True) + if not return_val: + return False + return return_val + + def _evaluate_for_field(self, document, ignore_case=True, is_wildcard=False, is_prefix=False): doc_source = document['_source'] return_val = False for field, value in self.condition.items(): @@ -205,7 +219,8 @@ def _evaluate_for_field(self, document, ignore_case=True, is_wildcard=False): field, value, ignore_case, - is_wildcard + is_wildcard=is_wildcard, + is_prefix=is_prefix ) if return_val: break @@ -320,8 +335,8 @@ def _evaluate_for_should_query_type(self, document): def _evaluate_for_multi_match_query_type(self, document): return self._evaluate_for_fields(document) - def _compare_value_for_field(self, doc_source, field, value, ignore_case, is_wildcard=False): - if is_wildcard: + def _compare_value_for_field(self, doc_source, field, value, ignore_case, is_wildcard=False, is_prefix=False): + if is_wildcard or is_prefix: value = value['value'] if ignore_case and isinstance(value, str): value = value.lower() @@ -350,6 +365,8 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case, is_wil val = val.lower() if is_wildcard: return re.search(value.replace('*', '.*'), val) + if is_prefix: + return val.startswith(value) if value == val: return True if isinstance(val, str) and str(value) in val: diff --git a/setup.py b/setup.py index 007a24f..2f122ab 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import setuptools -__version__ = '1.9.0' +__version__ = '1.10.0' # read the contents of your readme file from os import path diff --git a/tests/fake_elasticsearch/test_search.py b/tests/fake_elasticsearch/test_search.py index c972027..c154add 100644 --- a/tests/fake_elasticsearch/test_search.py +++ b/tests/fake_elasticsearch/test_search.py @@ -112,6 +112,22 @@ def test_search_with_wildcard_query(self): hits = response['hits']['hits'] self.assertEqual(len(hits), 3) + def test_search_with_prefix_query(self): + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_20221010'}) + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_20221011'}) + self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_20221012'}) + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, + body={'query': {'prefix': {'data': {'value': 'test_1'}}}}) + self.assertEqual(response['hits']['total']['value'], 0) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 0) + + response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, + body={'query': {'prefix': {'data': {'value': 'test_2'}}}}) + self.assertEqual(response['hits']['total']['value'], 3) + hits = response['hits']['hits'] + self.assertEqual(len(hits), 3) + def test_search_with_match_query_in_int_list(self): for i in range(0, 10): self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': [i, 11, 13]})