Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mget #64

Merged
merged 9 commits into from
Mar 2, 2021
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ bin/
share/
pyvenv.cfg

### Visual Studio ###
.vscode/

### Intellij ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
Expand Down
66 changes: 49 additions & 17 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import dateutil.parser
from elasticsearch import Elasticsearch
from elasticsearch.client.utils import query_params
from elasticsearch.client import _normalize_hosts
from elasticsearch.transport import Transport
from elasticsearch.exceptions import NotFoundError, RequestError

from elasticmock.behaviour.server_failure import server_failure
from elasticmock.fake_cluster import FakeClusterClient
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.utilities import extract_ignore_as_iterable, get_random_id, get_random_scroll_id
from elasticmock.utilities import (extract_ignore_as_iterable, get_random_id,
get_random_scroll_id)
from elasticmock.utilities.decorator import for_all_methods

PY3 = sys.version_info[0] == 3
Expand Down Expand Up @@ -228,11 +231,15 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case):
value = value.lower()

doc_val = doc_source
# Remove boosting
field, *_ = field.split("*")
for k in field.split("."):
if hasattr(doc_val, k):
doc_val = getattr(doc_val, k)
break
elif k in doc_val:
doc_val = doc_val[k]
break
else:
return False

Expand All @@ -247,7 +254,7 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case):

if value == val:
return True
if isinstance(val, str) and value in val:
if isinstance(val, str) and str(value) in val:
return True

return False
Expand All @@ -260,6 +267,7 @@ class FakeElasticsearch(Elasticsearch):
def __init__(self, hosts=None, transport_class=None, **kwargs):
self.__documents_dict = {}
self.__scrolls = {}
self.transport = Transport(_normalize_hosts(hosts), **kwargs)

@property
def indices(self):
Expand Down Expand Up @@ -309,10 +317,10 @@ def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None

if id is None:
id = get_random_id()
elif self.exists(index, doc_type, id, params=params):
doc = self.get(index, id, doc_type, params=params)
elif self.exists(index, id, doc_type=doc_type, params=params):
doc = self.get(index, id, doc_type=doc_type, params=params)
version = doc['_version'] + 1
self.delete(index, doc_type, id)
self.delete(index, id, doc_type=doc_type)

self.__documents_dict[index].append({
'_type': doc_type,
Expand Down Expand Up @@ -344,10 +352,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
action = next(iter(line.keys()))

version = 1
index = line[action]['_index']
index = line[action].get('_index') or index
doc_type = line[action].get('_type', "_doc") # _type is deprecated in 7.x

if action in ['delete', 'updated'] and not line[action].get("_id"):
if action in ['delete', 'update'] and not line[action].get("_id"):
raise RequestError(400, 'action_request_validation_exception', 'missing id')

document_id = line[action].get('_id', get_random_id())
Expand All @@ -367,7 +375,7 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
errors = True
item[action]["error"] = result
else:
self.delete(index, doc_type, document_id, params=params)
self.delete(index, document_id, doc_type=doc_type, params=params)
item[action]["result"] = result
items.append(item)

Expand All @@ -392,10 +400,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
}
if not error:
item[action]["result"] = result
if self.exists(index, doc_type, document_id, params=params):
doc = self.get(index, document_id, doc_type, params=params)
if self.exists(index, document_id, doc_type=doc_type, params=params):
doc = self.get(index, document_id, doc_type=doc_type, params=params)
version = doc['_version'] + 1
self.delete(index, doc_type, document_id, params=params)
self.delete(index, document_id, doc_type=doc_type, params=params)

self.__documents_dict[index].append({
'_type': doc_type,
Expand Down Expand Up @@ -430,7 +438,7 @@ def _validate_action(self, action, index, document_id, doc_type, params=None):
raise NotImplementedError(f"{action} behaviour hasn't been implemented")

@query_params('parent', 'preference', 'realtime', 'refresh', 'routing')
def exists(self, index, doc_type, id, params=None, headers=None):
def exists(self, index, id, doc_type=None, params=None, headers=None):
result = False
if index in self.__documents_dict:
for document in self.__documents_dict[index]:
Expand Down Expand Up @@ -471,6 +479,26 @@ def get(self, index, id, doc_type='_all', params=None, headers=None):
}
raise NotFoundError(404, json.dumps(error_data))

@query_params('_source', '_source_exclude', '_source_include',
'preference', 'realtime', 'refresh', 'routing',
'stored_fields')
def mget(self, body, index, doc_type='_all', params=None, headers=None):
ids = body.get('ids')
results = []
for id in ids:
try:
results.append(self.get(index, id, doc_type=doc_type,
params=params, headers=headers))
except:
pass
if not results:
raise RequestError(
400,
'action_request_validation_exception',
'Validation Failed: 1: no documents to get;'
)
return {'docs': results}

@query_params('_source', '_source_exclude', '_source_include', 'parent',
'preference', 'realtime', 'refresh', 'routing', 'version',
'version_type')
Expand Down Expand Up @@ -646,17 +674,20 @@ def scroll(self, scroll_id, params=None, headers=None):

@query_params('consistency', 'parent', 'refresh', 'replication', 'routing',
'timeout', 'version', 'version_type')
def delete(self, index, doc_type, id, params=None, headers=None):
def delete(self, index, id, doc_type=None, params=None, headers=None):

found = False
ignore = extract_ignore_as_iterable(params)

if index in self.__documents_dict:
for document in self.__documents_dict[index]:
if document.get('_type') == doc_type and document.get('_id') == id:
if document.get('_id') == id:
found = True
self.__documents_dict[index].remove(document)
break
if doc_type and document.get('_type') != doc_type:
found = False
if found:
self.__documents_dict[index].remove(document)
break

result_dict = {
'found': found,
Expand All @@ -665,12 +696,13 @@ def delete(self, index, doc_type, id, params=None, headers=None):
'_id': id,
'_version': 1,
}

if found:
return result_dict
elif params and 404 in ignore:
return {'found': False}
else:
raise NotFoundError(404, json.dumps(result_dict, default=str))
raise NotFoundError(404, json.dumps(result_dict))

@query_params('allow_no_indices', 'expand_wildcards', 'ignore_unavailable',
'preference', 'routing')
Expand Down
8 changes: 8 additions & 0 deletions tests/fake_elasticsearch/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ def test_should_get_only_document_source_with_id(self):
target_doc_source = self.es.get_source(index=INDEX_NAME, doc_type=DOC_TYPE, id=document_id)

self.assertEqual(target_doc_source, BODY)

def test_mget_get_several_documents_by_id(self):
ids = []
for _ in range(0, 10):
data = self.es.index(index=INDEX_NAME, doc_type=DOC_TYPE, body=BODY)
ids.append(data.get('_id'))
results = self.es.mget(index=INDEX_NAME, body={'ids': ids})
self.assertEqual(len(results['docs']), 10)
2 changes: 1 addition & 1 deletion tests/fake_elasticsearch/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_search_bool_should_match_query(self):
}
}
})
self.assertEqual(response['hits']['total'], 3)
self.assertEqual(response['hits']['total']['value'], 3)
hits = response['hits']['hits']
self.assertEqual(len(hits), 3)
self.assertEqual(hits[0]['_source'], {'data': 'test_0'})
Expand Down