diff --git a/.gitignore b/.gitignore index dd2017e..6e593e4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ bin/ share/ pyvenv.cfg +### Visual Studio ### +.vscode/ + ### Intellij ### # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 diff --git a/elasticmock/fake_elasticsearch.py b/elasticmock/fake_elasticsearch.py index 97c6347..6b580ff 100644 --- a/elasticmock/fake_elasticsearch.py +++ b/elasticmock/fake_elasticsearch.py @@ -7,12 +7,15 @@ import dateutil.parser from elasticsearch import Elasticsearch from elasticsearch.client.utils import query_params +from elasticsearch.client import _normalize_hosts +from elasticsearch.transport import Transport from elasticsearch.exceptions import NotFoundError, RequestError from elasticmock.behaviour.server_failure import server_failure from elasticmock.fake_cluster import FakeClusterClient from elasticmock.fake_indices import FakeIndicesClient -from elasticmock.utilities import extract_ignore_as_iterable, get_random_id, get_random_scroll_id +from elasticmock.utilities import (extract_ignore_as_iterable, get_random_id, + get_random_scroll_id) from elasticmock.utilities.decorator import for_all_methods PY3 = sys.version_info[0] == 3 @@ -228,11 +231,15 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case): value = value.lower() doc_val = doc_source + # Remove boosting + field, *_ = field.split("*") for k in field.split("."): if hasattr(doc_val, k): doc_val = getattr(doc_val, k) + break elif k in doc_val: doc_val = doc_val[k] + break else: return False @@ -247,7 +254,7 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case): if value == val: return True - if isinstance(val, str) and value in val: + if isinstance(val, str) and str(value) in val: return True return False @@ -260,6 +267,7 @@ class FakeElasticsearch(Elasticsearch): def __init__(self, hosts=None, transport_class=None, **kwargs): self.__documents_dict = {} self.__scrolls = {} + self.transport = Transport(_normalize_hosts(hosts), **kwargs) @property def indices(self): @@ -309,10 +317,10 @@ def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None if id is None: id = get_random_id() - elif self.exists(index, doc_type, id, params=params): - doc = self.get(index, id, doc_type, params=params) + elif self.exists(index, id, doc_type=doc_type, params=params): + doc = self.get(index, id, doc_type=doc_type, params=params) version = doc['_version'] + 1 - self.delete(index, doc_type, id) + self.delete(index, id, doc_type=doc_type) self.__documents_dict[index].append({ '_type': doc_type, @@ -344,10 +352,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None): action = next(iter(line.keys())) version = 1 - index = line[action]['_index'] + index = line[action].get('_index') or index doc_type = line[action].get('_type', "_doc") # _type is deprecated in 7.x - if action in ['delete', 'updated'] and not line[action].get("_id"): + if action in ['delete', 'update'] and not line[action].get("_id"): raise RequestError(400, 'action_request_validation_exception', 'missing id') document_id = line[action].get('_id', get_random_id()) @@ -367,7 +375,7 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None): errors = True item[action]["error"] = result else: - self.delete(index, doc_type, document_id, params=params) + self.delete(index, document_id, doc_type=doc_type, params=params) item[action]["result"] = result items.append(item) @@ -392,10 +400,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None): } if not error: item[action]["result"] = result - if self.exists(index, doc_type, document_id, params=params): - doc = self.get(index, document_id, doc_type, params=params) + if self.exists(index, document_id, doc_type=doc_type, params=params): + doc = self.get(index, document_id, doc_type=doc_type, params=params) version = doc['_version'] + 1 - self.delete(index, doc_type, document_id, params=params) + self.delete(index, document_id, doc_type=doc_type, params=params) self.__documents_dict[index].append({ '_type': doc_type, @@ -430,7 +438,7 @@ def _validate_action(self, action, index, document_id, doc_type, params=None): raise NotImplementedError(f"{action} behaviour hasn't been implemented") @query_params('parent', 'preference', 'realtime', 'refresh', 'routing') - def exists(self, index, doc_type, id, params=None, headers=None): + def exists(self, index, id, doc_type=None, params=None, headers=None): result = False if index in self.__documents_dict: for document in self.__documents_dict[index]: @@ -471,6 +479,26 @@ def get(self, index, id, doc_type='_all', params=None, headers=None): } raise NotFoundError(404, json.dumps(error_data)) + @query_params('_source', '_source_exclude', '_source_include', + 'preference', 'realtime', 'refresh', 'routing', + 'stored_fields') + def mget(self, body, index, doc_type='_all', params=None, headers=None): + ids = body.get('ids') + results = [] + for id in ids: + try: + results.append(self.get(index, id, doc_type=doc_type, + params=params, headers=headers)) + except: + pass + if not results: + raise RequestError( + 400, + 'action_request_validation_exception', + 'Validation Failed: 1: no documents to get;' + ) + return {'docs': results} + @query_params('_source', '_source_exclude', '_source_include', 'parent', 'preference', 'realtime', 'refresh', 'routing', 'version', 'version_type') @@ -646,17 +674,20 @@ def scroll(self, scroll_id, params=None, headers=None): @query_params('consistency', 'parent', 'refresh', 'replication', 'routing', 'timeout', 'version', 'version_type') - def delete(self, index, doc_type, id, params=None, headers=None): + def delete(self, index, id, doc_type=None, params=None, headers=None): found = False ignore = extract_ignore_as_iterable(params) if index in self.__documents_dict: for document in self.__documents_dict[index]: - if document.get('_type') == doc_type and document.get('_id') == id: + if document.get('_id') == id: found = True - self.__documents_dict[index].remove(document) - break + if doc_type and document.get('_type') != doc_type: + found = False + if found: + self.__documents_dict[index].remove(document) + break result_dict = { 'found': found, @@ -665,12 +696,13 @@ def delete(self, index, doc_type, id, params=None, headers=None): '_id': id, '_version': 1, } + if found: return result_dict elif params and 404 in ignore: return {'found': False} else: - raise NotFoundError(404, json.dumps(result_dict, default=str)) + raise NotFoundError(404, json.dumps(result_dict)) @query_params('allow_no_indices', 'expand_wildcards', 'ignore_unavailable', 'preference', 'routing') diff --git a/tests/fake_elasticsearch/test_get.py b/tests/fake_elasticsearch/test_get.py index 10220ec..d895370 100644 --- a/tests/fake_elasticsearch/test_get.py +++ b/tests/fake_elasticsearch/test_get.py @@ -60,3 +60,11 @@ def test_should_get_only_document_source_with_id(self): target_doc_source = self.es.get_source(index=INDEX_NAME, doc_type=DOC_TYPE, id=document_id) self.assertEqual(target_doc_source, BODY) + + def test_mget_get_several_documents_by_id(self): + ids = [] + for _ in range(0, 10): + data = self.es.index(index=INDEX_NAME, doc_type=DOC_TYPE, body=BODY) + ids.append(data.get('_id')) + results = self.es.mget(index=INDEX_NAME, body={'ids': ids}) + self.assertEqual(len(results['docs']), 10) diff --git a/tests/fake_elasticsearch/test_search.py b/tests/fake_elasticsearch/test_search.py index 98eb9b2..7b71070 100644 --- a/tests/fake_elasticsearch/test_search.py +++ b/tests/fake_elasticsearch/test_search.py @@ -205,7 +205,7 @@ def test_search_bool_should_match_query(self): } } }) - self.assertEqual(response['hits']['total'], 3) + self.assertEqual(response['hits']['total']['value'], 3) hits = response['hits']['hits'] self.assertEqual(len(hits), 3) self.assertEqual(hits[0]['_source'], {'data': 'test_0'})