diff --git a/CHANGELOG.md b/CHANGELOG.md index ef3bb2cca..03866f53d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # CHANGELOG ## 0.9.0dev + +* [Feature] Allow loading configuration value from a `pyproject.toml` file upon magic initialization (#689) * [Fix] Fix error that was incorrectly converted into a print message * [Fix] Modified histogram query to ensure histogram binning is done correctly (#751) * [Fix] Fix bug that caused the `COMMIT` not to work when the SQLAlchemy driver did not support `set_isolation_level` diff --git a/doc/api/configuration.md b/doc/api/configuration.md index d47b35d0b..8e57cc48e 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -241,3 +241,13 @@ print(res) res = %sql SELECT * FROM languages LIMIT 2 print(res) ``` + +## Loading configuration settings + +You can define configurations in a `pyproject.toml` file and automatically load the configurations when you run `%load_ext sql`. If the file is not found in the current or parent directories, default values will be used. A sample `pyproject.toml` could look like this: + +``` +[tool.jupysql.SqlMagic] +feedback = true +autopandas = true +``` diff --git a/src/sql/display.py b/src/sql/display.py index c0c2e7f6c..7fee72bea 100644 --- a/src/sql/display.py +++ b/src/sql/display.py @@ -4,7 +4,7 @@ import html from prettytable import PrettyTable -from IPython.display import display +from IPython.display import display, HTML class Table: @@ -90,3 +90,13 @@ def message(message): def message_success(message): """Display a success message""" display(Message(message, style="color: green")) + + +def message_html(message): + """Display a message as HTML""" + display(HTML(str(Message(message)))) + + +def table(headers, rows): + """Display a table""" + display(Table(headers, rows)) diff --git a/src/sql/exceptions.py b/src/sql/exceptions.py index 39b8c3034..350d81c98 100644 --- a/src/sql/exceptions.py +++ b/src/sql/exceptions.py @@ -44,3 +44,6 @@ def _error(message): # raised internally when the user chooses a table that doesn't exist TableNotFoundError = exception_factory("TableNotFoundError") + +# raise it when there is an error in parsing pyproject.toml file +ConfigurationError = exception_factory("ConfigurationError") diff --git a/src/sql/magic.py b/src/sql/magic.py index ad46b1fcf..5834b5b43 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -32,7 +32,7 @@ from sql.magic_plot import SqlPlotMagic from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error -from sql import query_util +from sql import query_util, util from sql.util import get_suggestions_message, pretty_print from sql.exceptions import RuntimeError from sql.error_message import detail @@ -622,6 +622,39 @@ def _persist_dataframe( display.message_success(f"Success! Persisted {table_name} to the database.") +def set_configs(ip, file_path): + """Set user defined SqlMagic configuration settings""" + sql = ip.find_cell_magic("sql").__self__ + user_configs = util.get_user_configs(file_path, ["tool", "jupysql", "SqlMagic"]) + default_configs = util.get_default_configs(sql) + table_rows = [] + for config, value in user_configs.items(): + if config in default_configs.keys(): + default_type = type(default_configs[config]) + if isinstance(value, default_type): + setattr(sql, config, value) + table_rows.append([config, value]) + else: + display.message( + f"'{value}' is an invalid value for '{config}'. " + f"Please use {default_type.__name__} value instead." + ) + else: + util.find_close_match_config(config, default_configs.keys()) + + return table_rows + + +def load_SqlMagic_configs(ip): + """Loads saved SqlMagic configs in pyproject.toml""" + file_path = util.find_path_from_root("pyproject.toml") + if file_path: + table_rows = set_configs(ip, file_path) + if table_rows: + display.message("Settings changed:") + display.table(["Config", "value"], table_rows) + + def load_ipython_extension(ip): """Load the extension in IPython.""" @@ -636,3 +669,5 @@ def load_ipython_extension(ip): ip.register_magics(SqlCmdMagic) patch_ipython_usage_error(ip) + + load_SqlMagic_configs(ip) diff --git a/src/sql/util.py b/src/sql/util.py index 1abefb7d7..424852ba8 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -5,6 +5,13 @@ from sql.store import store, _get_dependents_for_key from sql import exceptions, display import json +from pathlib import Path +from ploomber_core.dependencies import requires + +try: + import toml +except ModuleNotFoundError: + toml = None SINGLE_QUOTE = "'" DOUBLE_QUOTE = '"' @@ -365,3 +372,125 @@ def show_deprecation_warning(): "raise an exception in the next major release so please remove it.", FutureWarning, ) + + +def find_path_from_root(file_name): + """ + Recursively finds an absolute path to file_name starting + from current to root directory + """ + current = Path().resolve() + while not (current / file_name).exists(): + if current == current.parent: + return None + + current = current.parent + display.message(f"Found {file_name} from '{current}'") + + return str(Path(current, file_name)) + + +def find_close_match_config(word, possibilities, n=3): + """Finds closest matching configurations and displays message""" + closest_matches = difflib.get_close_matches(word, possibilities, n=n) + if not closest_matches: + display.message_html( + f"'{word}' is an invalid configuration. Please review our " + "" # noqa + "configuration guideline." + ) + else: + display.message( + f"'{word}' is an invalid configuration. Did you mean " + f"{pretty_print(closest_matches, last_delimiter='or')}?" + ) + + +def get_line_content_from_toml(file_path, line_number): + """ + Locates a line that error occurs when loading a toml file + and returns the line, key, and value + """ + with open(file_path, "r") as file: + lines = file.readlines() + eline = lines[line_number - 1].strip() + ekey, evalue = None, None + if "=" in eline: + ekey, evalue = map(str.strip, eline.split("=")) + return eline, ekey, evalue + + +@requires(["toml"]) +def load_toml(file_path): + """ + Returns toml file content in a dictionary format + and raises error if it fails to load the toml file + """ + try: + with open(file_path, "r") as file: + content = file.read() + return toml.loads(content) + except toml.TomlDecodeError as e: + raise parse_toml_error(e, file_path) + + +def parse_toml_error(e, file_path): + eline, ekey, evalue = get_line_content_from_toml(file_path, e.lineno) + if "Duplicate keys!" in str(e): + return exceptions.ConfigurationError(f"Duplicate key found : '{ekey}'") + elif "Only all lowercase booleans" in str(e): + return exceptions.ConfigurationError( + f"Invalid value '{evalue}' in '{eline}'. " + "Valid boolean values: true, false" + ) + elif "invalid literal for int()" in str(e): + return exceptions.ConfigurationError( + f"Invalid value '{evalue}' in '{eline}'. " + "To use str value, enclose it with ' or \"." + ) + else: + return e + + +def get_user_configs(file_path, section_names): + """ + Returns saved configuration settings in a toml file from given file_path + + Parameters + ---------- + file_path : str + file path to a toml file + section_names : list + section names that contains the configuration settings + (e.g., ["tool", "jupysql", "SqlMagic"]) + + Returns + ------- + dict + saved configuration settings + """ + data = load_toml(file_path) + while section_names: + section_to_find, sections_from_user = section_names.pop(0), data.keys() + if section_to_find not in sections_from_user: + close_match = difflib.get_close_matches(section_to_find, sections_from_user) + if not close_match: + return {} + else: + raise exceptions.ConfigurationError( + f"{pretty_print(close_match)} is an invalid section name. " + f"Did you mean '{section_to_find}'?" + ) + data = data[section_to_find] + return data + + +def get_default_configs(sql): + """ + Returns a dictionary of SqlMagic configuration settings users can set + with their default values. + """ + default_configs = sql.trait_defaults() + del default_configs["parent"] + del default_configs["config"] + return default_configs diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index bb7bd57de..48198178d 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1342,3 +1342,183 @@ def test_interact_and_missing_ipywidgets_installed(ip): assert "'ipywidgets' is required to use '--interactive argument'" in str( excinfo.value ) + + +@pytest.mark.parametrize( + "file_content, expect, revert", + [ + ( + """ +[tool.jupysql.SqlMagic] +autocommit = false +autolimit = 1 +style = "RANDOM" +""", + [ + "Found pyproject.toml from '%s'", + "Settings changed:", + r"autocommit\s*\|\s*False", + r"autolimit\s*\|\s*1", + r"style\s*\|\s*RANDOM", + ], + {"autocommit": True, "autolimit": 0, "style": "DEFAULT"}, + ), + ( + """ +[tool.jupysql.SqlMagic] +""", + ["Found pyproject.toml from '%s'"], + {}, + ), + ( + """ +[test] +github = "ploomber/jupysql" +""", + ["Found pyproject.toml from '%s'"], + {}, + ), + ( + """ +[tool.pkgmt] +github = "ploomber/jupysql" +""", + ["Found pyproject.toml from '%s'"], + {}, + ), + ( + """ +[tool.jupysql.test] +github = "ploomber/jupysql" +""", + ["Found pyproject.toml from '%s'"], + {}, + ), + ( + "", + ["Found pyproject.toml from '%s'"], + {}, + ), + ], +) +def test_valid_loading_toml(tmp_empty, ip, capsys, file_content, expect, revert): + Path("pyproject.toml").write_text(file_content) + toml_dir = os.getcwd() + os.makedirs("sub") + os.chdir("sub") + + ip.run_cell("%load_ext sql").result + out, _ = capsys.readouterr() + + expect[0] = expect[0] % (re.escape(toml_dir)) + assert all(re.search(substring, out) for substring in expect) + + sql = ip.find_cell_magic("sql").__self__ + [setattr(sql, config, value) for config, value in revert.items()] + + +def test_no_toml(tmp_empty, ip, capsys): + os.makedirs("sub") + os.chdir("sub") + + ip.run_cell("%load_ext sql").result + out, _ = capsys.readouterr() + + assert out == "" + + +@pytest.mark.parametrize( + "file_content, error_msg", + [ + ( + """ +[tool.jupysql.SqlMagic] +autocommit = true +autocommit = true +""", + "Duplicate key found : 'autocommit'", + ), + ( + """ +[tool.jupySql.SqlMagic] +autocommit = true +""", + "'jupySql' is an invalid section name. Did you mean 'jupysql'?", + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = True +""", + ( + "Invalid value 'True' in 'autocommit = True'. " + "Valid boolean values: true, false" + ), + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = invalid +""", + ( + "Invalid value 'invalid' in 'autocommit = invalid'. " + "To use str value, enclose it with ' or \"." + ), + ), + ], +) +def test_error_on_toml_parsing(tmp_empty, ip, capsys, file_content, error_msg): + Path("pyproject.toml").write_text(file_content) + toml_dir = os.getcwd() + found_statement = "Found pyproject.toml from '%s'" % (toml_dir) + os.makedirs("sub") + os.chdir("sub") + + with pytest.raises(UsageError) as excinfo: + ip.run_cell("%load_ext sql") + out, _ = capsys.readouterr() + + assert out.strip() == found_statement + assert excinfo.value.error_type == "ConfigurationError" + assert str(excinfo.value) == error_msg + + +def test_valid_and_invalid_configs(tmp_empty, ip, capsys): + Path("pyproject.toml").write_text( + """ +[tool.jupysql.SqlMagic] +autocomm = true +autop = false +autolimit = "text" +invalid = false +displaycon = false +""" + ) + toml_dir = os.getcwd() + os.makedirs("sub") + os.chdir("sub") + + ip.run_cell("%load_ext sql") + out, _ = capsys.readouterr() + expect = [ + "Found pyproject.toml from '%s'" % (re.escape(toml_dir)), + "'autocomm' is an invalid configuration. Did you mean 'autocommit'?", + ( + "'autop' is an invalid configuration. " + "Did you mean 'autopandas', or 'autopolars'?" + ), + ( + "'text' is an invalid value for 'autolimit'. " + "Please use int value instead." + ), + r"displaycon\s*\|\s*False", + ] + assert all(re.search(substring, out) for substring in expect) + + # confirm the correct changes are applied + confirm = {"displaycon": False, "autolimit": 0} + sql = ip.find_cell_magic("sql").__self__ + assert all([getattr(sql, config) == value for config, value in confirm.items()]) + + # revert back to a default setting + setattr(sql, "displaycon", True)