Skip to content

Commit

Permalink
[App] Fix env variable login (#16339)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 12, 2023
1 parent a28e31f commit f806f1b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed not being able to run multiple lightning apps locally due to port collision ([#15819](https://github.com/Lightning-AI/lightning/pull/15819))

- Fixed a bug where `lightning login` with env variables would not correctly save the credentials ([#16339](https://github.com/Lightning-AI/lightning/pull/16339))


## [1.8.6] - 2022-12-21

Expand Down
43 changes: 19 additions & 24 deletions src/lightning_app/utilities/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,6 @@ class Auth:

secrets_file = pathlib.Path(LIGHTNING_CREDENTIAL_PATH)

def __post_init__(self):
for key in Keys:
setattr(self, key.suffix, os.environ.get(key.value, None))

self._with_env_var = bool(self.user_id and self.api_key) # used by authenticate method
if self._with_env_var:
self.save("", self.user_id, self.api_key, self.user_id)
logger.info("Credentials loaded from environment variables")
elif self.api_key or self.user_id:
raise ValueError(
"To use env vars for authentication both "
f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set."
)

def load(self) -> bool:
"""Load credentials from disk and update properties with credentials.
Expand Down Expand Up @@ -88,13 +74,12 @@ def save(self, token: str = "", user_id: str = "", api_key: str = "", username:
self.api_key = api_key
logger.debug("credentials saved successfully")

@classmethod
def clear(cls) -> None:
"""remove credentials from disk and env variables."""
if cls.secrets_file.exists():
cls.secrets_file.unlink()
def clear(self) -> None:
"""Remove credentials from disk."""
if self.secrets_file.exists():
self.secrets_file.unlink()
for key in Keys:
os.environ.pop(key.value, None)
setattr(self, key.suffix, None)
logger.debug("credentials removed successfully")

@property
Expand All @@ -119,11 +104,21 @@ def authenticate(self) -> Optional[str]:
----------
authorization header to use when authentication completes.
"""
if self._with_env_var:
logger.debug("successfully loaded credentials from env")
return self.auth_header

if not self.load():
# First try to authenticate from env
for key in Keys:
setattr(self, key.suffix, os.environ.get(key.value, None))

if self.user_id and self.api_key:
self.save("", self.user_id, self.api_key, self.user_id)
logger.info("Credentials loaded from environment variables")
return self.auth_header
elif self.api_key or self.user_id:
raise ValueError(
"To use env vars for authentication both "
f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set."
)

logger.debug("failed to load credentials, opening browser to get new.")
self._run_server()
return self.auth_header
Expand Down
27 changes: 21 additions & 6 deletions tests/tests_app/utilities/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

@pytest.fixture(autouse=True)
def before_each():
login.Auth.clear()
for key in login.Keys:
os.environ.pop(key.value, None)
login.Auth().clear()


class TestAuthentication:
Expand All @@ -25,7 +27,6 @@ def test_can_store_credentials(self):

def test_e2e(self):
auth = login.Auth()
assert auth._with_env_var is False
auth.save(username="superman", user_id="kr-1234")
assert auth.secrets_file.exists()

Expand All @@ -46,6 +47,9 @@ def test_auth_header(self):
os.environ.setdefault("LIGHTNING_USER_ID", "7c8455e3-7c5f-4697-8a6d-105971d6b9bd")
os.environ.setdefault("LIGHTNING_API_KEY", "e63fae57-2b50-498b-bc46-d6204cbf330e")
auth = login.Auth()
auth.clear()
auth.authenticate()

assert "Basic" in auth.auth_header
assert (
auth.auth_header
Expand All @@ -57,7 +61,9 @@ def test_authentication_with_invalid_environment_vars():
# if api key is passed without user id
os.environ.setdefault("LIGHTNING_API_KEY", "123")
with pytest.raises(ValueError):
login.Auth()
auth = login.Auth()
auth.clear()
auth.authenticate()


@mock.patch("lightning_app.utilities.login.AuthServer.login_with_browser")
Expand All @@ -66,13 +72,19 @@ def test_authentication_with_environment_vars(browser_login: mock.MagicMock):
os.environ.setdefault("LIGHTNING_API_KEY", "abc")

auth = login.Auth()
auth.clear()
auth.authenticate()

assert auth.user_id == "abc"
assert auth.auth_header == "Basic YWJjOmFiYw=="
assert auth._with_env_var is True
assert auth.authenticate() == auth.auth_header
# should not run login flow when env vars are passed
browser_login.assert_not_called()

# Check credentials file
assert auth.secrets_file.exists()
assert auth.load() is True


def test_get_auth_url():
auth_url = login.AuthServer().get_auth_url(1234)
Expand Down Expand Up @@ -103,13 +115,16 @@ def test_login_with_browser(
def test_authenticate(click_launch: mock.MagicMock, head: mock.MagicMock, run: mock.MagicMock, port: mock.MagicMock):
port.return_value = 1234
auth = login.Auth()
auth.user_id = "user_id"
auth.api_key = "api_key"
auth.clear()

click_launch.side_effect = lambda _: auth.save("", "user_id", "api_key", "user_id")

auth.authenticate()
url = f"{LIGHTNING_CLOUD_URL}/sign-in?redirectTo=http%3A%2F%2Flocalhost%3A1234%2Flogin-complete" # E501
head.assert_called_with(url)
click_launch.assert_called_with(url)
run.assert_called()

assert auth.auth_header == "Basic dXNlcl9pZDphcGlfa2V5"

auth.authenticate()
Expand Down

0 comments on commit f806f1b

Please sign in to comment.