Skip to content

Commit

Permalink
feat(rest_api): custom client for specific resources
Browse files Browse the repository at this point in the history
  • Loading branch information
joscha committed Nov 21, 2024
1 parent 9a49868 commit 8892998
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 10 deletions.
29 changes: 21 additions & 8 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,17 @@ def create_resources(
incremental_cursor_transform,
) = setup_incremental_object(request_params, endpoint_config.get("incremental"))

merged_client_config: ClientConfig = {
**client_config,
**endpoint_resource.pop("client", {}),
}

client = RESTClient(
base_url=client_config["base_url"],
headers=client_config.get("headers"),
auth=create_auth(client_config.get("auth")),
paginator=create_paginator(client_config.get("paginator")),
session=client_config.get("session"),
base_url=merged_client_config["base_url"],
headers=merged_client_config.get("headers"),
auth=create_auth(merged_client_config.get("auth")),
paginator=create_paginator(merged_client_config.get("paginator")),
session=merged_client_config.get("session"),
)

hooks = create_response_hooks(endpoint_config.get("response_actions"))
Expand Down Expand Up @@ -405,14 +410,22 @@ def paginate_dependent_resource(

def _validate_config(config: RESTAPIConfig) -> None:
c = deepcopy(config)
client_config = c.get("client")
_mask_client_config_auth(c.get("client"))
resources = c.get("resources")
if resources:
for resource in resources:
if isinstance(resource, str) or isinstance(resource, DltResource):
continue
_mask_client_config_auth(resource.get("client"))
validate_dict(RESTAPIConfig, c, path=".")


def _mask_client_config_auth(client_config: Optional[ClientConfig]) -> None:
if client_config:
auth = client_config.get("auth")
if auth:
auth = _mask_secrets(auth)

validate_dict(RESTAPIConfig, c, path=".")


def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
# skip AuthBase (derived from requests lib) or shorthand notation
Expand Down
9 changes: 7 additions & 2 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,18 @@ class OAuth2ClientCredentialsConfig(AuthTypeConfig, total=False):
]


class ClientConfig(TypedDict, total=False):
base_url: str
class BaseClientConfig(TypedDict, total=False):
base_url: Optional[str]
headers: Optional[Dict[str, str]]
auth: Optional[AuthConfig]
paginator: Optional[PaginatorConfig]
session: Optional[Session]


class ClientConfig(BaseClientConfig, total=False):
base_url: str # type: ignore[misc]


class IncrementalRESTArgs(IncrementalArgs, total=False):
convert: Optional[Callable[..., Any]]

Expand Down Expand Up @@ -279,6 +283,7 @@ class ResourceBase(TResourceHintsBase, total=False):
selected: Optional[bool]
parallelized: Optional[bool]
processing_steps: Optional[List[ProcessingSteps]]
client: Optional[BaseClientConfig]


class EndpointResourceBase(ResourceBase, total=False):
Expand Down
30 changes: 30 additions & 0 deletions docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ A resource configuration is used to define a [dlt resource](../../../general-usa
- `include_from_parent`: A list of fields from the parent resource to be included in the resource output. See the [resource relationships](#include-fields-from-the-parent-resource) section for more details.
- `processing_steps`: A list of [processing steps](#processing-steps-filter-and-transform-data) to filter and transform the data.
- `selected`: A flag to indicate if the resource is selected for loading. This could be useful when you want to load data only from child resources and not from the parent resource.
- `client`: An optional `ClientConfig`. A config passed here is merged with the one given in the [dlt resource](../../../general-usage/resource.md) definition.

You can also pass additional resource parameters that will be used to configure the dlt resource. See [dlt resource API reference](../../../api_reference/extract/decorators#resource) for more details.

Expand Down Expand Up @@ -549,6 +550,35 @@ config = {
}
```

You can also overwrite the client configuration for a specific endpoint only.
For example to change the auth method used for the resource endpoint.
The configurations will be merged.

```py
from dlt.sources.helpers.rest_client.auth import HttpBasicAuth

config = {
"client": {
"auth": {
"type": "bearer",
"token": dlt.secrets["your_api_token"],
}
},
"resources": [
"resource-using-bearer-auth",
{
"name": "my-resource-with-special-auth",
"client": {
"auth": HttpBasicAuth("user", dlt.secrets["your_basic_auth_password"])
},
# ...
}
]
# ...
}
```
This would use `Bearer` auth as defined in the `client` for `resource-using-bearer-auth` and `Http Basic` auth for `my-resource-with-special-auth`.

:::warning
Make sure to store your access tokens and other sensitive information in the `secrets.toml` file and never commit it to the version control system.
:::
Expand Down
13 changes: 13 additions & 0 deletions tests/sources/rest_api/configurations/source_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,19 @@ def repositories():
repositories(),
],
},
{
"client": {
"base_url": "https://test",
},
"resources": [
{
"name": "test",
"client": {
"auth": HttpBasicAuth("", "BASIC_AUTH_TOKEN"),
},
}
],
},
]


Expand Down
27 changes: 27 additions & 0 deletions tests/sources/rest_api/integration/test_response_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,30 @@ def add_field(response: Response, *args, **kwargs) -> Response:
mock_response_hook_2.assert_called_once()

assert all(record["custom_field"] == "foobar" for record in data)


def test_auth_overwrites_for_specific_endpoints(mock_api_server, mocker):
def custom_hook(response: Response, *args, **kwargs) -> Response:
assert response.request.headers["my_header"] == "overwritten"
return response

mock_response_hook = mocker.Mock(side_effect=custom_hook)
mock_source = rest_api_source(
{
"client": {"base_url": "https://api.example.com", "headers": {"my_header": "original"}},
"resources": [
{
"name": "posts",
"client": {"headers": {"my_header": "overwritten"}},
"endpoint": {
"response_actions": [
mock_response_hook,
],
},
},
],
}
)

list(mock_source.with_resources("posts").add_limit(1))
mock_response_hook.assert_called_once()

0 comments on commit 8892998

Please sign in to comment.