diff --git a/lib/pbench/client/__init__.py b/lib/pbench/client/__init__.py index 5a6e9c146a..85c6390179 100644 --- a/lib/pbench/client/__init__.py +++ b/lib/pbench/client/__init__.py @@ -10,7 +10,6 @@ from pbench.client.oidc_admin import OIDCAdmin from pbench.client.types import Dataset, JSONOBJECT -from pbench.server.auth import OpenIDClientError class PbenchClientError(Exception): @@ -340,17 +339,13 @@ def login(self, user: str, password: str): user: Account username password: Account password """ - try: - response = self.oidc_admin.user_login( - client_id=self.endpoints["openid"]["client"], - username=user, - password=password, - ) - except OpenIDClientError: - self.auth_token = None - else: - self.username = user - self.auth_token = response["access_token"] + response = self.oidc_admin.user_login( + client_id=self.endpoints["openid"]["client"], + username=user, + password=password, + ) + self.username = user + self.auth_token = response["access_token"] def upload(self, tarball: Path, **kwargs) -> requests.Response: """Upload a tarball to the server. diff --git a/lib/pbench/client/oidc_admin.py b/lib/pbench/client/oidc_admin.py index a1064ddc15..99d4952185 100644 --- a/lib/pbench/client/oidc_admin.py +++ b/lib/pbench/client/oidc_admin.py @@ -64,7 +64,6 @@ def create_new_user( admin_token = self.get_admin_token().get("access_token") url_path = f"/admin/realms/{self.OIDC_REALM}/users" headers = { - "Content-Type": "application/json", "Authorization": f"Bearer {admin_token}", } data = { @@ -102,7 +101,6 @@ def user_login(self, client_id: str, username: str, password: str) -> dict: """ url_path = f"/realms/{self.OIDC_REALM}/protocol/openid-connect/token" - headers = {"Content-Type": "application/x-www-form-urlencoded"} data = { "client_id": client_id, "grant_type": "password", @@ -110,7 +108,7 @@ def user_login(self, client_id: str, username: str, password: str) -> dict: "username": username, "password": password, } - return self.post(path=url_path, data=data, headers=headers).json() + return self.post(path=url_path, data=data).json() def get_user(self, username: str, token: str) -> dict: """Get the OIDC user representation dict. @@ -138,7 +136,6 @@ def get_user(self, username: str, token: str) -> dict: response = self.get( f"admin/realms/{self.OIDC_REALM}/users", headers={ - "Content-Type": "application/json", "Authorization": f"Bearer {token}", }, username=username, diff --git a/lib/pbench/server/auth/__init__.py b/lib/pbench/server/auth/__init__.py index 8a35bf5d2c..65dfe166b4 100644 --- a/lib/pbench/server/auth/__init__.py +++ b/lib/pbench/server/auth/__init__.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from http import HTTPStatus import logging -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional from urllib.parse import urljoin import jwt @@ -38,7 +38,7 @@ class Connection: def __init__( self, server_url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, verify: bool = True, ): self.server_url = server_url @@ -50,9 +50,9 @@ def _method( self, method: str, path: str, - data: Union[Dict, str, None], - json: Optional[Dict] = None, - headers: Optional[Dict] = None, + data: Optional[Any] = None, + json: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, **kwargs, ) -> requests.Response: """Common frontend for the HTTP operations on OIDC client connection. @@ -61,7 +61,7 @@ def _method( method : The API HTTP method path : Path for the request. data : Data to send with the request in case of the POST - json: JSON data to send with the request in case of the POST + json : JSON data to send with the request in case of the POST kwargs : Additional keyword args Returns: @@ -71,7 +71,7 @@ def _method( if headers is not None: final_headers.update(headers) url = urljoin(self.server_url, path) - kwargs = dict( + request_dict = dict( params=kwargs, data=data, json=json, @@ -79,7 +79,7 @@ def _method( verify=self.verify, ) try: - response = self._connection.request(method, url, **kwargs) + response = self._connection.request(method, url, **request_dict) except requests.exceptions.ConnectionError as exc: raise OpenIDClientError( http_status=HTTPStatus.BAD_GATEWAY, @@ -113,7 +113,7 @@ def _method( return response def get( - self, path: str, headers: Optional[Dict] = None, **kwargs + self, path: str, headers: Optional[dict[str, str]] = None, **kwargs ) -> requests.Response: """GET wrapper to handle an authenticated GET operation on the Resource at a given path. @@ -126,14 +126,14 @@ def get( Returns: Response from the request. """ - return self._method("GET", path, None, None, headers=headers, **kwargs) + return self._method("GET", path, headers=headers, **kwargs) def post( self, path: str, - data: Union[Dict, str] = None, - json: Optional[Dict] = None, - headers: Optional[Dict] = None, + data: Optional[Any] = None, + json: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, **kwargs, ) -> requests.Response: """POST wrapper to handle an authenticated POST operation on the @@ -141,8 +141,8 @@ def post( Args: path : Path for the request - data : Request body to attach - json: JSON request body + data : Optional request body to attach + json : JSON request body headers : Additional headers to add to the request kwargs : Additional keyword args to be added as URL parameters diff --git a/lib/pbench/test/functional/server/test_connect.py b/lib/pbench/test/functional/server/test_connect.py index f58857fe8f..1940023563 100644 --- a/lib/pbench/test/functional/server/test_connect.py +++ b/lib/pbench/test/functional/server/test_connect.py @@ -30,4 +30,4 @@ def test_connect(self, server_client: PbenchServerClient): # verify all the required openid-connect fields are present if "openid" in endpoints: expected = {"server", "client", "realm"} - assert set(endpoints["openid"]) == expected + assert set(endpoints["openid"]) >= expected diff --git a/lib/pbench/test/unit/client/test_login.py b/lib/pbench/test/unit/client/test_login.py index 94747439e7..650b52d9a9 100644 --- a/lib/pbench/test/unit/client/test_login.py +++ b/lib/pbench/test/unit/client/test_login.py @@ -1,7 +1,10 @@ from http import HTTPStatus +import pytest import responses +from pbench.server.auth import OpenIDClientError + class TestLogin: def test_login(self, connect): @@ -38,7 +41,8 @@ def test_bad_login(self, connect): status=HTTPStatus.UNAUTHORIZED, json={"error_description": "Invalid user credentials"}, ) - connect.login("user", "password") + with pytest.raises(OpenIDClientError): + connect.login("user", "password") assert len(rsp.calls) == 1 assert rsp.calls[0].request.url == url assert rsp.calls[0].response.status_code == 401 diff --git a/lib/pbench/test/unit/server/auth/test_auth.py b/lib/pbench/test/unit/server/auth/test_auth.py index b88af85aad..c3343d8da2 100644 --- a/lib/pbench/test/unit/server/auth/test_auth.py +++ b/lib/pbench/test/unit/server/auth/test_auth.py @@ -1,7 +1,7 @@ import configparser from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from flask import current_app, Flask import jwt @@ -43,7 +43,12 @@ def fake_method(self, monkeypatch): args = {} def fake_method( - the_self, method: str, path: str, data: Dict, json: Dict, **kwargs + the_self, + method: str, + path: str, + data: Optional[Any] = None, + json: Optional[dict[str, Any]] = None, + **kwargs, ) -> requests.Response: args["method"] = method args["path"] = path @@ -176,7 +181,7 @@ def test_post(self, fake_method, conn): """ args = fake_method response = conn.post( - "foo/bar", {"one": "two", "three": "four"}, None, five="six" + "foo/bar", data={"one": "two", "three": "four"}, five="six" ) assert response is not None assert args["method"] == "POST"