diff --git a/src/scribe_data/check/check_pyicu.py b/src/scribe_data/check/check_pyicu.py index c67b4d3b..a1f24cd8 100644 --- a/src/scribe_data/check/check_pyicu.py +++ b/src/scribe_data/check/check_pyicu.py @@ -28,6 +28,7 @@ import pkg_resources import requests +from questionary import confirm def check_if_pyicu_installed(): @@ -144,20 +145,17 @@ def check_and_install_pyicu(): package_name = "PyICU" installed_packages = {pkg.key for pkg in pkg_resources.working_set} if package_name.lower() not in installed_packages: - # print(f"{package_name} not found. Installing...") - # Fetch available wheels from GitHub to estimate download size. wheels, total_size_mb = fetch_wheel_releases() - print( - f"{package_name} is not installed.\nIt will be downloaded from 'https://github.com/repos/cgohlke/pyicu'" - f"\nApproximately {total_size_mb:.2f} MB will be downloaded.\nDo you want to proceed? (Y/n)?" - ) + # Use questionary to ask for user confirmation + user_wants_to_proceed = confirm( + f"{package_name} is not installed.\nIt will be downloaded from 'https://github.com/repos/cgohlke/pyicu-build'" + f"\nApproximately {total_size_mb:.2f} MB will be downloaded.\nDo you want to proceed?" + ).ask() - user_input = input().strip().lower() - if user_input in ["", "y", "yes"]: + if user_wants_to_proceed: print("Proceeding with installation...") - else: print("Installation aborted by the user.") return False diff --git a/src/scribe_data/cli/get.py b/src/scribe_data/cli/get.py index 3e4dd277..c915ff91 100644 --- a/src/scribe_data/cli/get.py +++ b/src/scribe_data/cli/get.py @@ -83,14 +83,12 @@ def get_data( output_type = output_type or "json" if output_dir is None: - if output_type == "csv": - output_dir = DEFAULT_CSV_EXPORT_DIR - elif output_type == "json": - output_dir = DEFAULT_JSON_EXPORT_DIR - elif output_type == "sqlite": - output_dir = DEFAULT_SQLITE_EXPORT_DIR - elif output_type == "tsv": - output_dir = DEFAULT_TSV_EXPORT_DIR + output_dir = { + "csv": DEFAULT_CSV_EXPORT_DIR, + "json": DEFAULT_JSON_EXPORT_DIR, + "sqlite": DEFAULT_SQLITE_EXPORT_DIR, + "tsv": DEFAULT_TSV_EXPORT_DIR, + }.get(output_type, DEFAULT_JSON_EXPORT_DIR) languages = [language] if language else None data_types = [data_type] if data_type else None @@ -98,10 +96,41 @@ def get_data( subprocess_result = False # MARK: Get All - if all: - print("Updating all languages and data types ...") - query_data(None, None, None, overwrite) + if language: + print(f"Updating all data types for language for {language}") + query_data( + languages=[language], + data_type=None, + output_dir=output_dir, + overwrite=overwrite, + ) + print( + f"Query completed for all data types with specified language for {language}." + ) + + elif data_type: + print(f"Updating all languages for data type: {data_type}") + query_data( + languages=None, + data_type=[data_type], + output_dir=output_dir, + overwrite=overwrite, + ) + print( + f"Query completed for all languages with specified data type for {data_type}." + ) + + else: + print("Updating all languages and data types ...") + query_data( + languages=None, + data_type=None, + output_dir=output_dir, + overwrite=overwrite, + ) + print("Query completed for all languages and all data types.") + subprocess_result = True # MARK: Emojis @@ -113,10 +142,7 @@ def get_data( elif language or data_type: data_type = data_type[0] if isinstance(data_type, list) else data_type - - print( - f"Updating data for language(s): {language}; data type(s): {', '.join([data_type])}" - ) + print(f"Updating data for language(s): {language}; data type(s): {data_type}") query_data( languages=languages, data_type=data_types, @@ -132,9 +158,13 @@ def get_data( ) if ( - isinstance(subprocess_result, subprocess.CompletedProcess) - and subprocess_result.returncode != 1 - ) or (isinstance(subprocess_result, bool) and subprocess_result is not False): + ( + isinstance(subprocess_result, subprocess.CompletedProcess) + and subprocess_result.returncode != 1 + ) + or isinstance(subprocess_result, bool) + and subprocess_result + ): print(f"Updated data was saved in: {Path(output_dir).resolve()}.") json_input_path = Path(output_dir) / f"{language}/{data_type}.json" diff --git a/src/scribe_data/cli/main.py b/src/scribe_data/cli/main.py index 83bd4d81..313ab74d 100644 --- a/src/scribe_data/cli/main.py +++ b/src/scribe_data/cli/main.py @@ -24,6 +24,8 @@ import argparse from pathlib import Path +from rich import print as rprint + from scribe_data.cli.cli_utils import validate_language_and_data_type from scribe_data.cli.convert import convert_wrapper from scribe_data.cli.get import get_data @@ -263,43 +265,47 @@ def main() -> None: parser.print_help() return - if args.command in ["list", "l"]: - list_wrapper( - language=args.language, data_type=args.data_type, all_bool=args.all - ) + try: + if args.command in ["list", "l"]: + list_wrapper( + language=args.language, data_type=args.data_type, all_bool=args.all + ) - elif args.command in ["get", "g"]: - if args.interactive: - start_interactive_mode() + elif args.command in ["get", "g"]: + if args.interactive: + start_interactive_mode() + + else: + get_data( + language=args.language, + data_type=args.data_type, + output_type=args.output_type, + output_dir=args.output_dir, + outputs_per_entry=args.outputs_per_entry, + overwrite=args.overwrite, + all=args.all, + ) + + elif args.command in ["total", "t"]: + total_wrapper( + language=args.language, data_type=args.data_type, all_bool=args.all + ) - else: - get_data( + elif args.command in ["convert", "c"]: + convert_wrapper( language=args.language, data_type=args.data_type, output_type=args.output_type, + input_file=args.input_file, output_dir=args.output_dir, - outputs_per_entry=args.outputs_per_entry, overwrite=args.overwrite, - all=args.all, ) - elif args.command in ["total", "t"]: - total_wrapper( - language=args.language, data_type=args.data_type, all_bool=args.all - ) - - elif args.command in ["convert", "c"]: - convert_wrapper( - language=args.language, - data_type=args.data_type, - output_type=args.output_type, - input_file=args.input_file, - output_dir=args.output_dir, - overwrite=args.overwrite, - ) - - else: - parser.print_help() + else: + parser.print_help() + + except KeyboardInterrupt: + rprint("[bold red]Execution was interrupted by the user.[/bold red]") if __name__ == "__main__": diff --git a/src/scribe_data/wikidata/query_data.py b/src/scribe_data/wikidata/query_data.py index f54ccce3..a0e1b95d 100644 --- a/src/scribe_data/wikidata/query_data.py +++ b/src/scribe_data/wikidata/query_data.py @@ -66,12 +66,18 @@ def execute_formatting_script(formatting_file_path, output_dir): env = os.environ.copy() env["PYTHONPATH"] = str(project_root) - # Use subprocess to run the formatting file. - subprocess.run( - [python_executable, str(formatting_file_path), "--file-path", output_dir], - env=env, - check=True, - ) + try: + subprocess.run( + [python_executable, str(formatting_file_path), "--file-path", output_dir], + env=env, + check=True, + ) + except FileNotFoundError: + print( + f"Error: The formatting script file '{formatting_file_path}' does not exist." + ) + except subprocess.CalledProcessError as e: + print(f"Error: The formatting script failed with exit status {e.returncode}.") def query_data( diff --git a/tests/cli/test_get.py b/tests/cli/test_get.py index a1e21e75..99690733 100644 --- a/tests/cli/test_get.py +++ b/tests/cli/test_get.py @@ -48,9 +48,34 @@ def test_invalid_arguments(self): # MARK: All Data @patch("scribe_data.cli.get.query_data") - def test_get_all_data(self, mock_query_data): - get_data(all=True) - mock_query_data.assert_called_once_with(None, None, None, False) + def test_get_all_data_types_for_language(self, mock_query_data): + get_data(all=True, language="English") + mock_query_data.assert_called_once_with( + languages=["English"], + data_type=None, + output_dir="scribe_data_json_export", + overwrite=False, + ) + + @patch("scribe_data.cli.get.query_data") + def test_get_all_languages_for_data_type(self, mock_query_data): + get_data(all=True, data_type="nouns") + mock_query_data.assert_called_once_with( + languages=None, + data_type=["nouns"], + output_dir="scribe_data_json_export", + overwrite=False, + ) + + @patch("scribe_data.cli.get.query_data") + def test_get_all_languages_and_data_types(self, mock_query_data): + get_data(all=True, output_dir="./test_output") + mock_query_data.assert_called_once_with( + languages=None, + data_type=None, + output_dir="./test_output", + overwrite=False, + ) # MARK: Language and Data Type