Skip to content

Commit

Permalink
Rip out self.next() entirely
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Aug 3, 2023
1 parent 6afef47 commit 0da42bf
Showing 1 changed file with 76 additions and 72 deletions.
148 changes: 76 additions & 72 deletions Tools/clinic/clinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4388,10 +4388,6 @@ def dedent(self, line: str) -> str:
return line[indent:]


class StateKeeper(Protocol):
def __call__(self, function: Function | None, line: str) -> Function | None: ...


ConverterArgs = dict[str, Any]

class ParamState(enum.IntEnum):
Expand Down Expand Up @@ -4425,8 +4421,18 @@ class ParamState(enum.IntEnum):
RIGHT_SQUARE_AFTER = 6


class DSLParserState(enum.Enum):
DSL_START = enum.auto()
MODULENAME_NAME = enum.auto()
PARAMETERS_START = enum.auto()
PARAMETER = enum.auto()
PARAMETER_DOCSTRING_START = enum.auto()
PARAMETER_DOCSTRING = enum.auto()
FUNCTION_DOCSTRING = enum.auto()


class DSLParser:
state: StateKeeper
state: DSLParserState
keyword_only: bool
positional_only: bool
group: int
Expand Down Expand Up @@ -4456,7 +4462,7 @@ def __init__(self, clinic: Clinic) -> None:
self.reset()

def reset(self) -> None:
self.state = self.state_dsl_start
self.state = DSLParserState.DSL_START
self.keyword_only = False
self.positional_only = False
self.group = 0
Expand Down Expand Up @@ -4611,7 +4617,7 @@ def parse(self, block: Block) -> None:
if '\t' in line:
fail('Tab characters are illegal in the Clinic DSL.\n\t' + repr(line), line_number=block_start)
try:
function = self.state(function, line)
function = self.handle_line(function, line)
except ClinicError as exc:
exc.lineno = line_number
raise
Expand All @@ -4624,11 +4630,43 @@ def parse(self, block: Block) -> None:
fail("'preserve' only works for blocks that don't produce any output!")
block.output = self.saved_output

def handle_line(self, function: Function | None, line: str) -> Function | None:
match function:
case None:
match self.state:
case DSLParserState.DSL_START:
return self.handle_dsl_start(line)
case DSLParserState.MODULENAME_NAME:
return self.handle_modulename_name(line)
case _ as state:
raise AssertionError(
f"self.state is {state!r} but function is still None!"
)
case Function():
match self.state:
case DSLParserState.PARAMETERS_START:
return self.handle_parameters_start(function, line)
case DSLParserState.PARAMETER:
return self.handle_parameter(function, line)
case DSLParserState.PARAMETER_DOCSTRING_START:
return self.handle_parameter_docstring_start(function, line)
case DSLParserState.PARAMETER_DOCSTRING:
return self.handle_parameter_docstring(function, line)
case DSLParserState.FUNCTION_DOCSTRING:
return self.handle_function_docstring(function, line)
case _ as state:
raise AssertionError(f"Unexpected state: {state!r}")
case _:
raise AssertionError(
f"Expected function to be a Function or None, "
f"got {type(function)!r}"
)

def in_docstring(self) -> bool:
"""Return true if we are processing a docstring."""
return self.state in {
self.state_parameter_docstring,
self.state_function_docstring,
DSLParserState.PARAMETER_DOCSTRING,
DSLParserState.FUNCTION_DOCSTRING,
}

def valid_line(self, line: str) -> bool:
Expand All @@ -4647,21 +4685,7 @@ def valid_line(self, line: str) -> bool:
def calculate_indent(line: str) -> int:
return len(line) - len(line.strip())

def next(
self,
state: StateKeeper,
*,
function: Function | None,
line: str | None = None,
) -> Function | None:
self.state = state
if line is not None:
function = self.state(function=function, line=line)
return function

def state_dsl_start(self, function: Function | None, line: str) -> Function | None:
assert function is None

def handle_dsl_start(self, line: str) -> Function | None:
if not self.valid_line(line):
return None

Expand All @@ -4676,11 +4700,10 @@ def state_dsl_start(self, function: Function | None, line: str) -> Function | No
fail(str(e))
return None

return self.next(self.state_modulename_name, function=None, line=line)
self.state = DSLParserState.MODULENAME_NAME
return self.handle_modulename_name(line)

