From 435988572a9d51407700e99d37490b64d566369d Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 13:00:02 +0300 Subject: [PATCH 1/8] Add base formatter --- .../formatter/__init__.py | 0 datamodel_code_generator/formatter/base.py | 91 +++++++++++++++++++ .../add_license_formatter.py | 17 ++++ tests/formatter/__init__.py | 0 tests/formatter/test_base.py | 84 +++++++++++++++++ 5 files changed, 192 insertions(+) create mode 100644 datamodel_code_generator/formatter/__init__.py create mode 100644 datamodel_code_generator/formatter/base.py create mode 100644 tests/data/python/custom_formatters/add_license_formatter.py create mode 100644 tests/formatter/__init__.py create mode 100644 tests/formatter/test_base.py diff --git a/datamodel_code_generator/formatter/__init__.py b/datamodel_code_generator/formatter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datamodel_code_generator/formatter/base.py b/datamodel_code_generator/formatter/base.py new file mode 100644 index 000000000..58e7784dd --- /dev/null +++ b/datamodel_code_generator/formatter/base.py @@ -0,0 +1,91 @@ +from importlib import import_module +from typing import Any, ClassVar, Dict + +from datamodel_code_generator.imports import Import + + +class BaseCodeFormatter: + """An abstract class for representing a code formatter. + + All formatters that format a generated code should subclass + it. All subclass should override `apply` method which + has a string with code in input and returns a formatted code in string. + We also need to determine a `formatter_name` field + which is unique name of formatter. + + Example: + >>> class CustomHeaderCodeFormatter(BaseCodeFormatter): + ... formatter_name: ClassVar[str] = "custom" + ... def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + ... super().__init__(formatter_kwargs=formatter_kwargs) + ... + ... default_header = "my header" + ... self.header: str = self.formatter_kwargs.get("header", default_header) + ... def apply(self, code: str) -> str: + ... return f'# {self.header}\\n{code}' + ... + ... formatter_kwargs = {"header": "formatted with CustomHeaderCodeFormatter"} + ... formatter = CustomHeaderCodeFormatter(formatter_kwargs) + ... code = '''x = 1\ny = 2''' + ... print(formatter.apply(code)) + # formatted with CustomHeaderCodeFormatter + x = 1 + y = 2 + + """ + + formatter_name: ClassVar[str] = '' + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + if self.formatter_name == '': + raise ValueError('`formatter_name` should be not empty string') + + self.formatter_kwargs = formatter_kwargs + + def apply(self, code: str) -> str: + raise NotImplementedError + + +def load_code_formatter( + custom_formatter_import: str, custom_formatters_kwargs: Dict[str, Any] +) -> BaseCodeFormatter: + """Load a formatter by import path as string. + + Args: + custom_formatter_import: custom formatter module. + custom_formatters_kwargs: kwargs for custom formatters from config. + + Examples: + for default formatters use + >>> custom_formatter_import = "datamodel_code_generator.formatter.BlackCodeFormatter" + this is equivalent to code + >>> from datamodel_code_generator.formatter import BlackCodeFormatter + + custom formatter + >>> custom_formatter_import = "my_package.my_sub_package.FormatterName" + this is equivalent to code + >>> from my_package.my_sub_package import FormatterName + + """ + + import_ = Import.from_full_path(custom_formatter_import) + imported_module_ = import_module(import_.from_) + + if not hasattr(imported_module_, import_.import_): + raise NameError( + f'Custom formatter module `{import_.from_}` not contains formatter with name `{import_.import_}`' + ) + + formatter_class = imported_module_.__getattribute__(import_.import_) + + if not issubclass(formatter_class, BaseCodeFormatter): + raise TypeError( + f'The custom module `{custom_formatter_import}` must inherit from ' + '`datamodel-code-generator.formatter.BaseCodeFormatter`' + ) + + custom_formatter_kwargs = custom_formatters_kwargs.get( + formatter_class.formatter_name, {} + ) + + return formatter_class(formatter_kwargs=custom_formatter_kwargs) diff --git a/tests/data/python/custom_formatters/add_license_formatter.py b/tests/data/python/custom_formatters/add_license_formatter.py new file mode 100644 index 000000000..1632dd90e --- /dev/null +++ b/tests/data/python/custom_formatters/add_license_formatter.py @@ -0,0 +1,17 @@ +from typing import Any, Dict, ClassVar + +from datamodel_code_generator.formatter.base import BaseCodeFormatter + + +class LicenseFormatter(BaseCodeFormatter): + """Add a license to file from license file path.""" + formatter_name: ClassVar[str] = "license_formatter" + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + super().__init__(formatter_kwargs) + + license_txt = formatter_kwargs.get('license_txt', "a license") + self.license_header = '\n'.join([f'# {line}' for line in license_txt.split('\n')]) + + def apply(self, code: str) -> str: + return f'{self.license_header}\n{code}' diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/formatter/test_base.py b/tests/formatter/test_base.py new file mode 100644 index 000000000..05167481d --- /dev/null +++ b/tests/formatter/test_base.py @@ -0,0 +1,84 @@ +from typing import ClassVar + +import pytest + +from datamodel_code_generator.formatter.base import ( + BaseCodeFormatter, + load_code_formatter, +) + +UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist.CustomFormatter' +WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong.WrongFormatterName_' +NOT_SUBCLASS_FORMATTER = ( + 'tests.data.python.custom_formatters.not_subclass.CodeFormatter' +) +ADD_LICENSE_FORMATTER = ( + 'tests.data.python.custom_formatters.add_license_formatter.LicenseFormatter' +) + + +def test_incorrect_from_base_not_implemented_apply(): + class CustomFormatter(BaseCodeFormatter): + formatter_name: ClassVar[str] = 'formatter' + + with pytest.raises(NotImplementedError): + formatter = CustomFormatter({}) + formatter.apply('') + + +def test_incorrect_from_base(): + class CustomFormatter(BaseCodeFormatter): + def apply(self, code: str) -> str: + return code + + with pytest.raises(ValueError): + _ = CustomFormatter({}) + + +def test_load_code_formatter_un_exist_custom_formatter(): + with pytest.raises(ModuleNotFoundError): + load_code_formatter(UN_EXIST_FORMATTER, {}) + + +def test_load_code_formatter_invalid_formatter_name(): + with pytest.raises(NameError): + load_code_formatter(WRONG_FORMATTER, {}) + + +def test_load_code_formatter_is_not_subclass(): + with pytest.raises(TypeError): + load_code_formatter(NOT_SUBCLASS_FORMATTER, {}) + + +def test_add_license_formatter_without_kwargs(): + formatter = load_code_formatter(ADD_LICENSE_FORMATTER, {}) + formatted_code = formatter.apply('x = 1\ny = 2') + + assert ( + formatted_code + == """# a license +x = 1 +y = 2""" + ) + + +def test_add_license_formatter_with_kwargs(): + formatter = load_code_formatter( + ADD_LICENSE_FORMATTER, + { + 'license_formatter': { + 'license_txt': 'MIT License\n\nCopyright (c) 2023 Blah-blah\n' + } + }, + ) + formatted_code = formatter.apply('x = 1\ny = 2') + + assert ( + formatted_code + == """# MIT License +# +# Copyright (c) 2023 Blah-blah +# +x = 1 +y = 2""" + ) From c26f5d7d18bb55913c510759c7b3cdd6e8fbd690 Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 13:26:36 +0300 Subject: [PATCH 2/8] Add isort formatter --- .../formatter/__init__.py | 5 ++ datamodel_code_generator/formatter/isort.py | 47 +++++++++++++++++++ tests/formatter/test_isort.py | 15 ++++++ 3 files changed, 67 insertions(+) create mode 100644 datamodel_code_generator/formatter/isort.py create mode 100644 tests/formatter/test_isort.py diff --git a/datamodel_code_generator/formatter/__init__.py b/datamodel_code_generator/formatter/__init__.py index e69de29bb..32227bc90 100644 --- a/datamodel_code_generator/formatter/__init__.py +++ b/datamodel_code_generator/formatter/__init__.py @@ -0,0 +1,5 @@ +from .isort import IsortCodeFormatter + +__all__ = [ + 'IsortCodeFormatter', +] diff --git a/datamodel_code_generator/formatter/isort.py b/datamodel_code_generator/formatter/isort.py new file mode 100644 index 000000000..e1b2e2965 --- /dev/null +++ b/datamodel_code_generator/formatter/isort.py @@ -0,0 +1,47 @@ +from pathlib import Path +from typing import Any, ClassVar, Dict + +import isort + +from .base import BaseCodeFormatter + + +class IsortCodeFormatter(BaseCodeFormatter): + formatter_name: ClassVar[str] = 'isort' + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + super().__init__(formatter_kwargs=formatter_kwargs) + + if 'settings_path' not in self.formatter_kwargs: + settings_path = Path().resolve() + else: + settings_path = Path(self.formatter_kwargs['settings_path']) + + self.settings_path: str = str(settings_path) + self.isort_config_kwargs: Dict[str, Any] = {} + + if 'known_third_party' in self.formatter_kwargs: + self.isort_config_kwargs['known_third_party'] = self.formatter_kwargs[ + 'known_third_party' + ] + + if isort.__version__.startswith('4.'): + self.isort_config = None + else: + self.isort_config = isort.Config( + settings_path=self.settings_path, **self.isort_config_kwargs + ) + + if isort.__version__.startswith('4.'): + + def apply(self, code: str) -> str: + return isort.SortImports( + file_contents=code, + settings_path=self.settings_path, + **self.isort_config_kwargs, + ).output + + else: + + def apply(self, code: str) -> str: + return isort.code(code, config=self.isort_config) diff --git a/tests/formatter/test_isort.py b/tests/formatter/test_isort.py new file mode 100644 index 000000000..1222c8954 --- /dev/null +++ b/tests/formatter/test_isort.py @@ -0,0 +1,15 @@ +from datamodel_code_generator.formatter.base import ( + BaseCodeFormatter, + load_code_formatter, +) +from datamodel_code_generator.formatter.isort import IsortCodeFormatter + + +def test_isort_formatter_is_subclass_if_base(): + assert issubclass(IsortCodeFormatter, BaseCodeFormatter) + assert IsortCodeFormatter.formatter_name == 'isort' + assert hasattr(IsortCodeFormatter, 'apply') + + +def test_load_isort_formatter(): + _ = load_code_formatter('datamodel_code_generator.formatter.IsortCodeFormatter', {}) From 09188e3b432f2682247f3a295b63dfcf4e9f8aa7 Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 13:39:57 +0300 Subject: [PATCH 3/8] Add black formatter --- .../formatter/__init__.py | 2 + datamodel_code_generator/formatter/black.py | 108 ++++++++++++++++++ tests/formatter/test_black.py | 15 +++ 3 files changed, 125 insertions(+) create mode 100644 datamodel_code_generator/formatter/black.py create mode 100644 tests/formatter/test_black.py diff --git a/datamodel_code_generator/formatter/__init__.py b/datamodel_code_generator/formatter/__init__.py index 32227bc90..b77445b58 100644 --- a/datamodel_code_generator/formatter/__init__.py +++ b/datamodel_code_generator/formatter/__init__.py @@ -1,5 +1,7 @@ +from .black import BlackCodeFormatter from .isort import IsortCodeFormatter __all__ = [ 'IsortCodeFormatter', + 'BlackCodeFormatter', ] diff --git a/datamodel_code_generator/formatter/black.py b/datamodel_code_generator/formatter/black.py new file mode 100644 index 000000000..9e367e96b --- /dev/null +++ b/datamodel_code_generator/formatter/black.py @@ -0,0 +1,108 @@ +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Dict + +import black + +from datamodel_code_generator.util import cached_property + +from .base import BaseCodeFormatter + + +class PythonVersion(Enum): + PY_36 = '3.6' + PY_37 = '3.7' + PY_38 = '3.8' + PY_39 = '3.9' + PY_310 = '3.10' + PY_311 = '3.11' + PY_312 = '3.12' + + @cached_property + def _is_py_38_or_later(self) -> bool: # pragma: no cover + return self.value not in {self.PY_36.value, self.PY_37.value} # type: ignore + + @cached_property + def _is_py_39_or_later(self) -> bool: # pragma: no cover + return self.value not in {self.PY_36.value, self.PY_37.value, self.PY_38.value} # type: ignore + + @cached_property + def _is_py_310_or_later(self) -> bool: # pragma: no cover + return self.value not in { + self.PY_36.value, + self.PY_37.value, + self.PY_38.value, + self.PY_39.value, + } # type: ignore + + @cached_property + def _is_py_311_or_later(self) -> bool: # pragma: no cover + return self.value not in { + self.PY_36.value, + self.PY_37.value, + self.PY_38.value, + self.PY_39.value, + self.PY_310.value, + } # type: ignore + + @property + def has_literal_type(self) -> bool: + return self._is_py_38_or_later + + @property + def has_union_operator(self) -> bool: # pragma: no cover + return self._is_py_310_or_later + + @property + def has_annotated_type(self) -> bool: + return self._is_py_39_or_later + + @property + def has_typed_dict(self) -> bool: + return self._is_py_38_or_later + + @property + def has_typed_dict_non_required(self) -> bool: + return self._is_py_311_or_later + + +if TYPE_CHECKING: + + class _TargetVersion(Enum): + ... + + BLACK_PYTHON_VERSION: Dict[PythonVersion, _TargetVersion] +else: + BLACK_PYTHON_VERSION: Dict[PythonVersion, black.TargetVersion] = { + v: getattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}') + for v in PythonVersion + if hasattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}') + } + + +class BlackCodeFormatter(BaseCodeFormatter): + formatter_name: ClassVar[str] = 'black' + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + super().__init__(formatter_kwargs=formatter_kwargs) + + if TYPE_CHECKING: + self.black_mode: black.FileMode + else: + self.black_mode = black.FileMode( + target_versions={ + BLACK_PYTHON_VERSION[formatter_kwargs.get('target-version', '3.7')] + }, + line_length=formatter_kwargs.get( + 'line-length', black.DEFAULT_LINE_LENGTH + ), + string_normalization=not formatter_kwargs.get( + 'skip-string-normalization', True + ), + **formatter_kwargs, + ) + + def apply(self, code: str) -> str: + return black.format_str( + code, + mode=self.black_mode, + ) diff --git a/tests/formatter/test_black.py b/tests/formatter/test_black.py new file mode 100644 index 000000000..8b669f9f4 --- /dev/null +++ b/tests/formatter/test_black.py @@ -0,0 +1,15 @@ +from datamodel_code_generator.formatter.base import ( + BaseCodeFormatter, + load_code_formatter, +) +from datamodel_code_generator.formatter.black import BlackCodeFormatter + + +def test_black_formatter_is_subclass_if_base(): + assert issubclass(BlackCodeFormatter, BaseCodeFormatter) + assert BlackCodeFormatter.formatter_name == 'black' + assert hasattr(BlackCodeFormatter, 'apply') + + +def test_load_black_formatter(): + _ = load_code_formatter('datamodel_code_generator.formatter.BlackCodeFormatter', {}) From 5f0e09baf3a4af4c377ebd0ba729ee87de50af0b Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 13:58:17 +0300 Subject: [PATCH 4/8] Fix test for black --- datamodel_code_generator/formatter/black.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datamodel_code_generator/formatter/black.py b/datamodel_code_generator/formatter/black.py index 9e367e96b..fed2415c9 100644 --- a/datamodel_code_generator/formatter/black.py +++ b/datamodel_code_generator/formatter/black.py @@ -90,7 +90,7 @@ def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: else: self.black_mode = black.FileMode( target_versions={ - BLACK_PYTHON_VERSION[formatter_kwargs.get('target-version', '3.7')] + BLACK_PYTHON_VERSION[formatter_kwargs.get('target-version', PythonVersion.PY_37)] }, line_length=formatter_kwargs.get( 'line-length', black.DEFAULT_LINE_LENGTH From e467f8b4fac55313a9db1d013522f4da165140f7 Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 14:06:31 +0300 Subject: [PATCH 5/8] Run poetry run scripts/format.sh --- datamodel_code_generator/formatter/black.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datamodel_code_generator/formatter/black.py b/datamodel_code_generator/formatter/black.py index fed2415c9..3531e0207 100644 --- a/datamodel_code_generator/formatter/black.py +++ b/datamodel_code_generator/formatter/black.py @@ -90,7 +90,9 @@ def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: else: self.black_mode = black.FileMode( target_versions={ - BLACK_PYTHON_VERSION[formatter_kwargs.get('target-version', PythonVersion.PY_37)] + BLACK_PYTHON_VERSION[ + formatter_kwargs.get('target-version', PythonVersion.PY_37) + ] }, line_length=formatter_kwargs.get( 'line-length', black.DEFAULT_LINE_LENGTH From ce02e72602ed2ec7ddc17d1b44d2e943e5f38538 Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 23:22:12 +0300 Subject: [PATCH 6/8] Add empty ruff formatter --- datamodel_code_generator/formatter/__init__.py | 2 ++ datamodel_code_generator/formatter/ruff.py | 13 +++++++++++++ tests/formatter/test_ruff.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 datamodel_code_generator/formatter/ruff.py create mode 100644 tests/formatter/test_ruff.py diff --git a/datamodel_code_generator/formatter/__init__.py b/datamodel_code_generator/formatter/__init__.py index b77445b58..097785dfc 100644 --- a/datamodel_code_generator/formatter/__init__.py +++ b/datamodel_code_generator/formatter/__init__.py @@ -1,7 +1,9 @@ from .black import BlackCodeFormatter from .isort import IsortCodeFormatter +from .ruff import RuffCodeFormatter __all__ = [ 'IsortCodeFormatter', 'BlackCodeFormatter', + 'RuffCodeFormatter', ] diff --git a/datamodel_code_generator/formatter/ruff.py b/datamodel_code_generator/formatter/ruff.py new file mode 100644 index 000000000..af8941a04 --- /dev/null +++ b/datamodel_code_generator/formatter/ruff.py @@ -0,0 +1,13 @@ +from typing import Any, ClassVar, Dict + +from .base import BaseCodeFormatter + + +class RuffCodeFormatter(BaseCodeFormatter): + formatter_name: ClassVar[str] = 'ruff' + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + super().__init__(formatter_kwargs=formatter_kwargs) + + def apply(self, code: str) -> str: + pass diff --git a/tests/formatter/test_ruff.py b/tests/formatter/test_ruff.py new file mode 100644 index 000000000..fb71ddc33 --- /dev/null +++ b/tests/formatter/test_ruff.py @@ -0,0 +1,15 @@ +from datamodel_code_generator.formatter.base import ( + BaseCodeFormatter, + load_code_formatter, +) +from datamodel_code_generator.formatter.ruff import RuffCodeFormatter + + +def test_ruff_formatter_is_subclass_if_base(): + assert issubclass(RuffCodeFormatter, BaseCodeFormatter) + assert RuffCodeFormatter.formatter_name == 'ruff' + assert hasattr(RuffCodeFormatter, 'apply') + + +def test_load_ruff_formatter(): + _ = load_code_formatter('datamodel_code_generator.formatter.RuffCodeFormatter', {}) From d8067f6ab7d97396a6c06798678916b0ddde0d3b Mon Sep 17 00:00:00 2001 From: denisart Date: Mon, 27 Nov 2023 23:22:42 +0300 Subject: [PATCH 7/8] Added runner for code formatters --- datamodel_code_generator/formatter/base.py | 95 +++++++++++++- datamodel_code_generator/formatter/black.py | 134 ++++++++------------ tests/formatter/test_base.py | 85 +++++++++++++ 3 files changed, 234 insertions(+), 80 deletions(-) diff --git a/datamodel_code_generator/formatter/base.py b/datamodel_code_generator/formatter/base.py index 58e7784dd..389e0fcb2 100644 --- a/datamodel_code_generator/formatter/base.py +++ b/datamodel_code_generator/formatter/base.py @@ -1,5 +1,6 @@ from importlib import import_module -from typing import Any, ClassVar, Dict +from pathlib import Path +from typing import Any, ClassVar, Dict, List, Optional from datamodel_code_generator.imports import Import @@ -89,3 +90,95 @@ def load_code_formatter( ) return formatter_class(formatter_kwargs=custom_formatter_kwargs) + + +class CodeFormattersRunner: + """Runner of code formatters.""" + + disable_default_formatter: bool + default_formatters: List[BaseCodeFormatter] + custom_formatters: List[BaseCodeFormatter] + custom_formatters_kwargs: Dict[str, Any] + + _mapping_from_formatter_name_to_formatter_module: Dict[str, str] = { + 'black': 'datamodel_code_generator.formatter.BlackCodeFormatter', + 'isort': 'datamodel_code_generator.formatter.IsortCodeFormatter', + 'ruff': 'datamodel_code_generator.formatter.RuffCodeFormatter', + } + _default_formatters: List[str] = [ + 'datamodel_code_generator.formatter.RuffCodeFormatter' + ] + + def __init__( + self, + disable_default_formatter: bool = False, + default_formatter: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, + settings_path: Optional[Path] = None, + wrap_string_literal: Optional[bool] = None, + skip_string_normalization: bool = True, + known_third_party: Optional[List[str]] = None, + ) -> None: + self.disable_default_formatter = disable_default_formatter + self.custom_formatters_kwargs = custom_formatters_kwargs or {} + + self.default_formatters = self._check_default_formatters(default_formatter) + self.custom_formatters = self._check_custom_formatters(custom_formatters) + + self.custom_formatters_kwargs['black'] = { + 'settings_path': settings_path, + 'wrap_string_literal': wrap_string_literal, + 'skip_string_normalization': skip_string_normalization, + } + self.custom_formatters_kwargs['isort'] = { + 'settings_path': settings_path, + 'known_third_party': known_third_party, + } + + def _load_formatters(self, formatters: List[str]) -> List[BaseCodeFormatter]: + return [ + load_code_formatter(custom_formatter_import, self.custom_formatters_kwargs) + for custom_formatter_import in formatters + ] + + def _check_default_formatters( + self, + default_formatters: Optional[List[str]], + ) -> List[BaseCodeFormatter]: + if self.disable_default_formatter is True: + return [] + + if default_formatters is None: + return self._load_formatters(self._default_formatters) + + formatters = [] + for formatter in default_formatters: + if formatter not in self._mapping_from_formatter_name_to_formatter_module: + raise ValueError(f'Unknown default formatter: {formatter}') + + formatters.append( + self._mapping_from_formatter_name_to_formatter_module[formatter] + ) + + return self._load_formatters(formatters) + + def _check_custom_formatters( + self, custom_formatters: Optional[List[str]] + ) -> List[BaseCodeFormatter]: + if custom_formatters is None: + return [] + + return self._load_formatters(custom_formatters) + + def format_code( + self, + code: str, + ) -> str: + for formatter in self.default_formatters: + code = formatter.apply(code) + + for formatter in self.custom_formatters: + code = formatter.apply(code) + + return code diff --git a/datamodel_code_generator/formatter/black.py b/datamodel_code_generator/formatter/black.py index 3531e0207..90d46ed9d 100644 --- a/datamodel_code_generator/formatter/black.py +++ b/datamodel_code_generator/formatter/black.py @@ -1,90 +1,56 @@ -from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Dict +from warnings import warn import black -from datamodel_code_generator.util import cached_property +from datamodel_code_generator.format import ( + BLACK_PYTHON_VERSION, + PythonVersion, + black_find_project_root, +) +from datamodel_code_generator.util import load_toml from .base import BaseCodeFormatter -class PythonVersion(Enum): - PY_36 = '3.6' - PY_37 = '3.7' - PY_38 = '3.8' - PY_39 = '3.9' - PY_310 = '3.10' - PY_311 = '3.11' - PY_312 = '3.12' - - @cached_property - def _is_py_38_or_later(self) -> bool: # pragma: no cover - return self.value not in {self.PY_36.value, self.PY_37.value} # type: ignore - - @cached_property - def _is_py_39_or_later(self) -> bool: # pragma: no cover - return self.value not in {self.PY_36.value, self.PY_37.value, self.PY_38.value} # type: ignore - - @cached_property - def _is_py_310_or_later(self) -> bool: # pragma: no cover - return self.value not in { - self.PY_36.value, - self.PY_37.value, - self.PY_38.value, - self.PY_39.value, - } # type: ignore - - @cached_property - def _is_py_311_or_later(self) -> bool: # pragma: no cover - return self.value not in { - self.PY_36.value, - self.PY_37.value, - self.PY_38.value, - self.PY_39.value, - self.PY_310.value, - } # type: ignore - - @property - def has_literal_type(self) -> bool: - return self._is_py_38_or_later - - @property - def has_union_operator(self) -> bool: # pragma: no cover - return self._is_py_310_or_later - - @property - def has_annotated_type(self) -> bool: - return self._is_py_39_or_later - - @property - def has_typed_dict(self) -> bool: - return self._is_py_38_or_later - - @property - def has_typed_dict_non_required(self) -> bool: - return self._is_py_311_or_later - - -if TYPE_CHECKING: - - class _TargetVersion(Enum): - ... - - BLACK_PYTHON_VERSION: Dict[PythonVersion, _TargetVersion] -else: - BLACK_PYTHON_VERSION: Dict[PythonVersion, black.TargetVersion] = { - v: getattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}') - for v in PythonVersion - if hasattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}') - } - - class BlackCodeFormatter(BaseCodeFormatter): formatter_name: ClassVar[str] = 'black' def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: super().__init__(formatter_kwargs=formatter_kwargs) + if 'settings_path' not in self.formatter_kwargs: + settings_path = Path().resolve() + else: + settings_path = Path(self.formatter_kwargs['settings_path']) + + wrap_string_literal = self.formatter_kwargs.get('wrap_string_literal', None) + skip_string_normalization = self.formatter_kwargs.get( + 'skip_string_normalization', True + ) + + config = self._load_config(settings_path) + + black_kwargs: Dict[str, Any] = {} + if wrap_string_literal is not None: + experimental_string_processing = wrap_string_literal + else: + experimental_string_processing = config.get( + 'experimental-string-processing' + ) + + if experimental_string_processing is not None: # pragma: no cover + if black.__version__.startswith('19.'): # type: ignore + warn( + f"black doesn't support `experimental-string-processing` option" # type: ignore + f' for wrapping string literal in {black.__version__}' + ) + else: + black_kwargs[ + 'experimental_string_processing' + ] = experimental_string_processing + if TYPE_CHECKING: self.black_mode: black.FileMode else: @@ -94,15 +60,25 @@ def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: formatter_kwargs.get('target-version', PythonVersion.PY_37) ] }, - line_length=formatter_kwargs.get( - 'line-length', black.DEFAULT_LINE_LENGTH - ), - string_normalization=not formatter_kwargs.get( - 'skip-string-normalization', True - ), + line_length=config.get('line-length', black.DEFAULT_LINE_LENGTH), + string_normalization=not skip_string_normalization + or not config.get('skip-string-normalization', True), **formatter_kwargs, ) + @staticmethod + def _load_config(settings_path: Path) -> Dict[str, Any]: + root = black_find_project_root((settings_path,)) + path = root / 'pyproject.toml' + + if path.is_file(): + pyproject_toml = load_toml(path) + config = pyproject_toml.get('tool', {}).get('black', {}) + else: + config = {} + + return config + def apply(self, code: str) -> str: return black.format_str( code, diff --git a/tests/formatter/test_base.py b/tests/formatter/test_base.py index 05167481d..ef04170ff 100644 --- a/tests/formatter/test_base.py +++ b/tests/formatter/test_base.py @@ -4,6 +4,7 @@ from datamodel_code_generator.formatter.base import ( BaseCodeFormatter, + CodeFormattersRunner, load_code_formatter, ) @@ -82,3 +83,87 @@ def test_add_license_formatter_with_kwargs(): x = 1 y = 2""" ) + + +def test_runner_quick_init(): + runner = CodeFormattersRunner() + + assert runner.disable_default_formatter is False + assert runner.custom_formatters_kwargs == { + 'black': { + 'settings_path': None, + 'wrap_string_literal': None, + 'skip_string_normalization': True, + }, + 'isort': { + 'settings_path': None, + 'known_third_party': None, + }, + } + + assert len(runner.default_formatters) == 1 + assert runner.default_formatters[0].formatter_name == 'ruff' + + assert runner.custom_formatters == [] + + +def test_runner_set_default_formatters(): + runner = CodeFormattersRunner( + default_formatter=['black', 'isort'], + ) + + assert runner.disable_default_formatter is False + + assert len(runner.default_formatters) == 2 + assert runner.default_formatters[0].formatter_name == 'black' + assert runner.default_formatters[1].formatter_name == 'isort' + + assert runner.custom_formatters == [] + + +def test_runner_set_default_formatters_disable(): + runner = CodeFormattersRunner( + default_formatter=['black', 'isort'], + disable_default_formatter=True, + ) + + assert runner.disable_default_formatter is True + + assert len(runner.default_formatters) == 0 + assert runner.custom_formatters == [] + + +def test_runner_custom_formatters(): + runner = CodeFormattersRunner(custom_formatters=[ADD_LICENSE_FORMATTER]) + + assert len(runner.custom_formatters) == 1 + assert runner.custom_formatters[0].formatter_name == 'license_formatter' + + +def test_runner_custom_formatters_kwargs(): + runner = CodeFormattersRunner( + custom_formatters=[ADD_LICENSE_FORMATTER], + custom_formatters_kwargs={ + 'license_formatter': { + 'license_txt': 'MIT License\n\nCopyright (c) 2023 Blah-blah\n' + } + }, + ) + + assert len(runner.custom_formatters) == 1 + assert runner.custom_formatters[0].formatter_name == 'license_formatter' + + assert runner.custom_formatters_kwargs == { + 'black': { + 'settings_path': None, + 'wrap_string_literal': None, + 'skip_string_normalization': True, + }, + 'isort': { + 'settings_path': None, + 'known_third_party': None, + }, + 'license_formatter': { + 'license_txt': 'MIT License\n\nCopyright (c) 2023 Blah-blah\n' + }, + } From c33815c8a124b5df73c9f75444894aa508d5954c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:42:08 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- datamodel_code_generator/formatter/black.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datamodel_code_generator/formatter/black.py b/datamodel_code_generator/formatter/black.py index 90d46ed9d..681cc2032 100644 --- a/datamodel_code_generator/formatter/black.py +++ b/datamodel_code_generator/formatter/black.py @@ -47,9 +47,9 @@ def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: f' for wrapping string literal in {black.__version__}' ) else: - black_kwargs[ - 'experimental_string_processing' - ] = experimental_string_processing + black_kwargs['experimental_string_processing'] = ( + experimental_string_processing + ) if TYPE_CHECKING: self.black_mode: black.FileMode