Skip to content

Commit

Permalink
Prefer httpx.Auth Instead of api_token
Browse files Browse the repository at this point in the history
It is still possible to use the api_token parameter.
In this case assume an API token was copied from the
user profile of a Dataverse instance.

If both are specified, an api_token and an explicit auth method,
warn the user and use the auth method.

Closes #192.
  • Loading branch information
shoeffner committed Jul 19, 2024
1 parent 0702ee8 commit 3e73fbf
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 34 deletions.
125 changes: 91 additions & 34 deletions pyDataverse/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import Any, Dict, Optional
import httpx
import subprocess as sp
from warnings import warn

from httpx import ConnectError, Response

from pyDataverse.auth import ApiTokenAuth
from pyDataverse.exceptions import (
ApiAuthorizationError,
ApiUrlError,
Expand Down Expand Up @@ -41,6 +43,8 @@ def __init__(
base_url: str,
api_token: Optional[str] = None,
api_version: str = "latest",
*,
auth: Optional[httpx.Auth] = None,
):
"""Init an Api() class.
Expand All @@ -51,17 +55,55 @@ def __init__(
----------
base_url : str
Base url for Dataverse api.
api_token : str
Api token for Dataverse api.
api_token : str | None
API token for Dataverse API. If you provide an :code:`api_token`, we
assume it is an API token as retrieved via your Dataverse instance
user profile.
We recommend using the :code:`auth` argument instead.
To retain the current behaviour with the :code:`auth` argument, change
.. code-block:: python
Api("https://demo.dataverse.org", "my_token")
to
.. code-block:: python
from pyDataverse.auth import ApiTokenAuth
Api("https://demo.dataverse.org", auth=ApiTokenAuth("my_token"))
If you are using an OIDC/OAuth 2.0 Bearer token, please use the :code:`auth`
parameter with the :py:class:`.auth.BearerTokenAuth`.
api_version : str
The version string of the Dataverse API or :code:`latest`, e.g.,
:code:`v1`. Defaults to :code:`latest`, which drops the version from
the API urls.
auth : httpx.Auth | None
You can provide any authentication mechanism you like to connect to
your Dataverse instance. The most common mechanisms are implemented
in :py:mod:`.auth`, but if one is missing, you can use your own
`httpx.Auth`-compatible class. For more information, have a look at
`httpx' Authentication docs
<https://www.python-httpx.org/advanced/authentication/>`_.
Examples
-------
Create an Api connection::
.. code-block::
>>> from pyDataverse.api import Api
>>> base_url = 'http://demo.dataverse.org'
>>> api = Api(base_url)
.. code-block::
>>> from pyDataverse.api import Api
>>> from pyDataverse.auth import ApiTokenAuth
>>> base_url = 'http://demo.dataverse.org'
>>> api = Api(base_url, ApiTokenAuth('my_api_token'))
"""
if not isinstance(base_url, str):
raise ApiUrlError("base_url {0} is not a string.".format(base_url))
Expand All @@ -73,10 +115,19 @@ def __init__(
raise ApiUrlError("api_version {0} is not a string.".format(api_version))
self.api_version = api_version

if api_token:
if not isinstance(api_token, str):
raise ApiAuthorizationError("Api token passed is not a string.")
self.auth = auth
self.api_token = api_token
if api_token is not None:
if auth is None:
self.auth = ApiTokenAuth(api_token)
else:
self.api_token = None
warn(
UserWarning(
"You provided both, an api_token and a custom auth "
"method. We will only use the auth method."
)
)

if self.base_url:
if self.api_version == "latest":
Expand Down Expand Up @@ -119,21 +170,21 @@ def get_request(self, url, params=None, auth=False):
Response object of requests library.
"""
params = {}
params["User-Agent"] = "pydataverse"
if self.api_token:
params["key"] = str(self.api_token)
headers = {}
headers["User-Agent"] = "pydataverse"

if self.client is None:
return self._sync_request(
method=httpx.get,
url=url,
headers=headers,
params=params,
)
else:
return self._async_request(
method=self.client.get,
url=url,
headers=headers,
params=params,
)

Expand Down Expand Up @@ -162,10 +213,8 @@ def post_request(self, url, data=None, auth=False, params=None, files=None):
Response object of requests library.
"""
params = {}
params["User-Agent"] = "pydataverse"
if self.api_token:
params["key"] = self.api_token
headers = {}
headers["User-Agent"] = "pydataverse"

if isinstance(data, str):
data = json.loads(data)
Expand All @@ -175,6 +224,7 @@ def post_request(self, url, data=None, auth=False, params=None, files=None):
method=httpx.post,
url=url,
json=data,
headers=headers,
params=params,
files=files,
)
Expand All @@ -183,6 +233,7 @@ def post_request(self, url, data=None, auth=False, params=None, files=None):
method=self.client.post,
url=url,
json=data,
headers=headers,
params=params,
files=files,
)
Expand All @@ -208,10 +259,8 @@ def put_request(self, url, data=None, auth=False, params=None):
Response object of requests library.
"""
params = {}
params["User-Agent"] = "pydataverse"
if self.api_token:
params["key"] = self.api_token
headers = {}
headers["User-Agent"] = "pydataverse"

if isinstance(data, str):
data = json.loads(data)
Expand All @@ -221,13 +270,15 @@ def put_request(self, url, data=None, auth=False, params=None):
method=httpx.put,
url=url,
json=data,
headers=headers,
params=params,
)
else:
return self._async_request(
method=self.client.put,
url=url,
json=data,
headers=headers,
params=params,
)

Expand All @@ -250,21 +301,21 @@ def delete_request(self, url, auth=False, params=None):
Response object of requests library.
"""
params = {}
params["User-Agent"] = "pydataverse"
if self.api_token:
params["key"] = self.api_token
headers = {}
headers["User-Agent"] = "pydataverse"

