Skip to content

Commit

Permalink
[Feature] Support required properties in JSON schemas (#1009)
Browse files Browse the repository at this point in the history
Add support for `required` JSON schema properties, allowing the model to omit any properties not in this list
  • Loading branch information
hudson-ai authored Sep 6, 2024
1 parent 2b252fc commit ed7e8a7
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 109 deletions.
106 changes: 47 additions & 59 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,67 +144,54 @@ def _gen_json_object(
*,
properties: Mapping[str, Any],
additional_properties: Union[bool, Mapping[str, Any]],
required: Sequence[str],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if additional_properties is True:
# True means that anything goes
additional_properties = {}

lm += "{"
if properties:
lm += _process_properties(properties=properties, definitions=definitions)
if properties and additional_properties is not False:
lm += optional(
","
+ _process_additional_properties(
additional_properties=additional_properties, definitions=definitions
)
)
elif additional_properties is not False:
lm += optional(
_process_additional_properties(
additional_properties=additional_properties, definitions=definitions
)
)
lm += "}"
return lm


@guidance(stateless=True)
def _process_properties(
lm,
*,
properties: Mapping[str, Any],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
properties_added = 0
for name, property_schema in properties.items():
lm += '"' + name + '"'

lm += ":"
lm += _gen_json(
json_schema=property_schema,
definitions=definitions,
)
properties_added += 1
if properties_added < len(properties):
lm += ","
return lm


@guidance(stateless=True)
def _process_additional_properties(
lm,
*,
additional_properties: Mapping[str, Any],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
item = (
_gen_json_string()
+ ":"
+ _gen_json(json_schema=additional_properties, definitions=definitions)
)
return lm + sequence(item + ",") + item
if any(k not in properties for k in required):
raise ValueError(f"Required properties not in properties: {set(required) - set(properties)}")

grammars = tuple(f'"{name}":' + _gen_json(json_schema=schema, definitions=definitions) for name, schema in properties.items())
required_items = tuple(name in required for name in properties)

if additional_properties is not False:
if additional_properties is True:
# True means that anything goes
additional_properties = {}
additional_item_grammar = _gen_json_string() + ':' + _gen_json(json_schema=additional_properties, definitions=definitions)
additional_items_grammar = sequence(additional_item_grammar + ',') + additional_item_grammar
grammars += (additional_items_grammar,)
required_items += (False,)

return lm + "{" + _gen_list(
elements = grammars,
required = required_items,
) + "}"

@guidance(stateless=True, cache=True)
def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool, ...], prefixed: bool = False):
if not elements:
return lm

elem, elements = elements[0], elements[1:]
is_required, required = required[0], required[1:]

if prefixed:
if is_required:
# If we know we have preceeding elements, we can safely just add a (',' + e)
return lm + (',' + elem + _gen_list(elements=elements, required=required, prefixed=True))
# If we know we have preceeding elements, we can safely just add an optional(',' + e)
return lm + (optional(',' + elem) + _gen_list(elements=elements, required=required, prefixed=True))
if is_required:
# No preceding elements, and our element is required, so we just add the element
return lm + (elem + _gen_list(elements=elements, required=required, prefixed=True))

# No preceding elements, and our element is optional, so we add a select between the two options.
# The first option is the recursive call with no preceding elements, the second is the recursive call
# with the current element as a prefix.
return lm + select([
_gen_list(elements=elements, required=required, prefixed=False),
elem + _gen_list(elements=elements, required=required, prefixed=True)
])


@guidance(stateless=True)
Expand Down Expand Up @@ -393,6 +380,7 @@ def _gen_json(
return lm + _gen_json_object(
properties=json_schema.get("properties", {}),
additional_properties=json_schema.get("additionalProperties", True),
required=json_schema.get("required", set()),
definitions=definitions,
)
raise ValueError(f"Unsupported type in schema: {target_type}")
Expand Down
Loading

0 comments on commit ed7e8a7

Please sign in to comment.