Skip to content

Commit

Permalink
SNOW-630142 Custom Auth (#1215)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Ling <[email protected]>
Co-authored-by: Mark Keller <[email protected]>
  • Loading branch information
3 people authored Nov 24, 2022
1 parent bd39dc7 commit 9e04925
Show file tree
Hide file tree
Showing 21 changed files with 948 additions and 430 deletions.
40 changes: 40 additions & 0 deletions src/snowflake/connector/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

from ._auth import Auth, get_public_key_fingerprint, get_token_from_private_key
from .by_plugin import AuthByPlugin, AuthType
from .default import AuthByDefault
from .keypair import AuthByKeyPair
from .oauth import AuthByOAuth
from .okta import AuthByOkta
from .usrpwdmfa import AuthByUsrPwdMfa
from .webbrowser import AuthByWebBrowser

FIRST_PARTY_AUTHENTICATORS = frozenset(
(
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
AuthByOkta,
AuthByUsrPwdMfa,
AuthByWebBrowser,
)
)

__all__ = [
"AuthByPlugin",
"AuthByDefault",
"AuthByKeyPair",
"AuthByOAuth",
"AuthByOkta",
"AuthByUsrPwdMfa",
"AuthByWebBrowser",
"Auth",
"AuthType",
"FIRST_PARTY_AUTHENTICATORS",
"get_public_key_fingerprint",
"get_token_from_private_key",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
#
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
#
Expand All @@ -16,6 +15,7 @@
from os import getenv, makedirs, mkdir, path, remove, removedirs, rmdir
from os.path import expanduser
from threading import Lock, Thread
from typing import TYPE_CHECKING, Any, Callable

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
Expand All @@ -26,44 +26,45 @@
load_pem_private_key,
)

from .auth_keypair import AuthByKeyPair
from .auth_usrpwdmfa import AuthByUsrPwdMfa
from .compat import IS_LINUX, IS_MACOS, IS_WINDOWS, urlencode
from .constants import (
from ..compat import IS_LINUX, IS_MACOS, IS_WINDOWS, urlencode
from ..constants import (
DAY_IN_SECONDS,
HTTP_HEADER_ACCEPT,
HTTP_HEADER_CONTENT_TYPE,
HTTP_HEADER_SERVICE_NAME,
HTTP_HEADER_USER_AGENT,
PARAMETER_CLIENT_REQUEST_MFA_TOKEN,
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL,
)
from .description import (
from ..description import (
COMPILER,
IMPLEMENTATION,
OPERATING_SYSTEM,
PLATFORM,
PYTHON_VERSION,
)
from .errorcode import ER_FAILED_TO_CONNECT_TO_DB
from .errors import (
from ..errorcode import ER_FAILED_TO_CONNECT_TO_DB
from ..errors import (
BadGatewayError,
DatabaseError,
Error,
ForbiddenError,
ProgrammingError,
ServiceUnavailableError,
)
from .network import (
from ..network import (
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
CONTENT_TYPE_APPLICATION_JSON,
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
KEY_PAIR_AUTHENTICATOR,
PYTHON_CONNECTOR_USER_AGENT,
ReauthenticationRequest,
)
from .options import installed_keyring, keyring
from .sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
from .version import VERSION
from ..options import installed_keyring, keyring
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
from ..version import VERSION

if TYPE_CHECKING:
from . import AuthByPlugin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,19 +154,19 @@ def base_auth_data(

def authenticate(
self,
auth_instance,
account,
user,
database=None,
schema=None,
warehouse=None,
role=None,
passcode=None,
passcode_in_password=False,
mfa_callback=None,
password_callback=None,
session_parameters=None,
timeout=120,
auth_instance: AuthByPlugin,
account: str,
user: str,
database: str | None = None,
schema: str | None = None,
warehouse: str | None = None,
role: str | None = None,
passcode: str | None = None,
passcode_in_password: bool = False,
mfa_callback: Callable[[], None] | None = None,
password_callback: Callable[[], str] | None = None,
session_parameters: dict[Any, Any] | None = None,
timeout: int = 120,
) -> dict[str, str | int | bool]:
logger.debug("authenticate")

Expand Down Expand Up @@ -242,15 +243,7 @@ def authenticate(
# login_timeout comes from user configuration.
# Between login timeout and auth specific
# timeout use whichever value is smaller
if hasattr(auth_instance, "get_timeout"):
logger.debug(
f"Authenticator, {type(auth_instance).__name__}, implements get_timeout"
)
auth_timeout = min(
self._rest._connection.login_timeout, auth_instance.get_timeout()
)
else:
auth_timeout = self._rest._connection.login_timeout
auth_timeout = min(self._rest._connection.login_timeout, auth_instance.timeout)
logger.debug(f"Timeout set to {auth_timeout}")

try:
Expand Down Expand Up @@ -386,15 +379,19 @@ def post_request_wrapper(self, url, headers, body):
)
)

if type(auth_instance) is AuthByKeyPair:
from . import AuthByKeyPair

if isinstance(auth_instance, AuthByKeyPair):
logger.debug(
"JWT Token authentication failed. "
"Token expires at: %s. "
"Current Time: %s",
str(auth_instance._jwt_token_exp),
str(datetime.utcnow()),
)
if type(auth_instance) is AuthByUsrPwdMfa:
from . import AuthByUsrPwdMfa

if isinstance(auth_instance, AuthByUsrPwdMfa):
delete_temporary_credential(self._rest._host, user, MFA_TOKEN)
Error.errorhandler_wrapper(
self._rest._connection,
Expand Down Expand Up @@ -483,16 +480,33 @@ def _read_temporary_credential(self, host, user, cred_type):
logger.debug("OS not supported for Local Secure Storage")
return cred

def read_temporary_credentials(self, host, user, session_parameters):
def read_temporary_credentials(
self,
host: str,
user: str,
session_parameters: dict[str, Any],
) -> None:
if session_parameters.get(PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False):
self._rest.id_token = self._read_temporary_credential(host, user, ID_TOKEN)
self._rest.id_token = self._read_temporary_credential(
host,
user,
ID_TOKEN,
)

if session_parameters.get(PARAMETER_CLIENT_REQUEST_MFA_TOKEN, False):
self._rest.mfa_token = self._read_temporary_credential(
host, user, MFA_TOKEN
host,
user,
MFA_TOKEN,
)

def _write_temporary_credential(self, host, user, cred_type, cred):
def _write_temporary_credential(
self,
host: str,
user: str,
cred_type: str,
cred: str | None,
) -> None:
if not cred:
logger.debug(
"no credential is given when try to store temporary credential"
Expand Down Expand Up @@ -522,9 +536,18 @@ def _write_temporary_credential(self, host, user, cred_type, cred):
else:
logger.debug("OS not supported for Local Secure Storage")

def write_temporary_credentials(self, host, user, session_parameters, response):
if self._rest._connection.consent_cache_id_token and session_parameters.get(
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False
def write_temporary_credentials(
self,
host: str,
user: str,
session_parameters: dict[str, Any],
response: dict[str, Any],
) -> None:
if (
self._rest._connection.auth_class.consent_cache_id_token
and session_parameters.get(
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False
)
):
self._write_temporary_credential(
host, user, ID_TOKEN, response["data"].get("idToken")
Expand All @@ -534,10 +557,9 @@ def write_temporary_credentials(self, host, user, session_parameters, response):
self._write_temporary_credential(
host, user, MFA_TOKEN, response["data"].get("mfaToken")
)
return


def flush_temporary_credentials():
def flush_temporary_credentials() -> None:
"""Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK."""
global TEMPORARY_CREDENTIAL
global TEMPORARY_CREDENTIAL_FILE
Expand Down Expand Up @@ -566,7 +588,7 @@ def flush_temporary_credentials():
unlock_temporary_credential_file()


def write_temporary_credential_file(host, cred_name, cred):
def write_temporary_credential_file(host, cred_name, cred) -> None:
"""Writes temporary credential file when OS is Linux."""
if not CACHE_DIR:
# no cache is enabled
Expand All @@ -581,7 +603,7 @@ def write_temporary_credential_file(host, cred_name, cred):
flush_temporary_credentials()


def read_temporary_credential_file():
def read_temporary_credential_file() -> None:
"""Reads temporary credential file when OS is Linux."""
if not CACHE_DIR:
# no cache is enabled
Expand Down Expand Up @@ -616,10 +638,9 @@ def read_temporary_credential_file():
)
finally:
unlock_temporary_credential_file()
return None


def lock_temporary_credential_file():
def lock_temporary_credential_file() -> bool:
global TEMPORARY_CREDENTIAL_FILE_LOCK
try:
mkdir(TEMPORARY_CREDENTIAL_FILE_LOCK)
Expand All @@ -632,7 +653,7 @@ def lock_temporary_credential_file():
return False


def unlock_temporary_credential_file():
def unlock_temporary_credential_file() -> bool:
global TEMPORARY_CREDENTIAL_FILE_LOCK
try:
rmdir(TEMPORARY_CREDENTIAL_FILE_LOCK)
Expand Down Expand Up @@ -709,10 +730,13 @@ def get_token_from_private_key(
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
)
auth_instance = AuthByKeyPair(private_key, 1440 * 60) # token valid for 24 hours
return auth_instance.authenticate(
KEY_PAIR_AUTHENTICATOR, None, account, user, key_password
)
from . import AuthByKeyPair

auth_instance = AuthByKeyPair(
private_key,
DAY_IN_SECONDS,
) # token valid for 24 hours
return auth_instance.prepare(account=account, user=user)


def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
Expand All @@ -729,4 +753,6 @@ def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
private_key = load_der_private_key(
data=private_key, password=None, backend=default_backend()
)
from . import AuthByKeyPair

return AuthByKeyPair.calculate_public_key_fingerprint(private_key)
Loading

0 comments on commit 9e04925

Please sign in to comment.