if self.client is None:
return self._sync_request(
method=httpx.delete,
url=url,
headers=headers,
params=params,
)
else:
return self._async_request(
method=self.client.delete,
url=url,
headers=headers,
params=params,
)

Expand Down Expand Up @@ -292,7 +343,7 @@ def _sync_request(
kwargs = self._filter_kwargs(kwargs)

try:
resp = method(**kwargs, follow_redirects=True, timeout=None)
resp = method(**kwargs, auth=self.auth, follow_redirects=True, timeout=None)
if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
Expand Down Expand Up @@ -335,7 +386,7 @@ async def _async_request(
kwargs = self._filter_kwargs(kwargs)

try:
resp = await method(**kwargs)
resp = await method(**kwargs, auth=self.auth)

if resp.status_code == 401:
error_msg = resp.json()["message"]
Expand Down Expand Up @@ -408,9 +459,9 @@ class DataAccessApi(Api):
"""

def __init__(self, base_url, api_token=None):
def __init__(self, base_url, api_token=None, *, auth=None):
"""Init an DataAccessApi() class."""
super().__init__(base_url, api_token)
super().__init__(base_url, api_token, auth=auth)
if base_url:
self.base_url_api_data_access = "{0}/access".format(self.base_url_api)
else:
Expand Down Expand Up @@ -628,9 +679,9 @@ class MetricsApi(Api):
"""

def __init__(self, base_url, api_token=None, api_version="latest"):
def __init__(self, base_url, api_token=None, api_version="latest", *, auth=None):
"""Init an MetricsApi() class."""
super().__init__(base_url, api_token, api_version)
super().__init__(base_url, api_token, api_version, auth=auth)
if base_url:
self.base_url_api_metrics = "{0}/api/info/metrics".format(self.base_url)
else:
Expand Down Expand Up @@ -729,7 +780,7 @@ class NativeApi(Api):
"""

def __init__(self, base_url: str, api_token=None, api_version="v1"):
def __init__(self, base_url: str, api_token=None, api_version="v1", *, auth=None):
"""Init an Api() class.
Scheme, host and path combined create the base-url for the api.
Expand All @@ -741,7 +792,7 @@ def __init__(self, base_url: str, api_token=None, api_version="v1"):
Api version of Dataverse native api. Default is `v1`.
"""
super().__init__(base_url, api_token, api_version)
super().__init__(base_url, api_token, api_version, auth=auth)
self.base_url_api_native = self.base_url_api

def get_dataverse(self, identifier, auth=False):
Expand Down Expand Up @@ -2402,9 +2453,9 @@ class SearchApi(Api):
"""

def __init__(self, base_url, api_token=None, api_version="latest"):
def __init__(self, base_url, api_token=None, api_version="latest", *, auth=None):
"""Init an SearchApi() class."""
super().__init__(base_url, api_token, api_version)
super().__init__(base_url, api_token, api_version, auth=auth)
if base_url:
self.base_url_api_search = "{0}/search?q=".format(self.base_url_api)
else:
Expand Down Expand Up @@ -2479,7 +2530,13 @@ class SwordApi(Api):
"""

def __init__(
self, base_url, api_version="v1.1", api_token=None, sword_api_version="v1.1"
self,
base_url,
api_version="v1.1",
api_token=None,
sword_api_version="v1.1",
*,
auth=None,
):
"""Init a :class:`SwordApi <pyDataverse.api.SwordApi>` instance.
Expand All @@ -2489,7 +2546,7 @@ def __init__(
Api version of Dataverse SWORD API.
"""
super().__init__(base_url, api_token, api_version)
super().__init__(base_url, api_token, api_version, auth=auth)
if not isinstance(sword_api_version, ("".__class__, "".__class__)):
raise ApiUrlError(
"sword_api_version {0} is not a string.".format(sword_api_version)
Expand Down
1 change: 1 addition & 0 deletions pyDataverse/docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Access all of Dataverse APIs.

.. automodule:: pyDataverse.api
:members:
:special-members:


Models Interface
Expand Down
31 changes: 31 additions & 0 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from httpx import Response
from time import sleep
from pyDataverse.api import DataAccessApi, NativeApi
from pyDataverse.auth import ApiTokenAuth
from pyDataverse.exceptions import ApiAuthorizationError
from pyDataverse.exceptions import ApiUrlError
from pyDataverse.models import Dataset
Expand Down Expand Up @@ -34,6 +35,36 @@ def test_api_connect_base_url_wrong(self):
NativeApi(None)


class TestApiTokenAndAuthBehavior:
def test_api_token_none_and_auth_none(self):
api = NativeApi("https://demo.dataverse.org")
assert api.api_token is None
assert api.auth is None

def test_api_token_none_and_auth(self):
auth = ApiTokenAuth("mytoken")
api = NativeApi("https://demo.dataverse.org", auth=auth)
assert api.api_token is None
assert api.auth is auth

def test_api_token_and_auth(self):
auth = ApiTokenAuth("mytoken")
# Only one, api_token or auth, should be specified
with pytest.warns(UserWarning):
api = NativeApi(
"https://demo.dataverse.org", api_token="sometoken", auth=auth
)
assert api.api_token is None
assert api.auth is auth

def test_api_token_and_auth_none(self):
api_token = "mytoken"
api = NativeApi("https://demo.dataverse.org", api_token)
assert api.api_token == api_token
assert isinstance(api.auth, ApiTokenAuth)
assert api.auth.api_token == api_token


class TestApiRequests(object):
"""Test the native_api requests."""

Expand Down

0 comments on commit 3e73fbf

Please sign in to comment.