Skip to content

Commit

Permalink
refactor(cursor_pagination): encapsulate encode/decode cursor logic
Browse files Browse the repository at this point in the history
  • Loading branch information
gromsterus committed Dec 27, 2023
1 parent 0c425b2 commit ce1e8e9
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions src/paginate_any/cursor_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ def _get_cursor_value(self, row: RowT, cursor: CurrentCursor) -> str:
logger.error(msg)
raise CursorValueErr(msg)
cursor_values.append(value)
return self._encode_cursor(cursor_values)

values = _cursor_encode(cursor_values)
return urlsafe_b64encode(values).decode('utf-8')
@staticmethod
def _encode_cursor(cursor_values: list[Any]) -> str:
return urlsafe_b64encode(_cursor_encode(cursor_values)).decode('utf-8')

def _get_field_val(self, row: RowT, field: str) -> Any:
err: Exception | None
Expand Down Expand Up @@ -196,29 +198,26 @@ def _get_after_and_before(
if after_raw is not None and before_raw is not None:
raise MultipleCursorsErr()

before = self._decode_cursor(before_raw)
after = self._decode_cursor(after_raw)
value = after or before
if value is None:
if after_raw:
cursor_values = self._decode_cursor(after_raw)
elif before_raw:
cursor_values = self._decode_cursor(before_raw)
else:
return None, None

try:
cursor_values = _cursor_decode(value)
except _cursor_decode_err as exc:
msg = 'Invalid cursor value'
raise CursorValueErr(detail=msg) from exc
if len(cursor_values) != len(sort_fields):
raise CursorValueErr()
return (cursor_values, None) if after else (None, cursor_values)
return (cursor_values, None) if after_raw else (None, cursor_values)

@staticmethod
def _decode_cursor(s: CursorRawT) -> bytes | None:
if s is None:
return None
def _decode_cursor(s: str | bytes) -> CursorValuesT:
try:
return urlsafe_b64decode(s)
return _cursor_decode(urlsafe_b64decode(s))
except binascii.Error as exc:
raise CursorValueErr(detail='Invalid base64 value') from exc
except _cursor_decode_err as exc:
msg = 'Invalid cursor value'
raise CursorValueErr(detail=msg) from exc

async def _get_rows(
self,
Expand Down

0 comments on commit ce1e8e9

Please sign in to comment.