diff --git a/dataprep/connector/connector.py b/dataprep/connector/connector.py index b909cb4a7..c8ccee6a9 100644 --- a/dataprep/connector/connector.py +++ b/dataprep/connector/connector.py @@ -6,8 +6,9 @@ import sys from asyncio import as_completed from pathlib import Path -from typing import Any, Awaitable, Dict, List, Optional, Union, cast - +from typing import Any, Awaitable, Dict, List, Optional, Union, Tuple +from aiohttp.client_reqrep import ClientResponse +from jsonpath_ng import parse as jparse import pandas as pd from aiohttp import ClientSession from jinja2 import Environment, StrictUndefined, Template, UndefinedError @@ -16,7 +17,15 @@ from .errors import InvalidParameterError, RequestError, UniversalParameterOverridden from .implicit_database import ImplicitDatabase, ImplicitTable from .ref import Ref -from .schema import ConfigDef, FieldDefUnion +from .schema import ( + ConfigDef, + FieldDefUnion, + OffsetPaginationDef, + SeekPaginationDef, + PagePaginationDef, + TokenPaginationDef, + TokenLocation, +) from .throttler import OrderedThrottler, ThrottleSession INFO_TEMPLATE = Template( @@ -98,6 +107,7 @@ def __init__( async def query( # pylint: disable=too-many-locals self, table: str, + *, _auth: Optional[Dict[str, Any]] = None, _count: Optional[int] = None, **where: Any, @@ -199,7 +209,7 @@ def show_schema(self, table_name: str) -> pd.DataFrame: new_schema_dict["data_type"].append(schema[k].type) return pd.DataFrame.from_dict(new_schema_dict) - async def _query_imp( # pylint: disable=too-many-locals,too-many-branches + async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, table: str, kwargs: Dict[str, Any], @@ -239,7 +249,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches total = _count n_page = math.ceil(total / max_per_page) - if pagdef.type == "seek": + if isinstance(pagdef, SeekPaginationDef): last_id = 0 dfs = [] # No way to parallelize for seek type @@ -251,6 +261,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches kwargs, _client=client, _throttler=throttler, + _page=i, _auth=_auth, _limit=count, _anchor=last_id - 1, @@ -263,10 +274,36 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches # The API returns empty for this page, maybe we've reached the end break - last_id = int(df[pagdef.seek_id][len(df) - 1]) - 1 + last_id = int(df.iloc[-1, df.columns.get_loc(pagdef.seek_id)]) - 1 # type: ignore dfs.append(df) + elif isinstance(pagdef, TokenPaginationDef): + next_token = None + dfs = [] + # No way to parallelize for seek type + for i in range(n_page): + count = min(total - i * max_per_page, max_per_page) + df, resp = await self._fetch( # type: ignore + itable, + kwargs, + _client=client, + _throttler=throttler, + _page=i, + _auth=_auth, + _limit=count, + _anchor=next_token, + _raw=True, + ) - elif pagdef.type in {"offset", "page"}: + if pagdef.token_location == TokenLocation.Header: + next_token = resp.headers[pagdef.token_accessor] + elif pagdef.token_location == TokenLocation.Body: + # only json body implemented + token_expr = jparse(pagdef.token_accessor) + (token_elem,) = token_expr.find(await resp.json()) + next_token = token_elem.value + + dfs.append(df) + elif isinstance(pagdef, (OffsetPaginationDef, PagePaginationDef)): resps_coros = [] allowed_page = Ref(n_page) for i in range(n_page): @@ -314,11 +351,10 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many- _page: int = 0, _allowed_page: Optional[Ref[int]] = None, _limit: Optional[int] = None, - _anchor: Optional[int] = None, + _anchor: Optional[Any] = None, _auth: Optional[Dict[str, Any]] = None, - ) -> Optional[pd.DataFrame]: - if (_limit is None) != (_anchor is None): - raise ValueError("_limit and _offset should both be None or not None") + _raw: bool = False, + ) -> Union[Optional[pd.DataFrame], Tuple[Optional[pd.DataFrame], ClientResponse]]: reqdef = table.config.request method = reqdef.method @@ -353,17 +389,18 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many- if reqdef.pagination is not None and _limit is not None: pagdef = reqdef.pagination - pag_type = pagdef.type limit_key = pagdef.limit_key - if pag_type == "seek": - anchor = cast(str, pagdef.seek_key) - elif pag_type == "offset": - anchor = cast(str, pagdef.offset_key) - elif pag_type == "page": - anchor = cast(str, pagdef.page_key) + if isinstance(pagdef, SeekPaginationDef): + anchor = pagdef.seek_key + elif isinstance(pagdef, OffsetPaginationDef): + anchor = pagdef.offset_key + elif isinstance(pagdef, PagePaginationDef): + anchor = pagdef.page_key + elif isinstance(pagdef, TokenPaginationDef): + anchor = pagdef.token_key else: - raise ValueError(f"Unknown pagination type {pag_type}.") + raise ValueError(f"Unknown pagination type {pagdef.type}.") if limit_key in req_data["params"]: raise UniversalParameterOverridden(limit_key, "_limit") @@ -371,7 +408,9 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many- if anchor in req_data["params"]: raise UniversalParameterOverridden(anchor, "_offset") - req_data["params"][anchor] = _anchor + + if _anchor is not None: + req_data["params"][anchor] = _anchor await _throttler.acquire(_page) @@ -396,7 +435,10 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many- if len(df) == 0 and _allowed_page is not None and _page is not None: _allowed_page.set(_page) - return None + df = None + + if _raw: + return df, resp else: return df diff --git a/dataprep/connector/implicit_database.py b/dataprep/connector/implicit_database.py index d9c2afc37..52f77fcf7 100644 --- a/dataprep/connector/implicit_database.py +++ b/dataprep/connector/implicit_database.py @@ -23,6 +23,7 @@ "string": str, "float": float, "boolean": bool, + "list": list, } diff --git a/dataprep/connector/ref.py b/dataprep/connector/ref.py index f359f216b..a5b725967 100644 --- a/dataprep/connector/ref.py +++ b/dataprep/connector/ref.py @@ -2,7 +2,7 @@ from typing import TypeVar, Generic -T = TypeVar("T") +T = TypeVar("T") # pylint: disable=invalid-name class Ref(Generic[T]): @@ -16,7 +16,7 @@ def __init__(self, val: T) -> None: self.val = val def __int__(self) -> int: - return int(self.val) + return int(self.val) # type: ignore def __bool__(self) -> bool: return bool(self.val) diff --git a/dataprep/connector/schema/defs.py b/dataprep/connector/schema/defs.py index 8d7e8c732..4f1e2ed83 100644 --- a/dataprep/connector/schema/defs.py +++ b/dataprep/connector/schema/defs.py @@ -7,31 +7,53 @@ from typing import Any, Dict, Optional, Union import requests -from pydantic import Field, root_validator +from pydantic import Field from .base import BaseDef, BaseDefT # pylint: disable=missing-class-docstring,missing-function-docstring -class PaginationDef(BaseDef): - type: str = Field(regex=r"^(offset|seek|page)$") + + +class OffsetPaginationDef(BaseDef): + type: str = Field("offset", const=True) + max_count: int + limit_key: str + offset_key: str + + +class SeekPaginationDef(BaseDef): + type: str = Field("seek", const=True) + max_count: int + limit_key: str + seek_id: str + seek_key: str + + +class PagePaginationDef(BaseDef): + type: str = Field("page", const=True) max_count: int - offset_key: Optional[str] limit_key: str - seek_id: Optional[str] - seek_key: Optional[str] - page_key: Optional[str] - - @root_validator(pre=True) - def check_key_provided(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values["type"] == "offset" and "offsetKey" not in values: - raise ValueError("Pagination type is 'offset' but no offsetKey set.") - elif values["type"] == "seek" and "seekKey" not in values: - raise ValueError("Pagination type is seek but no seekKey set.") - elif values["type"] == "page" and "pageKey" not in values: - raise ValueError("Pagination type is page but no pageKey set.") - else: - return values + page_key: str + + +class TokenLocation(str, Enum): + Header = "header" + Body = "body" + + +class TokenPaginationDef(BaseDef): + type: str = Field("token", const=True) + max_count: int + limit_key: str + token_location: TokenLocation + token_accessor: str + token_key: str + + +PaginationDef = Union[ + OffsetPaginationDef, SeekPaginationDef, PagePaginationDef, TokenPaginationDef +] class FieldDef(BaseDef):