Skip to content

Commit

Permalink
Inject TagIterator into BlockIterator (#39)
Browse files Browse the repository at this point in the history
* Inject tag iterator for greater flexibility. Fix some names.

* Inject TagIterator into BlockIterator for greater flexibility.

* Refine names.

* Fix name mismatch.

* Add type annotations

* Fix formatting for black.

* Tweak type annotation.

* Use Union type to make earlier python versions happy.
  • Loading branch information
peterallenwebb authored Jan 24, 2024
1 parent 8d69478 commit 63a4861
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 56 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240123-161107.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Inject TagIterator into BlockIterator for greater flexibility.
time: 2024-01-23T16:11:07.24321-05:00
custom:
Author: peterallenwebb
Issue: "38"
107 changes: 54 additions & 53 deletions dbt_common/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from typing import Iterator, List, Optional, Set, Union

from dbt_common.exceptions import (
BlockDefinitionNotAtTopError,
Expand All @@ -12,40 +13,42 @@
)


def regex(pat):
def regex(pat: str) -> re.Pattern:
return re.compile(pat, re.DOTALL | re.MULTILINE)


class BlockData:
"""raw plaintext data from the top level of the file."""

def __init__(self, contents):
def __init__(self, contents: str) -> None:
self.block_type_name = "__dbt__data"
self.contents = contents
self.contents: str = contents
self.full_block = contents


class BlockTag:
def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw):
def __init__(
self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None
) -> None:
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block

def __str__(self):
def __str__(self) -> str:
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)

def __repr__(self):
def __repr__(self) -> str:
return str(self)

@property
def end_block_type_name(self):
def end_block_type_name(self) -> str:
return "end{}".format(self.block_type_name)

def end_pat(self):
def end_pat(self) -> re.Pattern:
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = "".join(
pattern: str = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
Expand Down Expand Up @@ -98,44 +101,38 @@ def end_pat(self):


class TagIterator:
def __init__(self, data):
self.data = data
self.blocks = []
self._parenthesis_stack = []
self.pos = 0

def linepos(self, end=None) -> str:
"""Given an absolute position in the input data, return a pair of
def __init__(self, text: str) -> None:
self.text: str = text
self.pos: int = 0

def linepos(self, end: Optional[int] = None) -> str:
"""Given an absolute position in the input text, return a pair of
line number + relative position to the start of the line.
"""
end_val: int = self.pos if end is None else end
data = self.data[:end_val]
text = self.text[:end_val]
# if not found, rfind returns -1, and -1+1=0, which is perfect!
last_line_start = data.rfind("\n") + 1
last_line_start = text.rfind("\n") + 1
# it's easy to forget this, but line numbers are 1-indexed
line_number = data.count("\n") + 1
line_number = text.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"

def advance(self, new_position):
def advance(self, new_position: int) -> None:
self.pos = new_position

def rewind(self, amount=1):
def rewind(self, amount: int = 1) -> None:
self.pos -= amount

def _search(self, pattern):
return pattern.search(self.data, self.pos)
def _search(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.search(self.text, self.pos)

def _match(self, pattern):
return pattern.match(self.data, self.pos)
def _match(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.match(self.text, self.pos)

def _first_match(self, *patterns, **kwargs):
def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
match = self._search(pattern)
if match:
matches.append(match)
if not matches:
Expand All @@ -144,13 +141,13 @@ def _first_match(self, *patterns, **kwargs):
# TODO: do I need to account for m.start(), or is this ok?
return min(matches, key=lambda m: m.end())

def _expect_match(self, expected_name, *patterns, **kwargs):
match = self._first_match(*patterns, **kwargs)
def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore
match = self._first_match(*patterns)
if match is None:
raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :])
raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :])
return match

def handle_expr(self, match):
def handle_expr(self, match: re.Match) -> None:
"""Handle an expression. At this point we're at a string like:
{{ 1 + 2 }}
^ right here
Expand All @@ -176,12 +173,12 @@ def handle_expr(self, match):

self.advance(match.end())

def handle_comment(self, match):
def handle_comment(self, match: re.Match) -> None:
self.advance(match.end())
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())

def _expect_block_close(self):
def _expect_block_close(self) -> None:
"""Search for the tag close marker.
To the right of the type name, there are a few possiblities:
- a name (handled by the regex's 'block_name')
Expand All @@ -203,13 +200,13 @@ def _expect_block_close(self):
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())

def handle_raw(self):
def handle_raw(self) -> int:
# raw blocks are super special, they are a single complete regex
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()

def handle_tag(self, match):
def handle_tag(self, match: re.Match) -> Tag:
"""The tag could be one of a few things:
{% mytag %}
Expand All @@ -234,7 +231,7 @@ def handle_tag(self, match):
self._expect_block_close()
return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)

def find_tags(self):
def find_tags(self) -> Iterator[Tag]:
while True:
match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
if match is None:
Expand All @@ -259,7 +256,7 @@ def find_tags(self):
"Invalid regex match in next_block, expected block start, " "expr start, or comment start"
)

def __iter__(self):
def __iter__(self) -> Iterator[Tag]:
return self.find_tags()


Expand All @@ -272,31 +269,33 @@ def __iter__(self):


class BlockIterator:
def __init__(self, data):
self.tag_parser = TagIterator(data)
self.current = None
self.stack = []
self.last_position = 0
def __init__(self, tag_iterator: TagIterator) -> None:
self.tag_parser = tag_iterator
self.current: Optional[Tag] = None
self.stack: List[str] = []
self.last_position: int = 0

@property
def current_end(self):
def current_end(self) -> int:
if self.current is None:
return 0
else:
return self.current.end

@property
def data(self):
return self.tag_parser.data
def data(self) -> str:
return self.tag_parser.text

def is_current_end(self, tag):
def is_current_end(self, tag: Tag) -> bool:
return (
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)

def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
def find_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> Iterator[Union[BlockData, BlockTag]]:
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
Expand Down Expand Up @@ -347,5 +346,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
if raw_data:
yield BlockData(raw_data)

def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
def lex_for_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> List[Union[BlockData, BlockTag]]:
return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))
7 changes: 4 additions & 3 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_materialization_macro_name,
get_test_macro_name,
)
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator

from dbt_common.exceptions import (
CompilationError,
Expand Down Expand Up @@ -516,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:


def extract_toplevel_blocks(
data: str,
text: str,
allowed_blocks: Optional[Set[str]] = None,
collect_raw_data: bool = True,
) -> List[Union[BlockData, BlockTag]]:
Expand All @@ -534,4 +534,5 @@ def extract_toplevel_blocks(
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
tag_iterator = TagIterator(text)
return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)

0 comments on commit 63a4861

Please sign in to comment.