Skip to content

Commit

Permalink
studio: Update token authorization in DVC studio client (#118)
Browse files Browse the repository at this point in the history
Refactored the token authorization in DVC studio client. Introduced distinct error classes such as 'StudioAuthError', 'InvalidScopesError', 'AuthorizationExpired' for improved error handling and clearer code. The function `check_token_authorization` now raises an `AuthorizationExpired` exception if the authorization has expired instead of returning None implicitly. Further, unused imports and logging setups were removed for cleaner code. These changes are aimed to improve error debugging and the robustness of the authentication process.
  • Loading branch information
amritghimire authored Nov 23, 2023
1 parent 4b19b08 commit f041a6b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 11 deletions.
81 changes: 73 additions & 8 deletions src/dvc_studio_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import logging
from os import getenv
from typing import List, Optional, TypedDict
from urllib.parse import urljoin

import requests
from requests.adapters import HTTPAdapter

from .env import DVC_STUDIO_CLIENT_LOGLEVEL

logger = logging.getLogger(__name__)
logger.setLevel(getenv(DVC_STUDIO_CLIENT_LOGLEVEL, "INFO").upper())
from . import logger

AVAILABLE_SCOPES = ["live", "dvc_experiment", "view_url", "dql", "download_model"]

Expand All @@ -23,6 +18,72 @@ class DeviceLoginResponse(TypedDict):
expires_in: int


class StudioAuthError(Exception):
pass


class InvalidScopesError(StudioAuthError):
pass


class AuthorizationExpired(StudioAuthError):
pass


def initiate_authorization(*, name, hostname, scopes, use_device_code=False):
"""Initiate Authorization
This method initiates the authorization process for a client application.
It generates a user code and a verification URI that the user needs to
access in order to authorize the application.
Parameters:
name (str): The name of the client application.
hostname (str): The base URL of the application.
scopes (str): A comma-separated string of scopes that the application requires.
use_device_code (bool, optional): Whether to use the device code
flow for authorization. Default is False.
Returns:
tuple: A tuple containing the token name and the access token.
The token name is a string representing the token's name,
while the access token is a string representing the authorized access token.
"""

import webbrowser

response = start_device_login(
client_name="dvc",
base_url=hostname,
token_name=name,
scopes=scopes.split(","),
)
verification_uri = response["verification_uri"]
user_code = response["user_code"]
device_code = response["device_code"]
token_uri = response["token_uri"]
token_name = response["token_name"]

opened = False
if not use_device_code:
print(
f"A web browser has been opened at \n{verification_uri}.\n"
f"Please continue the login in the web browser.\n"
f"If no web browser is available or if the web browser fails to open,\n"
f"use device code flow with `dvc studio login --use-device-code`."
)
url = f"{verification_uri}?code={user_code}"
opened = webbrowser.open(url)

if not opened:
print(f"Please open the following url in your browser.\n{verification_uri}")
print(f"And enter the user code below {user_code} to authorize.")

access_token = check_token_authorization(uri=token_uri, device_code=device_code)

return token_name, access_token


def start_device_login(
*,
client_name: str,
Expand Down Expand Up @@ -56,7 +117,9 @@ def start_device_login(
f" ({base_url})" if base_url else "",
)
if invalid_scopes := list(filter(lambda s: s not in AVAILABLE_SCOPES, scopes)):
raise ValueError(f"Following scopes are not valid: {', '.join(invalid_scopes)}")
raise InvalidScopesError(
f"Following scopes are not valid: {', '.join(invalid_scopes)}"
)

body = {"client_name": client_name}

Expand Down Expand Up @@ -133,7 +196,9 @@ def check_token_authorization(*, uri: str, device_code: str) -> Optional[str]:
time.sleep(5)
continue
if detail == "authorization_expired":
return
raise AuthorizationExpired(
"failed to authenticate: This 'device_code' has expired."
)

r.raise_for_status()

Expand Down
5 changes: 2 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from requests import Response

from dvc_studio_client.auth import (
AuthorizationExpired,
DeviceLoginResponse,
check_token_authorization,
start_device_login,
Expand Down Expand Up @@ -50,12 +51,10 @@ def test_check_token_authorization_expired(mocker):
],
)

assert (
with pytest.raises(AuthorizationExpired):
check_token_authorization(
uri="https://example.com/token_uri", device_code="random_device_code"
)
is None
)

assert mock_post.call_count == 2
assert mock_post.call_args == mocker.call(
Expand Down

0 comments on commit f041a6b

Please sign in to comment.