Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support required properties in JSON schemas #1009

Merged
merged 11 commits into from
Sep 6, 2024
106 changes: 47 additions & 59 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,67 +144,54 @@ def _gen_json_object(
*,
properties: Mapping[str, Any],
additional_properties: Union[bool, Mapping[str, Any]],
required: Sequence[str],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if additional_properties is True:
# True means that anything goes
additional_properties = {}

lm += "{"
if properties:
lm += _process_properties(properties=properties, definitions=definitions)
if properties and additional_properties is not False:
lm += optional(
","
+ _process_additional_properties(
additional_properties=additional_properties, definitions=definitions
)
)
elif additional_properties is not False:
lm += optional(
_process_additional_properties(
additional_properties=additional_properties, definitions=definitions
)
)
lm += "}"
return lm


@guidance(stateless=True)
def _process_properties(
lm,
*,
properties: Mapping[str, Any],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
properties_added = 0
for name, property_schema in properties.items():
lm += '"' + name + '"'

lm += ":"
lm += _gen_json(
json_schema=property_schema,
definitions=definitions,
)
properties_added += 1
if properties_added < len(properties):
lm += ","
return lm


@guidance(stateless=True)
def _process_additional_properties(
lm,
*,
additional_properties: Mapping[str, Any],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
item = (
_gen_json_string()
+ ":"
+ _gen_json(json_schema=additional_properties, definitions=definitions)
)
return lm + sequence(item + ",") + item
if any(k not in properties for k in required):
raise ValueError(f"Required properties not in properties: {set(required) - set(properties)}")

grammars = tuple(f'"{name}":' + _gen_json(json_schema=schema, definitions=definitions) for name, schema in properties.items())
required_items = tuple(name in required for name in properties)

if additional_properties is not False:
if additional_properties is True:
# True means that anything goes
additional_properties = {}
additional_item_grammar = _gen_json_string() + ':' + _gen_json(json_schema=additional_properties, definitions=definitions)
additional_items_grammar = sequence(additional_item_grammar + ',') + additional_item_grammar
grammars += (additional_items_grammar,)
required_items += (False,)

return lm + "{" + _gen_list(
elements = grammars,
required = required_items,
) + "}"

@guidance(stateless=True, cache=True)
def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool, ...], prefixed: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason for these being Tuples and not Lists?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, one of those little should-never-happen-but.... checks, if I'm reading this correctly, elements and required must be the same length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're specifically tuples because tuples are immutable and therefore hashable iff their elements are. My recursive solution needs hashable args in order to support caching, which is in turn necessary to prevent the O(2^N) behavior. I suspect this can all be sidestepped with more of a direct dynamic programming approach.

if not elements:
return lm

elem, elements = elements[0], elements[1:]
is_required, required = required[0], required[1:]

if prefixed:
if is_required:
# If we know we have preceeding elements, we can safely just add a (',' + e)
return lm + (',' + elem + _gen_list(elements=elements, required=required, prefixed=True))
# If we know we have preceeding elements, we can safely just add an optional(',' + e)
return lm + (optional(',' + elem) + _gen_list(elements=elements, required=required, prefixed=True))
if is_required:
# No preceding elements, and our element is required, so we just add the element
return lm + (elem + _gen_list(elements=elements, required=required, prefixed=True))

# No preceding elements, and our element is optional, so we add a select between the two options.
# The first option is the recursive call with no preceding elements, the second is the recursive call
# with the current element as a prefix.
return lm + select([
_gen_list(elements=elements, required=required, prefixed=False),
elem + _gen_list(elements=elements, required=required, prefixed=True)
])


@guidance(stateless=True)
Expand Down Expand Up @@ -393,6 +380,7 @@ def _gen_json(
return lm + _gen_json_object(
properties=json_schema.get("properties", {}),
additional_properties=json_schema.get("additionalProperties", True),
required=json_schema.get("required", set()),
definitions=definitions,
)
raise ValueError(f"Unsupported type in schema: {target_type}")
Expand Down
Loading
Loading