diff --git a/elasticmock/fake_elasticsearch.py b/elasticmock/fake_elasticsearch.py index ed2c4d3..284e2e6 100644 --- a/elasticmock/fake_elasticsearch.py +++ b/elasticmock/fake_elasticsearch.py @@ -7,7 +7,7 @@ from elasticsearch.client.utils import query_params from elasticsearch.exceptions import NotFoundError -from elasticmock.utilities import get_random_id +from elasticmock.utilities import get_random_id, get_random_scroll_id PY3 = sys.version_info[0] == 3 @@ -20,6 +20,7 @@ class FakeElasticsearch(Elasticsearch): def __init__(self, hosts=None, transport_class=None, **kwargs): self.__documents_dict = {} + self.__scrolls = {} @query_params() def ping(self, params=None): @@ -183,10 +184,33 @@ def search(self, index=None, doc_type=None, body=None, params=None): for match in matches: match['_score'] = 1.0 hits.append(match) - result['hits']['hits'] = hits + if 'scroll' in params: + result['_scroll_id'] = str(get_random_scroll_id()) + params['size'] = int(params.get('size') if 'size' in params else 10) + params['from'] = int(params.get('from') + params.get('size') if 'from' in params else 0) + self.__scrolls[result.get('_scroll_id')] = { + 'index' : index, + 'doc_type' : doc_type, + 'body' : body, + 'params' : params + } + hits = hits[params.get('from'):params.get('from') + params.get('size')] + + result['hits']['hits'] = hits return result + @query_params('scroll') + def scroll(self, scroll_id, params=None): + scroll = self.__scrolls.pop(scroll_id) + result = self.search( + index = scroll.get('index'), + doc_type = scroll.get('doc_type'), + body = scroll.get('body'), + params = scroll.get('params') + ) + return result + @query_params('consistency', 'parent', 'refresh', 'replication', 'routing', 'timeout', 'version', 'version_type') def delete(self, index, doc_type, id, params=None): diff --git a/elasticmock/utilities/__init__.py b/elasticmock/utilities/__init__.py index 6a7b543..27f129e 100644 --- a/elasticmock/utilities/__init__.py +++ b/elasticmock/utilities/__init__.py @@ -2,10 +2,15 @@ import random import string +import base64 DEFAULT_ELASTICSEARCH_ID_SIZE = 20 CHARSET_FOR_ELASTICSEARCH_ID = string.ascii_letters + string.digits +DEFAULT_ELASTICSEARCH_SEARCHRESULTPHASE_COUNT = 6 def get_random_id(size=DEFAULT_ELASTICSEARCH_ID_SIZE): return ''.join(random.choice(CHARSET_FOR_ELASTICSEARCH_ID) for _ in range(size)) + +def get_random_scroll_id(size=DEFAULT_ELASTICSEARCH_SEARCHRESULTPHASE_COUNT): + return base64.b64encode(''.join(get_random_id() for _ in range(size)).encode()) diff --git a/tests/test_elasticmock.py b/tests/test_elasticmock.py index c668433..7840b15 100644 --- a/tests/test_elasticmock.py +++ b/tests/test_elasticmock.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import unittest - import elasticsearch from elasticsearch.exceptions import NotFoundError @@ -242,6 +241,34 @@ def test_doc_type_can_be_list(self): result = self.es.search(doc_type=doc_types[:2]) self.assertEqual(count_per_doc_type * 2, result.get('hits').get('total')) + def test_search_with_scroll_param(self): + for _ in range(100): + self.es.index(index='groups', doc_type='groups', body={'budget': 1000}) + + result = self.es.search(index='groups', params={'scroll' : '1m', 'size' : 30}) + self.assertNotEqual(None, result.get('_scroll_id', None)) + self.assertEqual(30, len(result.get('hits').get('hits'))) + self.assertEqual(100, result.get('hits').get('total')) + + def test_scrolling(self): + for _ in range(100): + self.es.index(index='groups', doc_type='groups', body={'budget': 1000}) + + result = self.es.search(index='groups', params={'scroll' : '1m', 'size' : 30}) + self.assertNotEqual(None, result.get('_scroll_id', None)) + self.assertEqual(30, len(result.get('hits').get('hits'))) + self.assertEqual(100, result.get('hits').get('total')) + + for _ in range(2): + result = self.es.scroll(scroll_id = result.get('_scroll_id'), scroll = '1m') + self.assertNotEqual(None, result.get('_scroll_id', None)) + self.assertEqual(30, len(result.get('hits').get('hits'))) + self.assertEqual(100, result.get('hits').get('total')) + + result = self.es.scroll(scroll_id = result.get('_scroll_id'), scroll = '1m') + self.assertNotEqual(None, result.get('_scroll_id', None)) + self.assertEqual(10, len(result.get('hits').get('hits'))) + self.assertEqual(100, result.get('hits').get('total')) if __name__ == '__main__': unittest.main()