Skip to content

Commit

Permalink
Merge pull request #63 from carlosgalvez-tiendeo/master
Browse files Browse the repository at this point in the history
Add multi_match
  • Loading branch information
vrcmarcos authored Jan 21, 2021
2 parents 83e25ca + fb4f6ad commit 2ece678
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
61 changes: 61 additions & 0 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class QueryType:
RANGE = 'RANGE'
SHOULD = 'SHOULD'
MINIMUM_SHOULD_MATCH = 'MINIMUM_SHOULD_MATCH'
MULTI_MATCH = 'MULTI_MATCH'

@staticmethod
def get_query_type(type_str):
Expand All @@ -54,6 +55,8 @@ def get_query_type(type_str):
return QueryType.SHOULD
elif type_str == 'minimum_should_match':
return QueryType.MINIMUM_SHOULD_MATCH
elif type_str == 'multi_match':
return QueryType.MULTI_MATCH
else:
raise NotImplementedError(f'type {type_str} is not implemented for QueryType')

Expand Down Expand Up @@ -95,6 +98,10 @@ def _evaluate_for_query_type(self, document):
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.FILTER:
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.MUST:
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.MULTI_MATCH:
return self._evaluate_for_multi_match_query_type(document)
else:
raise NotImplementedError('Fake query evaluation not implemented for query type: %s' % self.type)

Expand Down Expand Up @@ -125,6 +132,25 @@ def _evaluate_for_field(self, document, ignore_case):
break
return return_val

def _evaluate_for_fields(self, document):
doc_source = document['_source']
return_val = False
value = self.condition.get('query')
if not value:
return return_val
fields = self.condition.get('fields', [])
for field in fields:
return_val = self._compare_value_for_field(
doc_source,
field,
value,
True
)
if return_val:
break

return return_val

def _evaluate_for_range_query_type(self, document):
for field, comparisons in self.condition.items():
doc_val = document['_source']
Expand Down Expand Up @@ -180,6 +206,9 @@ def _evaluate_for_compound_query_type(self, document):

return return_val

def _evaluate_for_multi_match_query_type(self, document):
return self._evaluate_for_fields(document)

def _compare_value_for_field(self, doc_source, field, value, ignore_case):
if ignore_case and isinstance(value, str):
value = value.lower()
Expand Down Expand Up @@ -409,6 +438,38 @@ def count(self, index=None, doc_type=None, body=None, params=None, headers=None)
def _get_fake_query_condition(self, query_type_str, condition):
return FakeQueryCondition(QueryType.get_query_type(query_type_str), condition)

@query_params(
"ccs_minimize_roundtrips",
"max_concurrent_searches",
"max_concurrent_shard_requests",
"pre_filter_shard_size",
"rest_total_hits_as_int",
"search_type",
"typed_keys",
)
def msearch(self, body, index=None, doc_type=None, params=None, headers=None):
def grouped(iterable):
if len(iterable) % 2 != 0:
raise Exception('Malformed body')
iterator = iter(iterable)
while True:
try:
yield (next(iterator)['index'], next(iterator))
except StopIteration:
break

responses = []
took = 0
for ind, query in grouped(body):
response = self.search(index=ind, body=query)
took += response['took']
responses.append(response)
result = {
'took': took,
'responses': responses
}
return result

@query_params('_source', '_source_exclude', '_source_include',
'allow_no_indices', 'analyze_wildcard', 'analyzer', 'default_operator',
'df', 'expand_wildcards', 'explain', 'fielddata_fields', 'fields',
Expand Down
65 changes: 65 additions & 0 deletions tests/fake_elasticsearch/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,71 @@ def test_query_on_nested_data(self):
doc = response['hits']['hits'][0]['_source']
self.assertEqual(i, doc['id'])


def test_search_with_bool_query_and_multi_match(self):
for i in range(0, 10):
self.es.index(index='index_for_search', doc_type=DOC_TYPE, body={
'data': 'test_{0}'.format(i) if i % 2 == 0 else None,
'data2': 'test_{0}'.format(i) if (i+1) % 2 == 0 else None
})

search_body = {
"query": {
"bool": {
"must": {
"multi_match": {
"query": "test",
"fields": ["data", "data2"]
}
}
}
}
}
response = self.es.search(index='index_for_search', doc_type=DOC_TYPE,
body=search_body)
self.assertEqual(response['hits']['total'], 10)
hits = response['hits']['hits']
self.assertEqual(len(hits), 10)

def test_msearch(self):
for i in range(0, 10):
self.es.index(index='index_for_search1', doc_type=DOC_TYPE, body={
'data': 'test_{0}'.format(i) if i % 2 == 0 else None,
'data2': 'test_{0}'.format(i) if (i+1) % 2 == 0 else None
})
for i in range(0, 10):
self.es.index(index='index_for_search2', doc_type=DOC_TYPE, body={
'data': 'test_{0}'.format(i) if i % 2 == 0 else None,
'data2': 'test_{0}'.format(i) if (i+1) % 2 == 0 else None
})

search_body = {
"query": {
"bool": {
"must": {
"multi_match": {
"query": "test",
"fields": ["data", "data2"]
}
}
}
}
}
body = []
body.append({'index': 'index_for_search1'})
body.append(search_body)
body.append({'index': 'index_for_search2'})
body.append(search_body)

result = self.es.msearch(index='index_for_search', body=body)
response1, response2 = result['responses']
self.assertEqual(response1['hits']['total'], 10)
hits1 = response1['hits']['hits']
self.assertEqual(len(hits1), 10)
self.assertEqual(response2['hits']['total'], 10)
hits2 = response2['hits']['hits']
self.assertEqual(len(hits2), 10)

@parameterized.expand(
[
(
Expand Down

0 comments on commit 2ece678

Please sign in to comment.