diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4b1e2449..2a41d0fa4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -95,6 +95,7 @@ repos: | util_text | url_util | version + | config_parser/parser ).py$ additional_dependencies: - types-requests diff --git a/MANIFEST.in b/MANIFEST.in index 232f95b66..6729fc860 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,7 @@ include *.rst include LICENSE.txt include NOTICE include pyproject.toml -recursive-include src/snowflake/connector py.typed *.py *.pyx +recursive-include src/snowflake/connector py.typed *.py *.pyx *.toml recursive-include src/snowflake/connector/vendored LICENSE* recursive-include src/snowflake/connector/cpp *.cpp *.hpp diff --git a/setup.cfg b/setup.cfg index 3da04a4c2..918fee5cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,8 @@ install_requires = certifi>=2017.4.17 typing_extensions>=4.3,<5 filelock>=3.5,<4 + platformdirs + tomlkit include_package_data = True package_dir = =src diff --git a/src/snowflake/connector/config_parser/__init__.py b/src/snowflake/connector/config_parser/__init__.py new file mode 100644 index 000000000..aa36183aa --- /dev/null +++ b/src/snowflake/connector/config_parser/__init__.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from tomlkit import parse + +from ..constants import config_file +from .parser import ConfigParser + +ROOT_PARSER = ConfigParser( + name="ROOT_PARSER", + file_path=config_file, +) +ROOT_PARSER.add_option( + name="connections", + _type=parse, +) + +__all__ = [ + "ConfigParser", + "ROOT_PARSER", +] diff --git a/src/snowflake/connector/config_parser/default_config.toml b/src/snowflake/connector/config_parser/default_config.toml new file mode 100644 index 000000000..4995b52d0 --- /dev/null +++ b/src/snowflake/connector/config_parser/default_config.toml @@ -0,0 +1,4 @@ +[connections.default] +account = "accountname" +user = "username" +password = "password" diff --git a/src/snowflake/connector/config_parser/parser.py b/src/snowflake/connector/config_parser/parser.py new file mode 100644 index 000000000..477619903 --- /dev/null +++ b/src/snowflake/connector/config_parser/parser.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +from collections.abc import Iterable +from functools import wraps +from pathlib import Path +from typing import Callable, Literal, TypeVar + +import tomlkit +from tomlkit import TOMLDocument +from tomlkit.items import Table + +from ..errors import ConfigParserError, ConfigSourceError + +_T = TypeVar("_T") + + +class ConfigOption: + def __init__( + self, + name: str, + _root_parser: ConfigParser, + _nest_path: list[str], + _type: Callable[[str], _T] | None = None, + choices: Iterable[_T] | None = None, + env_name: str | None | Literal[False] = None, + ) -> None: + """Create a config option that can read values from different locations. + + Args: + name: The name of the ConfigOption + env_name: Environmental variable value should be read from, if not supplied, we'll construct this + type: A function that can turn str to the desired type, useful for reading value from environmental variable + """ + self.name = name + self.type = _type + self.choices = choices + self._nest_path = _nest_path + [name] + self._root_parser: ConfigParser = _root_parser + self.env_name = env_name + + def get(self): + """Retrieve a value of option.""" + source = "environment variable" + value = self._get_env() + if value is None: + source = "configuration file" + value = self._get_config() + if self.choices and value not in self.choices: + raise ConfigSourceError( + f"The value of {self.generate_name()} read from " + f"{source} is not part of {self.choices}" + ) + return value + + def generate_name(self) -> str: + return ".".join(self._nest_path[1:]) + + def generate_env_name(self) -> str: + pieces = map(lambda e: e.upper(), self._nest_path[1:]) + return f"SF{'_' + '_'.join(pieces)}" + + def _get_env(self) -> str | _T | None: + if self.env_name is False: + return None + if self.env_name is not None: + env_name = self.env_name + else: + # Generate environment name if it wasn't not explicitly supplied, + # and isn't disabled + env_name = self.generate_env_name() + if env_name not in os.environ: + return None + env_var = os.environ.get(env_name, None) + if env_var and self.type is not None: + return self.type(env_var) + return env_var + + def _get_config(self): + e = self._root_parser._conf + for k in self._nest_path[1:]: + e = e[k] + if isinstance(e, Table): + # If we got a TOML table we probably want it in dictionary form + return e.value + return e + + +class ConfigParser: + def __init__( + self, + *, + name: str, + file_path: Path | None = None, + ): + self.name = name + self.file_path = file_path + # Objects holding subparsers and options + self._options: dict[str, ConfigOption] = dict() + self._sub_parsers: dict[str, ConfigParser] = dict() + # Dictionary to cache read in config file + self._conf: TOMLDocument | None = None + # Information necessary to be able to nest elements + # and add options in O(1) + self._root_parser: ConfigParser = self + self._nest_path = [name] + + def read_config( + self, + ) -> None: + """Read and parse config file.""" + if self.file_path is None: + raise ConfigParserError( + "ConfigParser is trying to read config file," " but it doesn't have one" + ) + try: + self._conf = tomlkit.parse(self.file_path.read_text()) + except Exception as e: + raise ConfigSourceError( + f'An unknown error happened while loading "{str(self.file_path)}' + f'", please see the error: {e}' + ) + + @wraps(ConfigOption.__init__) + def add_option( + self, + *args, + **kwargs, + ) -> None: + kwargs["_root_parser"] = self._root_parser + kwargs["_nest_path"] = self._nest_path + new_option = ConfigOption( + *args, + **kwargs, + ) + self._check_child_conflict(new_option.name) + self._options[new_option.name] = new_option + + def _check_child_conflict(self, name: str) -> None: + if name in (self._options.keys() | self._sub_parsers.keys()): + raise ConfigParserError( + f"'{name}' subparser, or option conflicts with a child element of '{self.name}'" + ) + + def add_subparser(self, other: ConfigParser) -> None: + self._check_child_conflict(other.name) + self._sub_parsers[other.name] = other + + def _root_setter_helper(node: ConfigParser): + # Deal with ConfigParsers + node._root_parser = self._root_parser + node._nest_path = self._nest_path + node._nest_path + for sub_parser in node._sub_parsers.values(): + _root_setter_helper(sub_parser) + # Deal with ConfigOptions + for option in node._options.values(): + option._root_parser = self._root_parser + option._nest_path = self._nest_path + option._nest_path + + _root_setter_helper(other) + + def __getitem__(self, item: str) -> ConfigOption | ConfigParser: + if self._conf is None and ( + self.file_path is not None + and self.file_path.exists() + and self.file_path.is_file() + ): + self.read_config() + if item in self._options: + return self._options[item].get() + if item not in self._sub_parsers: + raise ConfigSourceError( + "No ConfigParser, or ConfigOption can be found" + f" with the name '{item}'" + ) + return self._sub_parsers[item] diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index d567137c1..3d9372878 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -14,6 +14,21 @@ if TYPE_CHECKING: from pyarrow import DataType +from pathlib import Path +from shutil import copyfile + +from platformdirs import PlatformDirs + +dirs = PlatformDirs( + appname="snowflake", + appauthor=False, + ensure_exists=True, +) +config_file = dirs.user_config_path / "config.toml" +if not config_file.exists(): + # Create default config file + default_config = Path(__file__).absolute().parent / "default_config.toml" + copyfile(default_config, config_file) DBAPI_TYPE_STRING = 0 DBAPI_TYPE_BINARY = 1 diff --git a/src/snowflake/connector/errors.py b/src/snowflake/connector/errors.py index 4adf0e14e..7436616a4 100644 --- a/src/snowflake/connector/errors.py +++ b/src/snowflake/connector/errors.py @@ -579,3 +579,17 @@ class PresignedUrlExpiredError(Error): """Exception for REST call to remote storage API failed because of expired presigned URL.""" pass + + +class ConfigSourceError(Error): + """Configuration source related errors. + + Examples are environmental variable and configuration file. + """ + + +class ConfigParserError(Error): + """Configuration parser related errors. + + These mean that ConfigParser is misused by a developer. + """ diff --git a/test/unit/test_configparser.py b/test/unit/test_configparser.py new file mode 100644 index 000000000..e43b1effa --- /dev/null +++ b/test/unit/test_configparser.py @@ -0,0 +1,204 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from pathlib import Path +from test.randomize import random_string +from textwrap import dedent +from typing import Callable, Dict, Union + +import pytest +from pytest import raises + +from snowflake.connector.config_parser import ConfigParser +from snowflake.connector.errors import ConfigFileError, ConfigParserError + + +def tmp_files_helper(cwd: Path, to_create: files) -> None: + for k, v in to_create.items(): + new_file = cwd / k + if isinstance(v, str): + new_file.touch() + new_file.write_text(v) + else: + new_file.mkdir() + tmp_files_helper(new_file, v) + + +files = Dict[str, Union[str, "files"]] + + +@pytest.fixture +def tmp_files(tmp_path: Path) -> Callable[[files], Path]: + def create_tmp_files(to_create: files) -> Path: + tmp_files_helper(tmp_path, to_create) + return tmp_path + + return create_tmp_files + + +def test_incorrect_config_read(tmp_files): + tmp_folder = tmp_files( + { + "config.toml": dedent( + """ + [connections.defa + """ + ) + } + ) + config_file = tmp_folder / "config.toml" + with raises(ConfigFileError) as ex: + ConfigParser(name="test", file_path=config_file).read_config() + assert ex.match(f'An unknown error happened while loading "{str(config_file)}"') + + +def test_simple_config_read(tmp_files): + tmp_folder = tmp_files( + { + "config.toml": dedent( + """\ + [connections.snowflake] + account = "snowflake" + user = "snowball" + password = "password" + """ + ) + } + ) + config_file = tmp_folder / "config.toml" + TEST_PARSER = ConfigParser( + name="test", + file_path=config_file, + ) + from tomlkit import parse + + TEST_PARSER.add_option( + "connections", + _type=parse, + ) + assert TEST_PARSER["connections"] == { + "snowflake": { + "account": "snowflake", + "user": "snowball", + "password": "password", + } + } + + +def test_simple_nesting(monkeypatch, tmp_path): + c1 = ConfigParser(name="test", file_path=tmp_path / "config.toml") + c2 = ConfigParser(name="sb") + c3 = ConfigParser(name="sb") + c3.add_option(name="b", _type=lambda e: e.lower() == "true") + c2.add_subparser(c3) + c1.add_subparser(c2) + with monkeypatch.context() as m: + m.setenv("SF_SB_SB_B", "TrUe") + assert c1["sb"]["sb"]["b"] is True + + +def test_complicated_nesting(monkeypatch, tmp_path): + c_file = tmp_path / "config.toml" + c1 = ConfigParser(file_path=c_file, name="root_parser") + c2 = ConfigParser(file_path=tmp_path / "config2.toml", name="sp") + c2.add_option(name="b", _type=lambda e: e.lower() == "true") + c1.add_subparser(c2) + c_file.write_text( + dedent( + """\ + [connections.default] + user="testuser" + account="testaccount" + password="testpassword" + + [sp] + b = true + """ + ) + ) + assert c1["sp"]["b"] is True + + +def test_error_missing_file_path(): + with pytest.raises( + ConfigParserError, + match="ConfigParser is trying to read config file," " but it doesn't have one", + ): + ConfigParser(name="test_parser").read_config() + + +def test_error_invalid_toml(tmp_path): + c_file = tmp_path / "c.toml" + c_file.write_text( + dedent( + """\ + invalid toml file + """ + ) + ) + with pytest.raises( + ConfigFileError, + match=f'An unknown error happened while loading "{str(c_file)}"', + ): + ConfigParser( + name="test_parser", + file_path=c_file, + ).read_config() + + +def test_error_child_conflict(): + cp = ConfigParser(name="test_parser") + cp.add_subparser(ConfigParser(name="b")) + with pytest.raises( + ConfigParserError, + match="'b' subparser, or option conflicts with a child element of 'test_parser'", + ): + cp.add_option("b") + + +def test_explicit_env_name(monkeypatch): + rnd_string = random_string(5) + toml_value = dedent( + f"""\ + text = "{rnd_string}" + """ + ) + TEST_PARSER = ConfigParser( + name="test_parser", + ) + + from tomlkit import parse + + TEST_PARSER.add_option("connections", _type=parse, env_name="CONNECTIONS") + with monkeypatch.context() as m: + m.setenv("CONNECTIONS", toml_value) + assert TEST_PARSER["connections"] == {"text": rnd_string} + + +def test_error_contains(monkeypatch): + tp = ConfigParser( + name="test_parser", + ) + tp.add_option("output_format", choices=("json", "csv")) + with monkeypatch.context() as m: + m.setenv("SF_OUTPUT_FORMAT", "toml") + with pytest.raises( + ConfigFileError, + match="The value of output_format read from environment variable " + "is not part of", + ): + tp["output_format"] + + +def test_missing_item(): + tp = ConfigParser( + name="test_parser", + ) + with pytest.raises( + ConfigFileError, + match="No ConfigParser, or ConfigOption can be found with the" " name 'asd'", + ): + tp["asd"]