From 889299817004e2d23f609e3c37612e11ffc9890f Mon Sep 17 00:00:00 2001 From: Joscha Feth Date: Thu, 21 Nov 2024 01:49:08 +0000 Subject: [PATCH] feat(rest_api): custom client for specific resources --- dlt/sources/rest_api/__init__.py | 29 +++++++++++++----- dlt/sources/rest_api/typing.py | 9 ++++-- .../verified-sources/rest_api/basic.md | 30 +++++++++++++++++++ .../rest_api/configurations/source_configs.py | 13 ++++++++ .../integration/test_response_actions.py | 27 +++++++++++++++++ 5 files changed, 98 insertions(+), 10 deletions(-) diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index ed55f71e10..a634dee0e8 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -263,12 +263,17 @@ def create_resources( incremental_cursor_transform, ) = setup_incremental_object(request_params, endpoint_config.get("incremental")) + merged_client_config: ClientConfig = { + **client_config, + **endpoint_resource.pop("client", {}), + } + client = RESTClient( - base_url=client_config["base_url"], - headers=client_config.get("headers"), - auth=create_auth(client_config.get("auth")), - paginator=create_paginator(client_config.get("paginator")), - session=client_config.get("session"), + base_url=merged_client_config["base_url"], + headers=merged_client_config.get("headers"), + auth=create_auth(merged_client_config.get("auth")), + paginator=create_paginator(merged_client_config.get("paginator")), + session=merged_client_config.get("session"), ) hooks = create_response_hooks(endpoint_config.get("response_actions")) @@ -405,14 +410,22 @@ def paginate_dependent_resource( def _validate_config(config: RESTAPIConfig) -> None: c = deepcopy(config) - client_config = c.get("client") + _mask_client_config_auth(c.get("client")) + resources = c.get("resources") + if resources: + for resource in resources: + if isinstance(resource, str) or isinstance(resource, DltResource): + continue + _mask_client_config_auth(resource.get("client")) + validate_dict(RESTAPIConfig, c, path=".") + + +def _mask_client_config_auth(client_config: Optional[ClientConfig]) -> None: if client_config: auth = client_config.get("auth") if auth: auth = _mask_secrets(auth) - validate_dict(RESTAPIConfig, c, path=".") - def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: # skip AuthBase (derived from requests lib) or shorthand notation diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py index d4cea892a3..d6516289d8 100644 --- a/dlt/sources/rest_api/typing.py +++ b/dlt/sources/rest_api/typing.py @@ -200,14 +200,18 @@ class OAuth2ClientCredentialsConfig(AuthTypeConfig, total=False): ] -class ClientConfig(TypedDict, total=False): - base_url: str +class BaseClientConfig(TypedDict, total=False): + base_url: Optional[str] headers: Optional[Dict[str, str]] auth: Optional[AuthConfig] paginator: Optional[PaginatorConfig] session: Optional[Session] +class ClientConfig(BaseClientConfig, total=False): + base_url: str # type: ignore[misc] + + class IncrementalRESTArgs(IncrementalArgs, total=False): convert: Optional[Callable[..., Any]] @@ -279,6 +283,7 @@ class ResourceBase(TResourceHintsBase, total=False): selected: Optional[bool] parallelized: Optional[bool] processing_steps: Optional[List[ProcessingSteps]] + client: Optional[BaseClientConfig] class EndpointResourceBase(ResourceBase, total=False): diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md index d23f3f139e..5666430ab3 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md @@ -308,6 +308,7 @@ A resource configuration is used to define a [dlt resource](../../../general-usa - `include_from_parent`: A list of fields from the parent resource to be included in the resource output. See the [resource relationships](#include-fields-from-the-parent-resource) section for more details. - `processing_steps`: A list of [processing steps](#processing-steps-filter-and-transform-data) to filter and transform the data. - `selected`: A flag to indicate if the resource is selected for loading. This could be useful when you want to load data only from child resources and not from the parent resource. +- `client`: An optional `ClientConfig`. A config passed here is merged with the one given in the [dlt resource](../../../general-usage/resource.md) definition. You can also pass additional resource parameters that will be used to configure the dlt resource. See [dlt resource API reference](../../../api_reference/extract/decorators#resource) for more details. @@ -549,6 +550,35 @@ config = { } ``` +You can also overwrite the client configuration for a specific endpoint only. +For example to change the auth method used for the resource endpoint. +The configurations will be merged. + +```py +from dlt.sources.helpers.rest_client.auth import HttpBasicAuth + +config = { + "client": { + "auth": { + "type": "bearer", + "token": dlt.secrets["your_api_token"], + } + }, + "resources": [ + "resource-using-bearer-auth", + { + "name": "my-resource-with-special-auth", + "client": { + "auth": HttpBasicAuth("user", dlt.secrets["your_basic_auth_password"]) + }, + # ... + } + ] + # ... +} +``` +This would use `Bearer` auth as defined in the `client` for `resource-using-bearer-auth` and `Http Basic` auth for `my-resource-with-special-auth`. + :::warning Make sure to store your access tokens and other sensitive information in the `secrets.toml` file and never commit it to the version control system. ::: diff --git a/tests/sources/rest_api/configurations/source_configs.py b/tests/sources/rest_api/configurations/source_configs.py index 705a42637c..d05dd6d803 100644 --- a/tests/sources/rest_api/configurations/source_configs.py +++ b/tests/sources/rest_api/configurations/source_configs.py @@ -395,6 +395,19 @@ def repositories(): repositories(), ], }, + { + "client": { + "base_url": "https://test", + }, + "resources": [ + { + "name": "test", + "client": { + "auth": HttpBasicAuth("", "BASIC_AUTH_TOKEN"), + }, + } + ], + }, ] diff --git a/tests/sources/rest_api/integration/test_response_actions.py b/tests/sources/rest_api/integration/test_response_actions.py index 1ec8058a86..79a55903ec 100644 --- a/tests/sources/rest_api/integration/test_response_actions.py +++ b/tests/sources/rest_api/integration/test_response_actions.py @@ -316,3 +316,30 @@ def add_field(response: Response, *args, **kwargs) -> Response: mock_response_hook_2.assert_called_once() assert all(record["custom_field"] == "foobar" for record in data) + + +def test_auth_overwrites_for_specific_endpoints(mock_api_server, mocker): + def custom_hook(response: Response, *args, **kwargs) -> Response: + assert response.request.headers["my_header"] == "overwritten" + return response + + mock_response_hook = mocker.Mock(side_effect=custom_hook) + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com", "headers": {"my_header": "original"}}, + "resources": [ + { + "name": "posts", + "client": {"headers": {"my_header": "overwritten"}}, + "endpoint": { + "response_actions": [ + mock_response_hook, + ], + }, + }, + ], + } + ) + + list(mock_source.with_resources("posts").add_limit(1)) + mock_response_hook.assert_called_once()