From 9697bae7dd95078bbcc59cb10fc1a85ca486b93f Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 20 Nov 2024 02:21:15 +0800 Subject: [PATCH] refactor[ux]: refactor preparser (#4293) this PR refactors the pre-parsing routine to use a new `PreParser` object. this will make it easier in the future to keep track of state during pre-parsing. --------- Co-authored-by: Charles Cooper --- tests/functional/grammar/test_grammar.py | 7 +- .../ast/test_annotate_and_optimize_ast.py | 11 +- tests/unit/ast/test_pre_parser.py | 14 ++- vyper/ast/parse.py | 39 +++--- vyper/ast/pre_parser.py | 112 +++++++----------- 5 files changed, 84 insertions(+), 99 deletions(-) diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 0ff8c23477..871ba4547f 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -9,7 +9,7 @@ from vyper.ast import Module, parse_to_ast from vyper.ast.grammar import parse_vyper_source, vyper_grammar -from vyper.ast.pre_parser import pre_parse +from vyper.ast.pre_parser import PreParser def test_basic_grammar(): @@ -102,6 +102,7 @@ def has_no_docstrings(c): max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much] ) def test_grammar_bruteforce(code): - pre_parse_result = pre_parse(code + "\n") - tree = parse_to_ast(pre_parse_result.reformatted_code) + pre_parser = PreParser() + pre_parser.parse(code + "\n") + tree = parse_to_ast(pre_parser.reformatted_code) assert isinstance(tree, Module) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 39ea899bd9..afba043113 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,6 +1,6 @@ import ast as python_ast -from vyper.ast.parse import annotate_python_ast, pre_parse +from vyper.ast.parse import PreParser, annotate_python_ast class AssertionVisitor(python_ast.NodeVisitor): @@ -28,12 +28,13 @@ def foo() -> int128: def get_contract_info(source_code): - pre_parse_result = pre_parse(source_code) - py_ast = python_ast.parse(pre_parse_result.reformatted_code) + pre_parser = PreParser() + pre_parser.parse(source_code) + py_ast = python_ast.parse(pre_parser.reformatted_code) - annotate_python_ast(py_ast, pre_parse_result.reformatted_code, pre_parse_result) + annotate_python_ast(py_ast, pre_parser.reformatted_code, pre_parser) - return py_ast, pre_parse_result.reformatted_code + return py_ast, pre_parser.reformatted_code def test_it_annotates_ast_with_source_code(): diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 5d3f30481c..73712aadb8 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.ast.pre_parser import PreParser, validate_version_pragma from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException, VersionException @@ -174,9 +174,10 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - pre_parse_result = pre_parse(code) + pre_parser = PreParser() + pre_parser.parse(code) - assert pre_parse_result.settings == pre_parse_settings + assert pre_parser.settings == pre_parse_settings compiler_data = CompilerData(code) @@ -203,8 +204,9 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve @pytest.mark.parametrize("code", pragma_venom) def test_parse_venom_pragma(code): - pre_parse_result = pre_parse(code) - assert pre_parse_result.settings.experimental_codegen is True + pre_parser = PreParser() + pre_parser.parse(code) + assert pre_parser.settings.experimental_codegen is True compiler_data = CompilerData(code) assert compiler_data.settings.experimental_codegen is True @@ -252,7 +254,7 @@ def test_parse_venom_pragma(code): @pytest.mark.parametrize("code", invalid_pragmas) def test_invalid_pragma(code): with pytest.raises(StructureException): - pre_parse(code) + PreParser().parse(code) def test_version_exception_in_import(make_input_bundle): diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 1e88241186..5d62072b9e 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -6,7 +6,7 @@ import asttokens from vyper.ast import nodes as vy_ast -from vyper.ast.pre_parser import PreParseResult, pre_parse +from vyper.ast.pre_parser import PreParser from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic, ParserException, SyntaxException from vyper.utils import sha256sum, vyper_warn @@ -54,9 +54,10 @@ def parse_to_ast_with_settings( """ if "\x00" in vyper_source: raise ParserException("No null bytes (\\x00) allowed in the source code.") - pre_parse_result = pre_parse(vyper_source) + pre_parser = PreParser() + pre_parser.parse(vyper_source) try: - py_ast = python_ast.parse(pre_parse_result.reformatted_code) + py_ast = python_ast.parse(pre_parser.reformatted_code) except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors raise SyntaxException(str(e), vyper_source, e.lineno, e.offset) from None @@ -72,20 +73,20 @@ def parse_to_ast_with_settings( annotate_python_ast( py_ast, vyper_source, - pre_parse_result, + pre_parser, source_id=source_id, module_path=module_path, resolved_path=resolved_path, ) # postcondition: consumed all the for loop annotations - assert len(pre_parse_result.for_loop_annotations) == 0 + assert len(pre_parser.for_loop_annotations) == 0 # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint - return pre_parse_result.settings, module + return pre_parser.settings, module def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: @@ -116,7 +117,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: def annotate_python_ast( parsed_ast: python_ast.AST, vyper_source: str, - pre_parse_result: PreParseResult, + pre_parser: PreParser, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -130,8 +131,8 @@ def annotate_python_ast( The AST to be annotated and optimized. vyper_source: str The original vyper source code - pre_parse_result: PreParseResult - Outputs from pre-parsing. + pre_parser: PreParser + PreParser object. Returns ------- @@ -142,7 +143,7 @@ def annotate_python_ast( tokens.mark_tokens(parsed_ast) visitor = AnnotatingVisitor( vyper_source, - pre_parse_result, + pre_parser, tokens, source_id, module_path=module_path, @@ -155,12 +156,12 @@ def annotate_python_ast( class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str - _pre_parse_result: PreParseResult + _pre_parser: PreParser def __init__( self, source_code: str, - pre_parse_result: PreParseResult, + pre_parser: PreParser, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -171,7 +172,7 @@ def __init__( self._module_path = module_path self._resolved_path = resolved_path self._source_code = source_code - self._pre_parse_result = pre_parse_result + self._pre_parser = pre_parser self.counter: int = 0 @@ -265,7 +266,7 @@ def visit_ClassDef(self, node): """ self.generic_visit(node) - node.ast_type = self._pre_parse_result.modification_offsets[(node.lineno, node.col_offset)] + node.ast_type = self._pre_parser.modification_offsets[(node.lineno, node.col_offset)] return node def visit_For(self, node): @@ -274,7 +275,7 @@ def visit_For(self, node): the pre-parser """ key = (node.lineno, node.col_offset) - annotation_tokens = self._pre_parse_result.for_loop_annotations.pop(key) + annotation_tokens = self._pre_parser.for_loop_annotations.pop(key) if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more @@ -342,14 +343,14 @@ def visit_Expr(self, node): # CMC 2024-03-03 consider unremoving this from the enclosing Expr node = node.value key = (node.lineno, node.col_offset) - node.ast_type = self._pre_parse_result.modification_offsets[key] + node.ast_type = self._pre_parser.modification_offsets[key] return node def visit_Await(self, node): start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them self.generic_visit(node) - node.ast_type = self._pre_parse_result.modification_offsets[start_pos] + node.ast_type = self._pre_parser.modification_offsets[start_pos] return node def visit_Call(self, node): @@ -394,10 +395,10 @@ def visit_Constant(self, node): node.ast_type = "NameConstant" elif isinstance(node.value, str): key = (node.lineno, node.col_offset) - if key in self._pre_parse_result.native_hex_literal_locations: + if key in self._pre_parser.hex_string_locations: if len(node.value) % 2 != 0: raise SyntaxException( - "Native hex string must have an even number of characters", + "Hex string must have an even number of characters", self._source_code, node.lineno, node.col_offset, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 5d2abcf645..dbeb6181f9 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -158,67 +158,52 @@ def consume(self, token, result): CUSTOM_EXPRESSION_TYPES = {"extcall": "ExtCall", "staticcall": "StaticCall"} -class PreParseResult: +class PreParser: # Compilation settings based on the directives in the source code settings: Settings # A mapping of class names to their original class types. modification_offsets: dict[tuple[int, int], str] # A mapping of line/column offsets of `For` nodes to the annotation of the for loop target for_loop_annotations: dict[tuple[int, int], list[TokenInfo]] - # A list of line/column offsets of native hex literals - native_hex_literal_locations: list[tuple[int, int]] + # A list of line/column offsets of hex string literals + hex_string_locations: list[tuple[int, int]] # Reformatted python source string. reformatted_code: str - def __init__( - self, - settings, - modification_offsets, - for_loop_annotations, - native_hex_literal_locations, - reformatted_code, - ): - self.settings = settings - self.modification_offsets = modification_offsets - self.for_loop_annotations = for_loop_annotations - self.native_hex_literal_locations = native_hex_literal_locations - self.reformatted_code = reformatted_code - - -def pre_parse(code: str) -> PreParseResult: - """ - Re-formats a vyper source string into a python source string and performs - some validation. More specifically, - - * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword - * Validates "@version" pragma against current compiler version - * Prevents direct use of python "class" keyword - * Prevents use of python semi-colon statement separator - * Extracts type annotation of for loop iterators into a separate dictionary - - Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"), and a mapping of line numbers - of for loops to the type annotation of their iterators. - - Parameters - ---------- - code : str - The vyper source code to be re-formatted. - - Returns - ------- - PreParseResult - Outputs for transforming the python AST to vyper AST - """ - result: list[TokenInfo] = [] - modification_offsets: dict[tuple[int, int], str] = {} - settings = Settings() - for_parser = ForParser(code) - native_hex_parser = HexStringParser() + def parse(self, code: str): + """ + Re-formats a vyper source string into a python source string and performs + some validation. More specifically, + + * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword + * Validates "@version" pragma against current compiler version + * Prevents direct use of python "class" keyword + * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary + + Stores a mapping of detected interface and struct names to their + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. + + Parameters + ---------- + code : str + The vyper source code to be re-formatted. + """ + try: + self._parse(code) + except TokenError as e: + raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e + + def _parse(self, code: str): + result: list[TokenInfo] = [] + modification_offsets: dict[tuple[int, int], str] = {} + settings = Settings() + for_parser = ForParser(code) + hex_string_parser = HexStringParser() + + _col_adjustments: dict[int, int] = defaultdict(lambda: 0) - _col_adjustments: dict[int, int] = defaultdict(lambda: 0) - - try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) @@ -301,7 +286,7 @@ def pre_parse(code: str) -> PreParseResult: # a bit cursed technique to get untokenize to put # the new tokens in the right place so that modification_offsets # will work correctly. - # (recommend comparing the result of pre_parse with the + # (recommend comparing the result of parse with the # source code side by side to visualize the whitespace) new_keyword = "await" vyper_type = CUSTOM_EXPRESSION_TYPES[string] @@ -322,20 +307,15 @@ def pre_parse(code: str) -> PreParseResult: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - if not for_parser.consume(token) and not native_hex_parser.consume(token, result): + if not for_parser.consume(token) and not hex_string_parser.consume(token, result): result.extend(toks) - except TokenError as e: - raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - - for_loop_annotations = {} - for k, v in for_parser.annotations.items(): - for_loop_annotations[k] = v.copy() + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + for_loop_annotations[k] = v.copy() - return PreParseResult( - settings, - modification_offsets, - for_loop_annotations, - native_hex_parser.locations, - untokenize(result).decode("utf-8"), - ) + self.settings = settings + self.modification_offsets = modification_offsets + self.for_loop_annotations = for_loop_annotations + self.hex_string_locations = hex_string_parser.locations + self.reformatted_code = untokenize(result).decode("utf-8")