-
-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add formatter module #1746
base: main
Are you sure you want to change the base?
Add formatter module #1746
Changes from 8 commits
4359885
c26f5d7
09188e3
5f0e09b
e467f8b
ce02e72
d8067f6
2894f63
c317d71
c33815c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .black import BlackCodeFormatter | ||
from .isort import IsortCodeFormatter | ||
from .ruff import RuffCodeFormatter | ||
|
||
__all__ = [ | ||
'IsortCodeFormatter', | ||
'BlackCodeFormatter', | ||
'RuffCodeFormatter', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from importlib import import_module | ||
from pathlib import Path | ||
from typing import Any, ClassVar, Dict, List, Optional | ||
|
||
from datamodel_code_generator.imports import Import | ||
|
||
|
||
class BaseCodeFormatter: | ||
"""An abstract class for representing a code formatter. | ||
|
||
All formatters that format a generated code should subclass | ||
it. All subclass should override `apply` method which | ||
has a string with code in input and returns a formatted code in string. | ||
We also need to determine a `formatter_name` field | ||
which is unique name of formatter. | ||
|
||
Example: | ||
>>> class CustomHeaderCodeFormatter(BaseCodeFormatter): | ||
... formatter_name: ClassVar[str] = "custom" | ||
... def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
... super().__init__(formatter_kwargs=formatter_kwargs) | ||
... | ||
... default_header = "my header" | ||
... self.header: str = self.formatter_kwargs.get("header", default_header) | ||
... def apply(self, code: str) -> str: | ||
... return f'# {self.header}\\n{code}' | ||
... | ||
... formatter_kwargs = {"header": "formatted with CustomHeaderCodeFormatter"} | ||
... formatter = CustomHeaderCodeFormatter(formatter_kwargs) | ||
... code = '''x = 1\ny = 2''' | ||
... print(formatter.apply(code)) | ||
# formatted with CustomHeaderCodeFormatter | ||
x = 1 | ||
y = 2 | ||
|
||
""" | ||
|
||
formatter_name: ClassVar[str] = '' | ||
|
||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
if self.formatter_name == '': | ||
raise ValueError('`formatter_name` should be not empty string') | ||
|
||
self.formatter_kwargs = formatter_kwargs | ||
|
||
def apply(self, code: str) -> str: | ||
raise NotImplementedError | ||
|
||
|
||
def load_code_formatter( | ||
custom_formatter_import: str, custom_formatters_kwargs: Dict[str, Any] | ||
) -> BaseCodeFormatter: | ||
"""Load a formatter by import path as string. | ||
|
||
Args: | ||
custom_formatter_import: custom formatter module. | ||
custom_formatters_kwargs: kwargs for custom formatters from config. | ||
|
||
Examples: | ||
for default formatters use | ||
>>> custom_formatter_import = "datamodel_code_generator.formatter.BlackCodeFormatter" | ||
this is equivalent to code | ||
>>> from datamodel_code_generator.formatter import BlackCodeFormatter | ||
|
||
custom formatter | ||
>>> custom_formatter_import = "my_package.my_sub_package.FormatterName" | ||
this is equivalent to code | ||
>>> from my_package.my_sub_package import FormatterName | ||
|
||
""" | ||
|
||
import_ = Import.from_full_path(custom_formatter_import) | ||
imported_module_ = import_module(import_.from_) | ||
|
||
if not hasattr(imported_module_, import_.import_): | ||
raise NameError( | ||
f'Custom formatter module `{import_.from_}` not contains formatter with name `{import_.import_}`' | ||
) | ||
|
||
formatter_class = imported_module_.__getattribute__(import_.import_) | ||
|
||
if not issubclass(formatter_class, BaseCodeFormatter): | ||
raise TypeError( | ||
f'The custom module `{custom_formatter_import}` must inherit from ' | ||
'`datamodel-code-generator.formatter.BaseCodeFormatter`' | ||
) | ||
|
||
custom_formatter_kwargs = custom_formatters_kwargs.get( | ||
formatter_class.formatter_name, {} | ||
) | ||
|
||
return formatter_class(formatter_kwargs=custom_formatter_kwargs) | ||
|
||
|
||
class CodeFormattersRunner: | ||
"""Runner of code formatters.""" | ||
|
||
disable_default_formatter: bool | ||
default_formatters: List[BaseCodeFormatter] | ||
custom_formatters: List[BaseCodeFormatter] | ||
custom_formatters_kwargs: Dict[str, Any] | ||
|
||
_mapping_from_formatter_name_to_formatter_module: Dict[str, str] = { | ||
'black': 'datamodel_code_generator.formatter.BlackCodeFormatter', | ||
'isort': 'datamodel_code_generator.formatter.IsortCodeFormatter', | ||
'ruff': 'datamodel_code_generator.formatter.RuffCodeFormatter', | ||
} | ||
_default_formatters: List[str] = [ | ||
'datamodel_code_generator.formatter.RuffCodeFormatter' | ||
] | ||
|
||
def __init__( | ||
self, | ||
disable_default_formatter: bool = False, | ||
default_formatter: Optional[List[str]] = None, | ||
custom_formatters: Optional[List[str]] = None, | ||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None, | ||
settings_path: Optional[Path] = None, | ||
wrap_string_literal: Optional[bool] = None, | ||
skip_string_normalization: bool = True, | ||
known_third_party: Optional[List[str]] = None, | ||
) -> None: | ||
self.disable_default_formatter = disable_default_formatter | ||
self.custom_formatters_kwargs = custom_formatters_kwargs or {} | ||
|
||
self.default_formatters = self._check_default_formatters(default_formatter) | ||
self.custom_formatters = self._check_custom_formatters(custom_formatters) | ||
|
||
self.custom_formatters_kwargs['black'] = { | ||
'settings_path': settings_path, | ||
'wrap_string_literal': wrap_string_literal, | ||
'skip_string_normalization': skip_string_normalization, | ||
} | ||
self.custom_formatters_kwargs['isort'] = { | ||
'settings_path': settings_path, | ||
'known_third_party': known_third_party, | ||
} | ||
|
||
def _load_formatters(self, formatters: List[str]) -> List[BaseCodeFormatter]: | ||
return [ | ||
load_code_formatter(custom_formatter_import, self.custom_formatters_kwargs) | ||
for custom_formatter_import in formatters | ||
] | ||
|
||
def _check_default_formatters( | ||
self, | ||
default_formatters: Optional[List[str]], | ||
) -> List[BaseCodeFormatter]: | ||
if self.disable_default_formatter is True: | ||
return [] | ||
|
||
if default_formatters is None: | ||
return self._load_formatters(self._default_formatters) | ||
Comment on lines
+152
to
+153
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the method expect the path string? If we want to pass the custom formatted class from CLI, how about loading the custom formatted class from str only for that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems I didn't fully understand your proposal. Can you explain or give an example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @denisart class CodeFormattersRunner:
...
def __init__(
self,
disable_default_formatter: bool = False,
default_formatter: Optional[List[BaseCodeFormatter]] = None,
custom_formatters: Optional[List[BaseCodeFormatter]] = None, if we give the external formatter to custom_comatter = load_code_formatter("my_package.my_sub_package.FormatterName")
runner = CodeFormattersRunner(custom_formatters=[custom_comatter]) I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @koxudaxi , Do you propose loading formatters by path in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes!!
Yes, it is. Thank you for saying what I was trying to say. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I agree with you. I'll create a fix. |
||
|
||
formatters = [] | ||
for formatter in default_formatters: | ||
if formatter not in self._mapping_from_formatter_name_to_formatter_module: | ||
raise ValueError(f'Unknown default formatter: {formatter}') | ||
|
||
formatters.append( | ||
self._mapping_from_formatter_name_to_formatter_module[formatter] | ||
) | ||
|
||
return self._load_formatters(formatters) | ||
|
||
def _check_custom_formatters( | ||
self, custom_formatters: Optional[List[str]] | ||
) -> List[BaseCodeFormatter]: | ||
if custom_formatters is None: | ||
return [] | ||
|
||
return self._load_formatters(custom_formatters) | ||
|
||
def format_code( | ||
self, | ||
code: str, | ||
) -> str: | ||
for formatter in self.default_formatters: | ||
code = formatter.apply(code) | ||
|
||
for formatter in self.custom_formatters: | ||
code = formatter.apply(code) | ||
|
||
return code |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, ClassVar, Dict | ||
from warnings import warn | ||
|
||
import black | ||
|
||
from datamodel_code_generator.format import ( | ||
BLACK_PYTHON_VERSION, | ||
PythonVersion, | ||
black_find_project_root, | ||
) | ||
from datamodel_code_generator.util import load_toml | ||
|
||
from .base import BaseCodeFormatter | ||
|
||
|
||
class BlackCodeFormatter(BaseCodeFormatter): | ||
formatter_name: ClassVar[str] = 'black' | ||
|
||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
super().__init__(formatter_kwargs=formatter_kwargs) | ||
|
||
if 'settings_path' not in self.formatter_kwargs: | ||
settings_path = Path().resolve() | ||
else: | ||
settings_path = Path(self.formatter_kwargs['settings_path']) | ||
|
||
wrap_string_literal = self.formatter_kwargs.get('wrap_string_literal', None) | ||
skip_string_normalization = self.formatter_kwargs.get( | ||
'skip_string_normalization', True | ||
) | ||
|
||
config = self._load_config(settings_path) | ||
|
||
black_kwargs: Dict[str, Any] = {} | ||
if wrap_string_literal is not None: | ||
experimental_string_processing = wrap_string_literal | ||
else: | ||
experimental_string_processing = config.get( | ||
'experimental-string-processing' | ||
) | ||
|
||
if experimental_string_processing is not None: # pragma: no cover | ||
if black.__version__.startswith('19.'): # type: ignore | ||
warn( | ||
f"black doesn't support `experimental-string-processing` option" # type: ignore | ||
f' for wrapping string literal in {black.__version__}' | ||
) | ||
else: | ||
black_kwargs[ | ||
'experimental_string_processing' | ||
] = experimental_string_processing | ||
|
||
if TYPE_CHECKING: | ||
self.black_mode: black.FileMode | ||
else: | ||
self.black_mode = black.FileMode( | ||
target_versions={ | ||
BLACK_PYTHON_VERSION[ | ||
formatter_kwargs.get('target-version', PythonVersion.PY_37) | ||
] | ||
}, | ||
line_length=config.get('line-length', black.DEFAULT_LINE_LENGTH), | ||
string_normalization=not skip_string_normalization | ||
or not config.get('skip-string-normalization', True), | ||
**formatter_kwargs, | ||
) | ||
|
||
@staticmethod | ||
def _load_config(settings_path: Path) -> Dict[str, Any]: | ||
root = black_find_project_root((settings_path,)) | ||
path = root / 'pyproject.toml' | ||
|
||
if path.is_file(): | ||
pyproject_toml = load_toml(path) | ||
config = pyproject_toml.get('tool', {}).get('black', {}) | ||
else: | ||
config = {} | ||
|
||
return config | ||
|
||
def apply(self, code: str) -> str: | ||
return black.format_str( | ||
code, | ||
mode=self.black_mode, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from pathlib import Path | ||
from typing import Any, ClassVar, Dict | ||
|
||
import isort | ||
|
||
from .base import BaseCodeFormatter | ||
|
||
|
||
class IsortCodeFormatter(BaseCodeFormatter): | ||
formatter_name: ClassVar[str] = 'isort' | ||
|
||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
super().__init__(formatter_kwargs=formatter_kwargs) | ||
|
||
if 'settings_path' not in self.formatter_kwargs: | ||
settings_path = Path().resolve() | ||
else: | ||
settings_path = Path(self.formatter_kwargs['settings_path']) | ||
|
||
self.settings_path: str = str(settings_path) | ||
self.isort_config_kwargs: Dict[str, Any] = {} | ||
|
||
if 'known_third_party' in self.formatter_kwargs: | ||
self.isort_config_kwargs['known_third_party'] = self.formatter_kwargs[ | ||
'known_third_party' | ||
] | ||
|
||
if isort.__version__.startswith('4.'): | ||
self.isort_config = None | ||
else: | ||
self.isort_config = isort.Config( | ||
settings_path=self.settings_path, **self.isort_config_kwargs | ||
) | ||
|
||
if isort.__version__.startswith('4.'): | ||
|
||
def apply(self, code: str) -> str: | ||
return isort.SortImports( | ||
file_contents=code, | ||
settings_path=self.settings_path, | ||
**self.isort_config_kwargs, | ||
).output | ||
|
||
else: | ||
|
||
def apply(self, code: str) -> str: | ||
return isort.code(code, config=self.isort_config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Any, ClassVar, Dict | ||
|
||
from .base import BaseCodeFormatter | ||
|
||
|
||
class RuffCodeFormatter(BaseCodeFormatter): | ||
formatter_name: ClassVar[str] = 'ruff' | ||
|
||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
super().__init__(formatter_kwargs=formatter_kwargs) | ||
|
||
def apply(self, code: str) -> str: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Any, Dict, ClassVar | ||
|
||
from datamodel_code_generator.formatter.base import BaseCodeFormatter | ||
|
||
|
||
class LicenseFormatter(BaseCodeFormatter): | ||
"""Add a license to file from license file path.""" | ||
formatter_name: ClassVar[str] = "license_formatter" | ||
|
||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: | ||
super().__init__(formatter_kwargs) | ||
|
||
license_txt = formatter_kwargs.get('license_txt', "a license") | ||
self.license_header = '\n'.join([f'# {line}' for line in license_txt.split('\n')]) | ||
|
||
def apply(self, code: str) -> str: | ||
return f'{self.license_header}\n{code}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think the uses case 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also had the same reaction. But the
Import
class is very suitable in this case.