def state_modulename_name(
self, function: Function | None, line: str
) -> Function | None:
def handle_modulename_name(self, line: str) -> Function | None:
# looking for declaration, which establishes the leftmost column
# line should be
# modulename.fnname [as c_basename] [-> return annotation]
Expand All @@ -4697,8 +4720,6 @@ def state_modulename_name(
# this line is permitted to start with whitespace.
# we'll call this number of spaces F (for "function").

assert function is None

if not self.valid_line(line):
return None

Expand Down Expand Up @@ -4740,7 +4761,8 @@ def state_modulename_name(
)
self.block.signatures.append(function)
(cls or module).functions.append(function)
return self.next(self.state_function_docstring, function=function)
self.state = DSLParserState.FUNCTION_DOCSTRING
return function

line, _, returns = line.partition('->')
returns = returns.strip()
Expand Down Expand Up @@ -4821,7 +4843,8 @@ def state_modulename_name(
function.parameters[name] = p_self

(cls or module).functions.append(function)
return self.next(self.state_parameters_start, function=function)
self.state = DSLParserState.PARAMETERS_START
return function

# Now entering the parameters section. The rules, formally stated:
#
Expand Down Expand Up @@ -4878,18 +4901,16 @@ def state_modulename_name(
# separate boolean state variables.) The states are defined in the
# ParamState class.

def state_parameters_start(self, function: Function | None, line: str) -> Function:
assert function is not None

def handle_parameters_start(self, function: Function, line: str) -> Function:
if self.valid_line(line):
# if this line is not indented, we have no parameters
if not self.indent.infer(line):
self.next(
self.state_function_docstring, function=function, line=line
)
self.state = DSLParserState.FUNCTION_DOCSTRING
self.handle_function_docstring(function=function, line=line)
else:
self.parameter_continuation = ''
self.next(self.state_parameter, function=function, line=line)
self.state = DSLParserState.PARAMETER
self.handle_parameter(function=function, line=line)

return function

Expand All @@ -4902,11 +4923,7 @@ def to_required(self, function: Function) -> None:
for p in function.parameters.values():
p.group = -p.group

def state_parameter(
self, function: Function | None, line: str
) -> Function:
assert function is not None

def handle_parameter(self, function: Function, line: str) -> Function:
if not self.valid_line(line):
return function

Expand All @@ -4918,17 +4935,13 @@ def state_parameter(
indent = self.indent.infer(line)
if indent == -1:
# we outdented, must be to definition column
self.next(
self.state_function_docstring, function=function, line=line
)
return function
self.state = DSLParserState.FUNCTION_DOCSTRING
return self.handle_function_docstring(function=function, line=line)

if indent == 1:
# we indented, must be to new parameter docstring column
self.next(
self.state_parameter_docstring_start, function=function, line=line
)
return function
self.state = DSLParserState.PARAMETER_DOCSTRING_START
return self.handle_parameter_docstring_start(function=function, line=line)

line = line.rstrip()
if line.endswith('\\'):
Expand Down Expand Up @@ -5318,15 +5331,14 @@ def parse_slash(self, function: Function) -> None:
"positional-only parameters, which is unsupported.")
p.kind = inspect.Parameter.POSITIONAL_ONLY

def state_parameter_docstring_start(
self, function: Function | None, line: str
def handle_parameter_docstring_start(
self, function: Function, line: str
) -> Function:
assert function is not None
assert self.indent.margin is not None, "self.margin.infer() has not yet been called to set the margin"
self.parameter_docstring_indent = len(self.indent.margin)
assert self.indent.depth == 3
self.next(self.state_parameter_docstring, function=function, line=line)
return function
self.state = DSLParserState.PARAMETER_DOCSTRING
return self.handle_parameter_docstring(function=function, line=line)

def docstring_append(self, obj: Function | Parameter, line: str) -> None:
"""Add a rstripped line to the current docstring."""
Expand All @@ -5345,11 +5357,9 @@ def docstring_append(self, obj: Function | Parameter, line: str) -> None:
# every line of the docstring must start with at least F spaces,
# where F > P.
# these F spaces will be stripped.
def state_parameter_docstring(
self, function: Function | None, line: str
def handle_parameter_docstring(
self, function: Function, line: str
) -> Function:
assert function is not None

if not self.valid_line(line):
return function

Expand All @@ -5359,25 +5369,19 @@ def state_parameter_docstring(
assert self.indent.depth < 3
if self.indent.depth == 2:
# back to a parameter
self.next(self.state_parameter, function=function, line=line)
return function
self.state = DSLParserState.PARAMETER
return self.handle_parameter(function=function, line=line)
assert self.indent.depth == 1
self.next(
self.state_function_docstring, function=function, line=line
)
return function
self.state = DSLParserState.FUNCTION_DOCSTRING
return self.handle_function_docstring(function=function, line=line)

assert function.parameters
last_param = next(reversed(function.parameters.values()))
self.docstring_append(last_param, line)
return function

# the final stanza of the DSL is the docstring.
def state_function_docstring(
self, function: Function | None, line: str
) -> Function:
assert function is not None

def handle_function_docstring(self, function: Function, line: str) -> Function:
if self.group:
fail(f"Function {function.name} has a ] without a matching [.")

Expand Down

0 comments on commit 0da42bf

Please sign in to comment.