Skip to content

Commit

Permalink
Remove Placeholder and replace_grammar_node in favor of pointer-l…
Browse files Browse the repository at this point in the history
…ike container object (#1007)

Essentially reopening #638 with some major simplifications allowed by
the rust re-write (the serialized grammar now directly supports
"references").

Reopening because:
#995 still takes far too long to process the large JSON schema.
Line-profiling revealed that ~80% of the time spent constructing the
`GrammarFunction` is due to `replace_grammar_node`. In particular, this
function has to traverse the entire grammar tree, and we potentially
have to call it many times if there is a lot of mutual recursion.

Simply adding a node type that acts as a container (which we can fill
after we decide what its contents should be) side-steps this problem
entirely.
  • Loading branch information
hudson-ai authored Sep 6, 2024
1 parent 5eaf240 commit 2b252fc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
2 changes: 0 additions & 2 deletions guidance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from ._guidance import _decorator, guidance

from ._grammar import (
Placeholder,
RawFunction,
GrammarFunction,
Terminal,
replace_grammar_node,
string,
)
from ._utils import strip_multiline_string_indents
Expand Down
60 changes: 30 additions & 30 deletions guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ def match_byte(self, byte):
def max_tokens(self):
return 1000000000000

class DeferredReference(Terminal):
"""Container to hold a value that is resolved at a later time. This is useful for recursive definitions."""
__slots__ = "_value"

def __init__(self) -> None:
super().__init__(temperature=-1, capture_name=None)
self._resolved = False
self._value: Optional[GrammarFunction] = None

@property
def value(self) -> GrammarFunction:
if self._resolved:
return cast(GrammarFunction, self._value)
else:
raise ValueError("DeferredReference does not have a value yet")

@value.setter
def value(self, value: GrammarFunction) -> None:
if self._resolved:
raise ValueError("DeferredReference value already set")
self._value = value
self._resolved = True

class Byte(Terminal):
__slots__ = ("byte", "temperature")
Expand Down Expand Up @@ -347,31 +369,6 @@ def __init__(self, name):
self.name = name


def replace_grammar_node(grammar, target, replacement):
# Use a stack to keep track of the nodes to be visited
stack = [grammar]
visited_set = set() # use set for O(1) lookups

while stack:
current = stack.pop()

# Check if we have already visited this node
if current in visited_set:
continue
visited_set.add(current)

# We are done with this node if it's a terminal
if isinstance(current, (Terminal, ModelVariable, Placeholder)):
continue

# Iterate through the node's values and replace target with replacement
for i, value in enumerate(current.values):
if value == target:
current.values[i] = replacement
else:
stack.append(value)


def replace_model_variables(grammar, model, allowed_vars=None):
"""Replace all the ModelVariable nodes with their values in an iterative manner."""
visited_set = set()
Expand Down Expand Up @@ -447,11 +444,6 @@ def commit_point(value, hidden=False):
raise NotImplementedError("commit_point is not implemented (may remove in the future)")


class Placeholder(GrammarFunction):
def __init__(self):
super().__init__(capture_name=None)


class Join(GrammarFunction):
__slots__ = (
"values",
Expand Down Expand Up @@ -1139,6 +1131,14 @@ def process(self, node: GrammarFunction):
"literal": "",
}
}
elif isinstance(node, DeferredReference):
if node.value is None:
raise ValueError("Cannot serialize DeferredReference with unset value")
obj = {
"Join": {
"sequence": [self.node(node.value)],
}
}
else:
raise Exception("Unknown node type:", type(node))
tp = next(iter(obj))
Expand Down
20 changes: 10 additions & 10 deletions guidance/_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect

from . import models
from ._grammar import Placeholder, RawFunction, Terminal, replace_grammar_node, string
from ._grammar import RawFunction, Terminal, string, DeferredReference
from ._utils import strip_multiline_string_indents


Expand Down Expand Up @@ -40,18 +40,18 @@ def wrapped(*args, **kwargs):
callable(stateless) and stateless(*args, **kwargs)
):

# if we have a placeholder set then we must be in a recursive definition and so we return the placeholder
placeholder = getattr(f, "_self_call_placeholder_", None)
if placeholder is not None:
return placeholder
# if we have a (deferred) reference set, then we must be in a recursive definition and so we return the reference
reference = getattr(f, "_self_call_reference_", None)
if reference is not None:
return reference

# otherwise we call the function to generate the grammar
else:

# set a placeholder for recursive calls (only if we don't have arguments that might make caching a bad idea)
# set a DeferredReference for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if no_args:
f._self_call_placeholder_ = Placeholder()
f._self_call_reference_ = DeferredReference()

try:
# call the function to get the grammar node
Expand All @@ -61,12 +61,12 @@ def wrapped(*args, **kwargs):
else:
if not isinstance(node, (Terminal, str)):
node.name = f.__name__
# replace all the placeholders with our generated node
# set the reference value with our generated node
if no_args:
replace_grammar_node(node, f._self_call_placeholder_, node)
f._self_call_reference_.value = node
finally:
if no_args:
del f._self_call_placeholder_
del f._self_call_reference_

return node

Expand Down

0 comments on commit 2b252fc

Please sign in to comment.