Skip to content

Commit

Permalink
PYTHON-4584 Add length option to Cursor.to_list for motor compat (#1791)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Aug 14, 2024
1 parent f2f75fc commit adf8817
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 28 deletions.
12 changes: 10 additions & 2 deletions gridfs/asynchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,8 +1892,16 @@ async def next(self) -> AsyncGridOut:
next_file = await super().next()
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)

async def to_list(self) -> list[AsyncGridOut]:
return [x async for x in self] # noqa: C416,RUF100
async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]:
"""Convert the cursor to a list."""
if length is None:
return [x async for x in self] # noqa: C416,RUF100
if length < 1:
raise ValueError("to_list() length must be greater than 0")
ret = []
for _ in range(length):
ret.append(await self.next())
return ret

__anext__ = next

Expand Down
12 changes: 10 additions & 2 deletions gridfs/synchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,8 +1878,16 @@ def next(self) -> GridOut:
next_file = super().next()
return GridOut(self._root_collection, file_document=next_file, session=self.session)

def to_list(self) -> list[GridOut]:
return [x for x in self] # noqa: C416,RUF100
def to_list(self, length: Optional[int] = None) -> list[GridOut]:
"""Convert the cursor to a list."""
if length is None:
return [x for x in self] # noqa: C416,RUF100
if length < 1:
raise ValueError("to_list() length must be greater than 0")
ret = []
for _ in range(length):
ret.append(self.next())
return ret

__next__ = next

Expand Down
27 changes: 21 additions & 6 deletions pymongo/asynchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,17 @@ async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
else:
return None

async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
await self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
Expand Down Expand Up @@ -381,21 +385,32 @@ async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()

async def to_list(self) -> list[_DocumentType]:
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res):
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res


Expand Down
27 changes: 21 additions & 6 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,16 +1260,20 @@ async def next(self) -> _DocumentType:
else:
raise StopAsyncIteration

async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
await self._supports_exhaust()
if self._empty:
return False
if len(self._data) or await self._refresh():
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
Expand All @@ -1286,21 +1290,32 @@ async def __aenter__(self) -> AsyncCursor[_DocumentType]:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()

async def to_list(self) -> list[_DocumentType]:
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res):
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res


Expand Down
27 changes: 21 additions & 6 deletions pymongo/synchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,17 @@ def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
else:
return None

def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
Expand Down Expand Up @@ -381,21 +385,32 @@ def __enter__(self) -> CommandCursor[_DocumentType]:
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()

def to_list(self) -> list[_DocumentType]:
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
To use::
>>> cursor.to_list()
Or, so read at most n items from the cursor::
>>> cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not self._next_batch(res):
if not self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res


Expand Down
27 changes: 21 additions & 6 deletions pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,16 +1258,20 @@ def next(self) -> _DocumentType:
else:
raise StopIteration

def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
self._supports_exhaust()
if self._empty:
return False
if len(self._data) or self._refresh():
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
Expand All @@ -1284,21 +1288,32 @@ def __enter__(self) -> Cursor[_DocumentType]:
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()

def to_list(self) -> list[_DocumentType]:
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
To use::
>>> cursor.to_list()
Or, so read at most n items from the cursor::
>>> cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not self._next_batch(res):
if not self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res


Expand Down
27 changes: 27 additions & 0 deletions test/asynchronous/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,20 @@ async def test_to_list_empty(self):
docs = await c.to_list()
self.assertEqual([], docs)

async def test_to_list_length(self):
coll = self.db.test
await coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
c = coll.find()
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)

c = coll.find(batch_size=2)
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)
docs = await c.to_list(3)
self.assertEqual(len(docs), 2)

@async_client_context.require_change_streams
async def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
Expand All @@ -1417,6 +1431,19 @@ async def test_command_cursor_to_list_empty(self):
docs = await c.to_list()
self.assertEqual([], docs)

@async_client_context.require_change_streams
async def test_command_cursor_to_list_length(self):
db = self.db
await db.drop_collection("test")
await db.test.insert_many([{"foo": 1}, {"foo": 2}])

pipeline = {"$project": {"_id": False, "foo": True}}
result = await db.test.aggregate([pipeline])
self.assertEqual(len(await result.to_list()), 2)

result = await db.test.aggregate([pipeline])
self.assertEqual(len(await result.to_list(1)), 1)


class TestRawBatchCursor(AsyncIntegrationTest):
async def test_find_raw(self):
Expand Down
27 changes: 27 additions & 0 deletions test/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,20 @@ def test_to_list_empty(self):
docs = c.to_list()
self.assertEqual([], docs)

def test_to_list_length(self):
coll = self.db.test
coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
c = coll.find()
docs = c.to_list(3)
self.assertEqual(len(docs), 3)

c = coll.find(batch_size=2)
docs = c.to_list(3)
self.assertEqual(len(docs), 3)
docs = c.to_list(3)
self.assertEqual(len(docs), 2)

@client_context.require_change_streams
def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
Expand All @@ -1408,6 +1422,19 @@ def test_command_cursor_to_list_empty(self):
docs = c.to_list()
self.assertEqual([], docs)

@client_context.require_change_streams
def test_command_cursor_to_list_length(self):
db = self.db
db.drop_collection("test")
db.test.insert_many([{"foo": 1}, {"foo": 2}])

pipeline = {"$project": {"_id": False, "foo": True}}
result = db.test.aggregate([pipeline])
self.assertEqual(len(result.to_list()), 2)

result = db.test.aggregate([pipeline])
self.assertEqual(len(result.to_list(1)), 1)


class TestRawBatchCursor(IntegrationTest):
def test_find_raw(self):
Expand Down
6 changes: 6 additions & 0 deletions test/test_gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ def test_gridfs_find(self):
gout = next(cursor)
self.assertEqual(b"test2+", gout.read())
self.assertRaises(StopIteration, cursor.__next__)
cursor.rewind()
items = cursor.to_list()
self.assertEqual(len(items), 2)
cursor.rewind()
items = cursor.to_list(1)
self.assertEqual(len(items), 1)
cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})

Expand Down

0 comments on commit adf8817

Please sign in to comment.