Skip to content

Commit

Permalink
Rename initiate_authentication to get_access_token
Browse files Browse the repository at this point in the history
Enhanced token authorization tests for DVC Studio client with improved error checking and exception handling. Introduced `get_access_token` method to retrieve the Access Token from DVC Server. Added test cases for authorization expiry and error handling. Also added functionality to validate required scopes, which prominently include 'EXPERIMENTS', 'DATASETS', and 'MODELS'. Updated README with relevant changes.
  • Loading branch information
amritghimire committed Nov 27, 2023
1 parent f041a6b commit 6778b91
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 47 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Features
- Live Experiments
- `post_live_metrics`_: Post updates to `api/live`.

- Studio authorization
- `initiate_authorization`_: Initiates the authorization process for a client application

Installation
------------

Expand Down Expand Up @@ -73,3 +76,4 @@ please `file an issue`_ along with a detailed description.
.. _DVC Studio: https://dvc.org/doc/studio
.. _get_download_uris: https://docs.iterative.ai/dvc-studio-client/reference/dvc_studio_client/model_registry/
.. _post_live_metrics: https://docs.iterative.ai/dvc-studio-client/reference/dvc_studio_client/post_live_metrics/
.. _initiate_authorization: https://docs.iterative.ai/dvc-studio-client/reference/dvc_studio_client/auth/
25 changes: 14 additions & 11 deletions src/dvc_studio_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import logger

AVAILABLE_SCOPES = ["live", "dvc_experiment", "view_url", "dql", "download_model"]
AVAILABLE_SCOPES = ["EXPERIMENTS", "DATASETS", "MODELS"]


class DeviceLoginResponse(TypedDict):
Expand All @@ -30,19 +30,20 @@ class AuthorizationExpired(StudioAuthError):
pass


