diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index ddc6f6121fe7d..d192d93db7544 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed not being able to run multiple lightning apps locally due to port collision ([#15819](https://github.com/Lightning-AI/lightning/pull/15819)) +- Fixed a bug where `lightning login` with env variables would not correctly save the credentials ([#16339](https://github.com/Lightning-AI/lightning/pull/16339)) + ## [1.8.6] - 2022-12-21 diff --git a/src/lightning_app/utilities/login.py b/src/lightning_app/utilities/login.py index ff04e3b865d41..31087be9ee0a5 100644 --- a/src/lightning_app/utilities/login.py +++ b/src/lightning_app/utilities/login.py @@ -40,20 +40,6 @@ class Auth: secrets_file = pathlib.Path(LIGHTNING_CREDENTIAL_PATH) - def __post_init__(self): - for key in Keys: - setattr(self, key.suffix, os.environ.get(key.value, None)) - - self._with_env_var = bool(self.user_id and self.api_key) # used by authenticate method - if self._with_env_var: - self.save("", self.user_id, self.api_key, self.user_id) - logger.info("Credentials loaded from environment variables") - elif self.api_key or self.user_id: - raise ValueError( - "To use env vars for authentication both " - f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set." - ) - def load(self) -> bool: """Load credentials from disk and update properties with credentials. @@ -88,13 +74,12 @@ def save(self, token: str = "", user_id: str = "", api_key: str = "", username: self.api_key = api_key logger.debug("credentials saved successfully") - @classmethod - def clear(cls) -> None: - """remove credentials from disk and env variables.""" - if cls.secrets_file.exists(): - cls.secrets_file.unlink() + def clear(self) -> None: + """Remove credentials from disk.""" + if self.secrets_file.exists(): + self.secrets_file.unlink() for key in Keys: - os.environ.pop(key.value, None) + setattr(self, key.suffix, None) logger.debug("credentials removed successfully") @property @@ -119,11 +104,21 @@ def authenticate(self) -> Optional[str]: ---------- authorization header to use when authentication completes. """ - if self._with_env_var: - logger.debug("successfully loaded credentials from env") - return self.auth_header - if not self.load(): + # First try to authenticate from env + for key in Keys: + setattr(self, key.suffix, os.environ.get(key.value, None)) + + if self.user_id and self.api_key: + self.save("", self.user_id, self.api_key, self.user_id) + logger.info("Credentials loaded from environment variables") + return self.auth_header + elif self.api_key or self.user_id: + raise ValueError( + "To use env vars for authentication both " + f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set." + ) + logger.debug("failed to load credentials, opening browser to get new.") self._run_server() return self.auth_header diff --git a/tests/tests_app/utilities/test_login.py b/tests/tests_app/utilities/test_login.py index e0ad4b110c868..08e65454cc09e 100644 --- a/tests/tests_app/utilities/test_login.py +++ b/tests/tests_app/utilities/test_login.py @@ -11,7 +11,9 @@ @pytest.fixture(autouse=True) def before_each(): - login.Auth.clear() + for key in login.Keys: + os.environ.pop(key.value, None) + login.Auth().clear() class TestAuthentication: @@ -25,7 +27,6 @@ def test_can_store_credentials(self): def test_e2e(self): auth = login.Auth() - assert auth._with_env_var is False auth.save(username="superman", user_id="kr-1234") assert auth.secrets_file.exists() @@ -46,6 +47,9 @@ def test_auth_header(self): os.environ.setdefault("LIGHTNING_USER_ID", "7c8455e3-7c5f-4697-8a6d-105971d6b9bd") os.environ.setdefault("LIGHTNING_API_KEY", "e63fae57-2b50-498b-bc46-d6204cbf330e") auth = login.Auth() + auth.clear() + auth.authenticate() + assert "Basic" in auth.auth_header assert ( auth.auth_header @@ -57,7 +61,9 @@ def test_authentication_with_invalid_environment_vars(): # if api key is passed without user id os.environ.setdefault("LIGHTNING_API_KEY", "123") with pytest.raises(ValueError): - login.Auth() + auth = login.Auth() + auth.clear() + auth.authenticate() @mock.patch("lightning_app.utilities.login.AuthServer.login_with_browser") @@ -66,13 +72,19 @@ def test_authentication_with_environment_vars(browser_login: mock.MagicMock): os.environ.setdefault("LIGHTNING_API_KEY", "abc") auth = login.Auth() + auth.clear() + auth.authenticate() + assert auth.user_id == "abc" assert auth.auth_header == "Basic YWJjOmFiYw==" - assert auth._with_env_var is True assert auth.authenticate() == auth.auth_header # should not run login flow when env vars are passed browser_login.assert_not_called() + # Check credentials file + assert auth.secrets_file.exists() + assert auth.load() is True + def test_get_auth_url(): auth_url = login.AuthServer().get_auth_url(1234) @@ -103,13 +115,16 @@ def test_login_with_browser( def test_authenticate(click_launch: mock.MagicMock, head: mock.MagicMock, run: mock.MagicMock, port: mock.MagicMock): port.return_value = 1234 auth = login.Auth() - auth.user_id = "user_id" - auth.api_key = "api_key" + auth.clear() + + click_launch.side_effect = lambda _: auth.save("", "user_id", "api_key", "user_id") + auth.authenticate() url = f"{LIGHTNING_CLOUD_URL}/sign-in?redirectTo=http%3A%2F%2Flocalhost%3A1234%2Flogin-complete" # E501 head.assert_called_with(url) click_launch.assert_called_with(url) run.assert_called() + assert auth.auth_header == "Basic dXNlcl9pZDphcGlfa2V5" auth.authenticate()