From d030ce27a197e0a3e819b311dca5c5421d1cf5ba Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 Apr 2024 00:04:10 -0500 Subject: [PATCH] [TVMScript] Optionally use `ruff format` instead of `black` (#16876) * [TVMScript] Optionally use `ruff format` instead of `black` The `ruff format` tool is significantly faster than the `black` formatter. For some particularly long TVMScript modules, using it can reduce the time required to show a formatted module from ~5 minutes to ~1 minute. This commit updates the `.show()` function to apply the optionally formatting using `ruff format` if available, falling back to `black` otherwise. * Fix lint error --- python/tvm/script/highlight.py | 95 +++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 18 deletions(-) diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py index be0de5a6bf2b..e017c1e6cab2 100644 --- a/python/tvm/script/highlight.py +++ b/python/tvm/script/highlight.py @@ -17,7 +17,10 @@ """Highlight printed TVM script. """ +import functools import os +import shutil +import subprocess import sys import warnings from typing import Any, Optional, Union @@ -92,7 +95,73 @@ def cprint( print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style))) -def _format(code_str: str) -> str: +@functools.lru_cache +def _get_formatter(formatter: Optional[str] = None): + def get_ruff_formatter(): + if shutil.which("ruff") is None: + return None + + def formatter(code_str): + proc = subprocess.Popen( + ["ruff", "format", "--stdin-filename=TVMScript"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding="utf-8", + ) + stdout, _stderr = proc.communicate(code_str) + return stdout + + return formatter + + def get_black_formatter(): + try: + # pylint: disable=import-outside-toplevel + import black + except ImportError: + return None + + def formatter(code_str): + return black.format_str(code_str, mode=black.FileMode()) + + return formatter + + def get_fallback_formatter(): + def formatter(code_str): + with warnings.catch_warnings(): + warnings.simplefilter("once", UserWarning) + ruff_install_cmd = sys.executable + " -m pip install ruff" + black_install_cmd = ( + sys.executable + ' -m pip install "black==22.3.0" --upgrade --user' + ) + warnings.warn( + f"Neither the 'ruff' formatter nor the 'black' formatter is available. " + f"To print formatted TVM script, please a formatter. \n" + f"To install ruff: {ruff_install_cmd}\n" + f"To install black: {black_install_cmd}", + category=UserWarning, + ) + return code_str + + return formatter + + # formatter = "black" + if formatter is None: + options = [get_ruff_formatter, get_black_formatter] + elif formatter == "ruff": + options = [get_ruff_formatter] + elif formatter == "black": + options = [get_black_formatter] + else: + raise ValueError(f"Unknown formatter: {formatter}") + + for option in options: + func = option() + if func is not None: + return func + return get_fallback_formatter() + + +def _format(code_str: str, formatter: Optional[str] = None) -> str: """Format a code string using Black. Parameters @@ -101,29 +170,19 @@ def _format(code_str: str) -> str: The string containing Python/TVMScript code to format + formatter: Optional[str] + + The formatter to use. Can specify `ruff`, `black`, or + auto-select by passing `None`. + Returns ------- formatted: str The formatted Python/TVMScript code + """ - try: - # pylint: disable=import-outside-toplevel - import black - except ImportError as err: - with warnings.catch_warnings(): - warnings.simplefilter("once", UserWarning) - install_cmd = sys.executable + ' -m pip install "black==22.3.0" --upgrade --user' - warnings.warn( - str(err) - + "\n" - + "To print formatted TVM script, please install the formatter 'Black':\n" - + install_cmd, - category=UserWarning, - ) - return code_str - else: - return black.format_str(code_str, mode=black.FileMode()) + return _get_formatter(formatter)(code_str) def _get_pygments_style(