def initiate_authorization(*, name, hostname, scopes, use_device_code=False):
def get_access_token(*, hostname, token_name=None, scopes="", use_device_code=False, client_name="client"):
"""Initiate Authorization
This method initiates the authorization process for a client application.
It generates a user code and a verification URI that the user needs to
access in order to authorize the application.
Parameters:
name (str): The name of the client application.
token_name (str): The name of the client application.
hostname (str): The base URL of the application.
scopes (str): A comma-separated string of scopes that the application requires.
scopes (str, optional): A comma-separated string of scopes that the application requires. Default is empty.
use_device_code (bool, optional): Whether to use the device code
flow for authorization. Default is False.
client_name (str, optional): Client name
Returns:
tuple: A tuple containing the token name and the access token.
Expand All @@ -53,10 +54,10 @@ def initiate_authorization(*, name, hostname, scopes, use_device_code=False):
import webbrowser

response = start_device_login(
client_name="dvc",
client_name=client_name,
base_url=hostname,
token_name=name,
scopes=scopes.split(","),
token_name=token_name,
scopes=scopes.split(",") if scopes else [],
)
verification_uri = response["verification_uri"]
user_code = response["user_code"]
Expand All @@ -66,16 +67,18 @@ def initiate_authorization(*, name, hostname, scopes, use_device_code=False):

opened = False
if not use_device_code:
url = f"{verification_uri}?code={user_code}"
opened = webbrowser.open(url)

if opened:
print(
f"A web browser has been opened at \n{verification_uri}.\n"
f"Please continue the login in the web browser.\n"
f"If no web browser is available or if the web browser fails to open,\n"
f"use device code flow with `dvc studio login --use-device-code`."
)
url = f"{verification_uri}?code={user_code}"
opened = webbrowser.open(url)

if not opened:
else:
print(f"Please open the following url in your browser.\n{verification_uri}")
print(f"And enter the user code below {user_code} to authorize.")

Expand Down Expand Up @@ -116,7 +119,7 @@ def start_device_login(
"Starting device login for Studio%s",
f" ({base_url})" if base_url else "",
)
if invalid_scopes := list(filter(lambda s: s not in AVAILABLE_SCOPES, scopes)):
if invalid_scopes := list(filter(lambda s: s.upper() not in AVAILABLE_SCOPES, scopes)):
raise InvalidScopesError(
f"Following scopes are not valid: {', '.join(invalid_scopes)}"
)
Expand Down
205 changes: 169 additions & 36 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,142 @@
AuthorizationExpired,
DeviceLoginResponse,
check_token_authorization,
start_device_login,
start_device_login, get_access_token, InvalidScopesError,
)

MOCK_RESPONSE = {
"verification_uri": "https://studio.example.com/auth/device-login",
"user_code": "MOCKCODE",
"device_code": "random-value",
"token_uri": "https://studio.example.com/api/device-login/token",
"token_name": "random-name",
}

def test_start_device_login(mocker):

@pytest.fixture
def mock_response(mocker):
def _mock_response(status_code, json):
response = Response()
response.status_code = status_code
mocker.patch.object(response, "json", return_value=json)
return response

return _mock_response


@pytest.fixture
def mock_post(mocker, mock_response):
def _mock_post(method, side_effect):
return mocker.patch(
method,
side_effect=[
mock_response(status, resp) for status, resp in side_effect
],
)

return _mock_post


def test_auth_expired(mocker, mock_post):
mocker.patch("webbrowser.open")

mock_login_post = mock_post(
"requests.post", [(200, MOCK_RESPONSE)]
)

mock_poll_post = mock_post(
"requests.Session.post",
[(400, {"detail": "authorization_expired"})]
)

with pytest.raises(AuthorizationExpired):
get_access_token(client_name="client", hostname="https://studio.example.com")

assert mock_login_post.call_args == mocker.call(
url="https://studio.example.com/api/device-login",
json={
"client_name": "client",
},
headers={"Content-type": "application/json"},
timeout=5,
)

assert mock_poll_post.call_args_list == [
mocker.call(
"https://studio.example.com/api/device-login/token",
json={"code": "random-value"},
timeout=5,
allow_redirects=False,
),
]


def test_auth_success(mocker, mock_post, capfd):
mocker.patch("time.sleep")
mocker.patch("webbrowser.open")
mock_login_post = mock_post(
"requests.post", [(200, MOCK_RESPONSE)]
)
mock_poll_post = mock_post(
"requests.Session.post",
[
(400, {"detail": "authorization_pending"}),
(200, {"access_token": "isat_access_token"}),
],
)

assert get_access_token(hostname="https://example.com", scopes="experiments", token_name="random-name") == (
"random-name", "isat_access_token")

assert mock_login_post.call_args_list == [
mocker.call(
url="https://example.com/api/device-login",
json={"client_name": "client", "token_name": "random-name", "scopes": ["experiments"]},
headers={"Content-type": "application/json"},
timeout=5,
)
]
assert mock_poll_post.call_count == 2
assert mock_poll_post.call_args_list == [
mocker.call(
f"https://studio.example.com/api/device-login/token",
json={"code": "random-value"},
timeout=5,
allow_redirects=False,
),
mocker.call(
f"https://studio.example.com/api/device-login/token",
json={"code": "random-value"},
timeout=5,
allow_redirects=False,
),
]
assert "Please continue the login in the web browser" in capfd.readouterr().out


def test_webbrowser_open_fails(mocker, mock_post, capfd):
mock_open = mocker.patch("webbrowser.open")
mock_open.return_value = False

mocker.patch("time.sleep")
mock_post(
"requests.post", [(200, MOCK_RESPONSE)]
)
mock_post(
"requests.Session.post",
[
(400, {"detail": "authorization_pending"}),
(200, {"access_token": "isat_access_token"}),
],
)

assert get_access_token(
hostname="https://example.com", scopes="experiments", token_name="random-name"
) == ("random-name", "isat_access_token")
assert "Please open the following url in your browser" in capfd.readouterr().out


def test_start_device_login(mocker, mock_post):
example_response = {
"device_code": "random-device-code",
"user_code": "MOCKCODE",
Expand All @@ -19,35 +150,45 @@ def test_start_device_login(mocker):
"token_name": "token_name",
"expires_in": 1500,
}
mock_post = mocker.patch(
mock_post = mock_post(
"requests.post",
return_value=mock_response(mocker, 200, example_response),
[(200, example_response)]
)

response: DeviceLoginResponse = start_device_login(
base_url="https://example.com",
client_name="dvc",
token_name="token_name",
scopes=["live"],
scopes=["EXPERIMENTS"],
)

assert mock_post.called
assert mock_post.call_args == mocker.call(
url="https://example.com/api/device-login",
json={"client_name": "dvc", "token_name": "token_name", "scopes": ["live"]},
json={"client_name": "dvc", "token_name": "token_name", "scopes": ["EXPERIMENTS"]},
headers={"Content-type": "application/json"},
timeout=5,
)
assert response == example_response


def test_check_token_authorization_expired(mocker):
def test_start_device_login_invalid_scopes(mock_post):
with pytest.raises(InvalidScopesError):
start_device_login(
base_url="https://example.com",
client_name="dvc",
token_name="token_name",
scopes=["INVALID!"],
)


def test_check_token_authorization_expired(mocker, mock_post):
mocker.patch("time.sleep")
mock_post = mocker.patch(
mock_post = mock_post(
"requests.Session.post",
side_effect=[
mock_response(mocker, 400, {"detail": "authorization_pending"}),
mock_response(mocker, 400, {"detail": "authorization_expired"}),
[
(400, {"detail": "authorization_pending"}),
(400, {"detail": "authorization_expired"}),
],
)

Expand All @@ -65,13 +206,13 @@ def test_check_token_authorization_expired(mocker):
)


def test_check_token_authorization_error(mocker):
def test_check_token_authorization_error(mocker, mock_post):
mocker.patch("time.sleep")
mock_post = mocker.patch(
mock_post = mock_post(
"requests.Session.post",
side_effect=[
mock_response(mocker, 400, {"detail": "authorization_pending"}),
mock_response(mocker, 500, {"detail": "unexpected_error"}),
[
(400, {"detail": "authorization_pending"}),
(500, {"detail": "unexpected_error"}),
],
)

Expand All @@ -89,36 +230,28 @@ def test_check_token_authorization_error(mocker):
)


def test_check_token_authorization_success(mocker):
def test_check_token_authorization_success(mocker, mock_post):
mocker.patch("time.sleep")
mock_post = mocker.patch(
mock_post_call = mock_post(
"requests.Session.post",
side_effect=[
mock_response(mocker, 400, {"detail": "authorization_pending"}),
mock_response(mocker, 400, {"detail": "authorization_pending"}),
mock_response(mocker, 200, {"access_token": "isat_token"}),
[
(400, {"detail": "authorization_pending"}),
(400, {"detail": "authorization_pending"}),
(200, {"access_token": "isat_token"}),
],
)

assert (
check_token_authorization(
uri="https://example.com/token_uri", device_code="random_device_code"
)
== "isat_token"
check_token_authorization(
uri="https://example.com/token_uri", device_code="random_device_code"
)
== "isat_token"
)

assert mock_post.call_count == 3
assert mock_post.call_args == mocker.call(
assert mock_post_call.call_count == 3
assert mock_post_call.call_args == mocker.call(
"https://example.com/token_uri",
json={"code": "random_device_code"},
timeout=5,
allow_redirects=False,
)


def mock_response(mocker, status_code, json):
response = Response()
response.status_code = status_code
mocker.patch.object(response, "json", side_effect=[json])

return response

0 comments on commit 6778b91

Please sign in to comment.