Skip to content

Commit

Permalink
Merge pull request #42 from KyKoPho/master
Browse files Browse the repository at this point in the history
Implements several basic search types, bumps major version
  • Loading branch information
vrcmarcos authored Sep 19, 2020
2 parents 646b214 + 977270d commit c7d2909
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 19 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
dist: xenial
language: python
python:
- "2.7"
- "3.6"
- "3.7"
- "3.8"
Expand Down
24 changes: 12 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
ELASTICMOCK_VERSION='1.5.1'
ELASTICMOCK_VERSION='2.0'

install:
@pip install -r requirements.txt
pip install -r requirements.txt

test_install: install
@pip install -r requirements_test.txt
pip install -r requirements_test.txt

test: test_install
@tox -p 20 --parallel--safe-build
python3 setup.py test

upload: create_dist
@pip install twine
@twine upload dist/*
@git push
pip install twine
twine upload dist/*
git push

create_dist: create_dist_commit update_pip
@rm -rf dist
@python setup.py sdist
rm -rf dist
python3 setup.py sdist

create_dist_commit:
@git commit --all -m "Bump version ${ELASTICMOCK_VERSION}"
@git tag ${ELASTICMOCK_VERSION}
git commit --all -m "Bump version ${ELASTICMOCK_VERSION}"
git tag ${ELASTICMOCK_VERSION}

update_pip:
@pip install --upgrade pip
pip install --upgrade pip
160 changes: 157 additions & 3 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,136 @@
unicode = str


class QueryType:

BOOL = 'BOOL'
FILTER = 'FILTER'
MATCH = 'MATCH'
TERM = 'TERM'
TERMS = 'TERMS'

@staticmethod
def get_query_type(type_str):
if type_str == 'bool':
return QueryType.BOOL
elif type_str == 'filter':
return QueryType.FILTER
elif type_str == 'match':
return QueryType.MATCH
elif type_str == 'term':
return QueryType.TERM
elif type_str == 'terms':
return QueryType.TERMS
else:
raise NotImplementedError(f'type {type_str} is not implemented for QueryType')


class FakeQueryCondition:
type = None
condition = None

def __init__(self, type, condition):
self.type = type
self.condition = condition

def evaluate(self, document):
return self._evaluate_for_query_type(document)

def _evaluate_for_query_type(self, document):
if self.type == QueryType.MATCH:
return self._evaluate_for_match_query_type(document)
elif self.type == QueryType.TERM:
return self._evaluate_for_term_query_type(document)
elif self.type == QueryType.TERMS:
return self._evaluate_for_terms_query_type(document)
elif self.type == QueryType.BOOL:
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.FILTER:
return self._evaluate_for_compound_query_type(document)
else:
raise NotImplementedError('Fake query evaluation not implemented for query type: %s' % self.type)

def _evaluate_for_match_query_type(self, document):
return self._evaluate_for_field(document, True)

def _evaluate_for_term_query_type(self, document):
return self._evaluate_for_field(document, False)

def _evaluate_for_terms_query_type(self, document):
for field in self.condition:
for term in self.condition[field]:
if FakeQueryCondition(QueryType.TERM, {field: term}).evaluate(document):
return True
return False

def _evaluate_for_field(self, document, ignore_case):
doc_source = document['_source']
return_val = False
for field, value in self.condition.items():
return_val = self._compare_value_for_field(
doc_source,
field,
value,
ignore_case
)
if return_val:
break
return return_val

def _evaluate_for_compound_query_type(self, document):
return_val = False
if isinstance(self.condition, dict):
for query_type, sub_query in self.condition.items():
return_val = FakeQueryCondition(
QueryType.get_query_type(query_type),
sub_query
).evaluate(document)
if not return_val:
return False
elif isinstance(self.condition, list):
for sub_condition in self.condition:
for sub_condition_key in sub_condition:
return_val = FakeQueryCondition(
QueryType.get_query_type(sub_condition_key),
sub_condition[sub_condition_key]
).evaluate(document)
if not return_val:
return False

return return_val

def _compare_value_for_field(self, doc_source, field, value, ignore_case):
value = str(value).lower() if ignore_case and isinstance(value, str) \
else value
doc_val = None
if hasattr(doc_source, field):
doc_val = getattr(doc_source, field)
elif field in doc_source:
doc_val = doc_source[field]

if isinstance(doc_val, list):
for val in doc_val:
val = val if isinstance(val, (int, float, complex)) \
else str(val)
if ignore_case and isinstance(val, str):
val = val.lower()
if isinstance(val, str) and value in val:
return True
if value == val:
return True
else:
doc_val = doc_val if isinstance(doc_val, (int, float, complex)) \
else str(doc_val)
if ignore_case and isinstance(doc_val, str):
doc_val = doc_val.lower()
if isinstance(doc_val, str) and value in doc_val:
return True
if value == doc_val:
return True

return False


@for_all_methods([server_failure])
class FakeElasticsearch(Elasticsearch):
__documents_dict = None
Expand Down Expand Up @@ -55,8 +185,17 @@ def info(self, params=None, headers=None):
'tagline': 'You Know, for Search'
}

@query_params('consistency', 'op_type', 'parent', 'refresh', 'replication',
'routing', 'timeout', 'timestamp', 'ttl', 'version', 'version_type')
@query_params('consistency',
'op_type',
'parent',
'refresh',
'replication',
'routing',
'timeout',
'timestamp',
'ttl',
'version',
'version_type')
def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None):
if index not in self.__documents_dict:
self.__documents_dict[index] = list()
Expand Down Expand Up @@ -201,6 +340,9 @@ def count(self, index=None, doc_type=None, body=None, params=None, headers=None)

return result

def _get_fake_query_condition(self, query_type_str, condition):
return FakeQueryCondition(QueryType.get_query_type(query_type_str), condition)

@query_params('_source', '_source_exclude', '_source_include',
'allow_no_indices', 'analyze_wildcard', 'analyzer', 'default_operator',
'df', 'expand_wildcards', 'explain', 'fielddata_fields', 'fields',
Expand All @@ -213,14 +355,26 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
searchable_indexes = self._normalize_index_to_list(index)

matches = []
conditions = []

if body and 'query' in body:
query = body['query']
for query_type_str, condition in query.items():
conditions.append(self._get_fake_query_condition(query_type_str, condition))
for searchable_index in searchable_indexes:
for document in self.__documents_dict[searchable_index]:
if doc_type:
if isinstance(doc_type, list) and document.get('_type') not in doc_type:
continue
if isinstance(doc_type, str) and document.get('_type') != doc_type:
continue
matches.append(document)
if conditions:
for condition in conditions:
if condition.evaluate(document):
matches.append(document)
break
else:
matches.append(document)

result = {
'hits': {
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
elasticsearch>=1.9.0,<8.0.0
mock==3.0.5
mock==3.0.5
ipdb
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

__version__ = '1.5.1'
__version__ = '2.0'

# read the contents of your readme file
from os import path
Expand Down
67 changes: 67 additions & 0 deletions tests/fake_elasticsearch/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,70 @@ def test_search_with_scroll_param(self):
self.assertNotEqual(None, result.get('_scroll_id', None))
self.assertEqual(30, len(result.get('hits').get('hits')))
self.assertEqual(100, result.get('hits').get('total'))

def test_search_with_match_query(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_{0}'.format(i)})

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': 'TEST' } } })
self.assertEqual(response['hits']['total'], 10)
hits = response['hits']['hits']
self.assertEqual(len(hits), 10)

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': '3' } } })
self.assertEqual(response['hits']['total'], 1)
hits = response['hits']['hits']
self.assertEqual(len(hits), 1)
self.assertEqual(hits[0]['_source'], {'data': 'test_3'})

def test_search_with_match_query_in_int_list(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': [i, 11, 13]})
response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': 1 } } })
self.assertEqual(response['hits']['total'], 1)
hits = response['hits']['hits']
self.assertEqual(len(hits), 1)
self.assertEqual(hits[0]['_source'], {'data': [1, 11, 13] })

def test_search_with_match_query_in_string_list(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': [str(i), 'two', 'three']})

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'match': {'data': '1' } } })
self.assertEqual(response['hits']['total'], 1)
hits = response['hits']['hits']
self.assertEqual(len(hits), 1)
self.assertEqual(hits[0]['_source'], {'data': ['1', 'two', 'three']})

def test_search_with_term_query(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'data': 'test_{0}'.format(i)})

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'term': {'data': 'TEST' } } })
self.assertEqual(response['hits']['total'], 0)
hits = response['hits']['hits']
self.assertEqual(len(hits), 0)

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'term': {'data': '3' } } })
self.assertEqual(response['hits']['total'], 1)
hits = response['hits']['hits']
self.assertEqual(len(hits), 1)
self.assertEqual(hits[0]['_source'], {'data': 'test_3'})

def test_search_with_bool_query(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'id': i})

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'bool': {'filter': [{'term': {'id': 1}}]}}})
self.assertEqual(response['hits']['total'], 1)
hits = response['hits']['hits']
self.assertEqual(len(hits), 1)

def test_search_with_terms_query(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={'id': i})

response = self.es.search(index='index_for_search', doc_type=DOC_TYPE, body={'query': {'terms': {'id': [1, 2, 3]}}})
self.assertEqual(response['hits']['total'], 3)
hits = response['hits']['hits']
self.assertEqual(len(hits), 3)
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# content of: tox.ini , put in same dir as setup.py
[tox]
envlist =
py27-elasticsearch{1,2,5,6,7}
py36-elasticsearch{1,2,5,6,7}
py37-elasticsearch{1,2,5,6,7}
py38-elasticsearch{1,2,5,6,7}
Expand Down

0 comments on commit c7d2909

Please sign in to comment.