Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow oneOf in JSON schemas (with limited support) #982

Merged
merged 8 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Type,
TYPE_CHECKING,
)
import warnings

try:
import jsonschema
Expand Down Expand Up @@ -39,6 +40,7 @@ def _to_compact_json(target: Any) -> str:
class Keyword(str, Enum):
ANYOF = "anyOf"
ALLOF = "allOf"
ONEOF = "oneOf"
REF = "$ref"
CONST = "const"
ENUM = "enum"
Expand All @@ -55,6 +57,7 @@ class Keyword(str, Enum):
IGNORED_KEYS = {
"$schema",
"$id",
"id",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that our problem with id (in the FHIR schema) is only at top level. I think that id is a perfectly fine name for a property on an object.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_json_node_keys is only being called on dictionaries that represent full (sub)schemas, which doesn't include the dictionary specified by the properties key of a (sub)schema. Therefore the set of ignored keys should have no impact on what property names are valid -- worth checking if you have doubts :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest adding a test with an object with an id property.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to do so :)

"$comment",
"title",
"description",
Expand All @@ -63,6 +66,14 @@ class Keyword(str, Enum):
"required", # TODO: implement and remove from ignored list
}

# discriminator is part of OpenAPI 3.1, not JSON Schema itself
# https://json-schema.org/blog/posts/validating-openapi-and-json-schema
# TODO: While ignoring this key shouldn't lead to invalid outputs, forcing
# the model to choose the value of the marked field before other fields
# are generated (statefully or statelessly) would reduce grammar ambiguity
# and possibly improve quality.
IGNORED_KEYS.add("discriminator")

TYPE_SPECIFIC_KEYS = {
"array": {"items", "prefixItems", "minItems", "maxItems"},
"object": {"properties", "additionalProperties"},
Expand Down Expand Up @@ -339,6 +350,13 @@ def _gen_json(
raise ValueError("Only support allOf with exactly one item")
return lm + _gen_json(allof_list[0], definitions)

if Keyword.ONEOF in json_schema:
oneof_list = json_schema[Keyword.ONEOF]
if len(oneof_list) == 1:
return lm + _gen_json(oneof_list[0], definitions)
warnings.warn("oneOf not fully supported, falling back to anyOf. This may cause validation errors in some cases.")
return lm + _process_anyOf(anyof_list=oneof_list, definitions=definitions)

if Keyword.REF in json_schema:
return lm + _get_definition(reference=json_schema[Keyword.REF], definitions=definitions)

Expand Down
43 changes: 42 additions & 1 deletion tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from guidance import json as gen_json
from guidance import models

from guidance.library._json import _to_compact_json, WHITESPACE
from guidance.library._json import _to_compact_json, WHITESPACE, IGNORED_KEYS

from ...utils import check_match_failure as _check_match_failure
from ...utils import check_run_with_temperature
Expand Down Expand Up @@ -1319,6 +1319,37 @@ def test_allOf_bad_schema(self):
lm += gen_json(name=CAPTURE_KEY, schema=schema_obj)
assert ve.value.args[0] == "Only support allOf with exactly one item"

class TestOneOf:
@pytest.mark.parametrize("target_obj", [123, 42])
def test_oneOf_simple(self, target_obj):
schema = """{
"oneOf" : [{ "type": "integer" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
generate_and_check(target_obj, schema_obj)


@pytest.mark.parametrize("target_obj", [123, True])
def test_oneOf_compound(self, target_obj):
schema = """{
"oneOf" : [{ "type": "integer" }, { "type": "boolean" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check; we expect a warning here because oneOf is not fully supported
with pytest.warns() as record:
generate_and_check(target_obj, schema_obj)
assert len(record) == 1
assert record[0].message.args[0].startswith("oneOf not fully supported")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for including this check (wrote my 'main' comment before seeing the code).



class TestEnum:
simple_schema = """{
Expand Down Expand Up @@ -1950,3 +1981,13 @@ def test_no_additionalProperties(self, compact):
maybe_whitespace=True,
compact=compact,
)

def test_ignored_keys_allowed_as_properties():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@riedgar-ms is this test acceptable to you here? It doesn't explicitly test id; rather it checks against all ignored keys. Happy to specialize it to id if you like that better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better to check them all

schema_obj = {
"type": "object",
"properties": {
key: {"type": "string"} for key in IGNORED_KEYS
}
}
target_obj = {key: "value" for key in IGNORED_KEYS}
generate_and_check(target_obj, schema_obj)
43 changes: 43 additions & 0 deletions tests/unit/library/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,46 @@ def test_bad_generic(
maybe_whitespace=maybe_whitespace,
compact=compact,
)

class TestDiscriminatedUnion:
"""
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-str-discriminators
"""

class Cat(pydantic.BaseModel):
pet_type: Literal['cat']
meows: int


class Dog(pydantic.BaseModel):
pet_type: Literal['dog']
barks: float


class Lizard(pydantic.BaseModel):
pet_type: Literal['reptile', 'lizard']
scales: bool


class Model(pydantic.BaseModel):
pet: Union[
'TestDiscriminatedUnion.Cat',
'TestDiscriminatedUnion.Dog',
'TestDiscriminatedUnion.Lizard',
] = pydantic.Field(..., discriminator='pet_type')
n: int

def test_good(self):
obj = {"pet": {"pet_type": "dog", "barks": 3.14}, "n": 42}
generate_and_check(obj, self.Model)

def test_bad(self):
check_match_failure(
bad_obj={"pet": {"pet_type": "dog"}, "n": 42},
good_bytes=b'{"pet":{"pet_type":"dog"',
failure_byte=b"}",
allowed_bytes={b","}, # expect a comma to continue the object with "barks"
pydantic_model=self.Model,
maybe_whitespace=False,
compact=True
)
Loading