Skip to content

Commit

Permalink
[Feature] Allow oneOf in JSON schemas (with limited support) (#982)
Browse files Browse the repository at this point in the history
- `oneOf` is allowed when only a single schema is provided in the
`oneOf` list
- Fall-back to `anyOf` if multiple schemas are provided, raising a
warning to the user.
- Add `id` and `discriminator` to ignored keys, expanding the schemas we
support
  • Loading branch information
hudson-ai authored Sep 3, 2024
1 parent 3bf3d14 commit 958145c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 1 deletion.
18 changes: 18 additions & 0 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Type,
TYPE_CHECKING,
)
import warnings

try:
import jsonschema
Expand Down Expand Up @@ -39,6 +40,7 @@ def _to_compact_json(target: Any) -> str:
class Keyword(str, Enum):
ANYOF = "anyOf"
ALLOF = "allOf"
ONEOF = "oneOf"
REF = "$ref"
CONST = "const"
ENUM = "enum"
Expand All @@ -55,6 +57,7 @@ class Keyword(str, Enum):
IGNORED_KEYS = {
"$schema",
"$id",
"id",
"$comment",
"title",
"description",
Expand All @@ -63,6 +66,14 @@ class Keyword(str, Enum):
"required", # TODO: implement and remove from ignored list
}

# discriminator is part of OpenAPI 3.1, not JSON Schema itself
# https://json-schema.org/blog/posts/validating-openapi-and-json-schema
# TODO: While ignoring this key shouldn't lead to invalid outputs, forcing
# the model to choose the value of the marked field before other fields
# are generated (statefully or statelessly) would reduce grammar ambiguity
# and possibly improve quality.
IGNORED_KEYS.add("discriminator")

TYPE_SPECIFIC_KEYS = {
"array": {"items", "prefixItems", "minItems", "maxItems"},
"object": {"properties", "additionalProperties"},
Expand Down Expand Up @@ -339,6 +350,13 @@ def _gen_json(
raise ValueError("Only support allOf with exactly one item")
return lm + _gen_json(allof_list[0], definitions)

if Keyword.ONEOF in json_schema:
oneof_list = json_schema[Keyword.ONEOF]
if len(oneof_list) == 1:
return lm + _gen_json(oneof_list[0], definitions)
warnings.warn("oneOf not fully supported, falling back to anyOf. This may cause validation errors in some cases.")
return lm + _process_anyOf(anyof_list=oneof_list, definitions=definitions)

if Keyword.REF in json_schema:
return lm + _get_definition(reference=json_schema[Keyword.REF], definitions=definitions)

Expand Down
43 changes: 42 additions & 1 deletion tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from guidance import json as gen_json
from guidance import models

from guidance.library._json import _to_compact_json, WHITESPACE
from guidance.library._json import _to_compact_json, WHITESPACE, IGNORED_KEYS

from ...utils import check_match_failure as _check_match_failure
from ...utils import check_run_with_temperature
Expand Down Expand Up @@ -1319,6 +1319,37 @@ def test_allOf_bad_schema(self):
lm += gen_json(name=CAPTURE_KEY, schema=schema_obj)
assert ve.value.args[0] == "Only support allOf with exactly one item"

class TestOneOf:
@pytest.mark.parametrize("target_obj", [123, 42])
def test_oneOf_simple(self, target_obj):
schema = """{
"oneOf" : [{ "type": "integer" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
generate_and_check(target_obj, schema_obj)


@pytest.mark.parametrize("target_obj", [123, True])
def test_oneOf_compound(self, target_obj):
schema = """{
"oneOf" : [{ "type": "integer" }, { "type": "boolean" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check; we expect a warning here because oneOf is not fully supported
with pytest.warns() as record:
generate_and_check(target_obj, schema_obj)
assert len(record) == 1
assert record[0].message.args[0].startswith("oneOf not fully supported")


class TestEnum:
simple_schema = """{
Expand Down Expand Up @@ -1950,3 +1981,13 @@ def test_no_additionalProperties(self, compact):
maybe_whitespace=True,
compact=compact,
)

def test_ignored_keys_allowed_as_properties():
schema_obj = {
"type": "object",
"properties": {
key: {"type": "string"} for key in IGNORED_KEYS
}
}
target_obj = {key: "value" for key in IGNORED_KEYS}
generate_and_check(target_obj, schema_obj)
43 changes: 43 additions & 0 deletions tests/unit/library/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,46 @@ def test_bad_generic(
maybe_whitespace=maybe_whitespace,
compact=compact,
)

class TestDiscriminatedUnion:
"""
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-str-discriminators
"""

class Cat(pydantic.BaseModel):
pet_type: Literal['cat']
meows: int


class Dog(pydantic.BaseModel):
pet_type: Literal['dog']
barks: float


class Lizard(pydantic.BaseModel):
pet_type: Literal['reptile', 'lizard']
scales: bool


class Model(pydantic.BaseModel):
pet: Union[
'TestDiscriminatedUnion.Cat',
'TestDiscriminatedUnion.Dog',
'TestDiscriminatedUnion.Lizard',
] = pydantic.Field(..., discriminator='pet_type')
n: int

def test_good(self):
obj = {"pet": {"pet_type": "dog", "barks": 3.14}, "n": 42}
generate_and_check(obj, self.Model)

def test_bad(self):
check_match_failure(
bad_obj={"pet": {"pet_type": "dog"}, "n": 42},
good_bytes=b'{"pet":{"pet_type":"dog"',
failure_byte=b"}",
allowed_bytes={b","}, # expect a comma to continue the object with "barks"
pydantic_model=self.Model,
maybe_whitespace=False,
compact=True
)

0 comments on commit 958145c

Please sign in to comment.