Skip to content

Commit

Permalink
json: revamp $ref handling (ggerganov#8073)
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Jun 26, 2024
1 parent 8854044 commit 2f4a1c0
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 133 deletions.
158 changes: 81 additions & 77 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <iostream>

using json = nlohmann::ordered_json;

Expand Down Expand Up @@ -392,10 +393,27 @@ class SchemaConverter {
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
std::unordered_map<std::string, json> _external_refs;
std::vector<json> _ref_context;

struct with_context {
SchemaConverter * _this;
const json * _target;
with_context(SchemaConverter * _this, const json * target) : _this(_this), _target(target) {
if (target) {
_this->_ref_context.push_back(*target);
}
}
~with_context() {
if (_target) {
GGML_ASSERT(_this->_ref_context.back() == *_target); // should not have been modified
_this->_ref_context.pop_back();
}
}
};

std::string _add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
Expand Down Expand Up @@ -683,17 +701,6 @@ class SchemaConverter {
return out.str();
}

std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
ref_name = visit(resolved, ref_name);
_refs_being_resolved.erase(ref);
}
return ref_name;
}

std::string _build_object_rule(
const std::vector<std::pair<std::string, json>> & properties,
const std::unordered_set<std::string> & required,
Expand Down Expand Up @@ -815,78 +822,70 @@ class SchemaConverter {
_rules["space"] = SPACE_RULE;
}

void resolve_refs(json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
* respective referenced (sub)schema dictionaries.
*/
std::function<void(json &)> visit_refs = [&](json & n) {
if (n.is_array()) {
for (auto & x : n) {
visit_refs(x);
}
} else if (n.is_object()) {
if (n.contains("$ref")) {
std::string ref = n["$ref"];
if (_refs.find(ref) == _refs.end()) {
json target;
if (ref.find("https://") == 0) {
std::string base_url = ref.substr(0, ref.find('#'));
auto it = _refs.find(base_url);
if (it != _refs.end()) {
target = it->second;
} else {
// Fetch the referenced schema and resolve its refs
auto referenced = _fetch_json(ref);
resolve_refs(referenced, base_url);
_refs[base_url] = referenced;
}
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
return;
}
} else if (ref.find("#/") == 0) {
target = schema;
n["$ref"] = url + ref;
ref = url + ref;
} else {
_errors.push_back("Unsupported ref: " + ref);
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
} else {
for (auto & kv : n.items()) {
visit_refs(kv.value());
}
}
}
};

visit_refs(schema);
}
// const std::unordered_map<std::string, json> & get_refs() const {
// return _refs;
// }

std::string _generate_constant_rule(const json & value) {
return format_literal(value.dump());
}

std::pair<json, bool> _resolve_ref(const std::string & ref) {
auto parts = split(ref, "#");
if (parts.size() != 2) {
_errors.push_back("Unsupported ref: " + ref);
return {json(), false};
}
const auto & url = parts[0];
json target;
bool is_local = url.empty();
if (is_local) {
if (_ref_context.empty()) {
_errors.push_back("Error resolving ref " + ref + ": no context");
return {json(), false};
}
target = _ref_context.back();
} else {
auto it = _external_refs.find(url);
if (it != _external_refs.end()) {
target = it->second;
} else {
// Fetch the referenced schema and resolve its refs
auto referenced = _fetch_json(url);
// resolve_refs(referenced, url);
_external_refs[url] = referenced;
}
}
std::vector<std::string> tokens = split(parts[1], "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return json();
}
target = target[sel];
}
return {target, is_local};
}

std::string visit(const json & schema, const std::string & name) {
json schema_type = schema.contains("type") ? schema["type"] : json();
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;

if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
with_context wc(this, _ref_context.empty() ? &schema : nullptr);

if (schema.contains("$ref") && schema["$ref"].is_string()) {
const auto & ref = schema["$ref"].get<std::string>();
auto pair = _resolve_ref(ref);
auto target = pair.first;
auto is_local = pair.second;
if (target.is_null()) {
return "";
}
std::cout << target.dump(4) << std::endl;
with_context wc(this, is_local ? nullptr : &target);
return visit(target, name);
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
Expand Down Expand Up @@ -932,8 +931,9 @@ class SchemaConverter {
std::vector<std::pair<std::string, json>> properties;
std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
add_component(_refs[comp_schema["$ref"]], is_required);
if (comp_schema.contains("$ref") && comp_schema["$ref"].is_string()) {
auto target = _resolve_ref(schema["$ref"].get<std::string>());
add_component(target, is_required);
} else if (comp_schema.contains("properties")) {
for (const auto & prop : comp_schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
Expand Down Expand Up @@ -1038,7 +1038,11 @@ class SchemaConverter {
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
// converter.resolve_refs(copy, "input");
std::cout << copy.dump(4) << std::endl;
// for (const auto & [n, j] : converter.get_refs()) {
// std::cout << "REF: " << n << " -> " << j.dump(4) << "\n";
// }
converter.visit(copy, "");
converter.check_errors();
return converter.format_grammar();
Expand Down
83 changes: 27 additions & 56 deletions examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
self._external_refs = {}
# self._refs_being_resolved = set()
self._ref_context = []

def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
Expand Down Expand Up @@ -332,51 +333,6 @@ def _add_rule(self, name, rule):
self._rules[key] = rule
return key

def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests

frag_split = ref.split('#')
base_url = frag_split[0]

target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target

if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')

for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]

self._refs[ref] = target
else:
for v in n.values():
visit(v)

return n
return visit(schema)

def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
Expand Down Expand Up @@ -541,18 +497,33 @@ def join_seq():
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")


def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name

def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))

def _resolve_ref(self, ref):
parts = ref.split('#')
assert len(parts) == 2, f'Unsupported ref: {ref}'
url = parts[0]
is_local = url == ''
if is_local:
assert self._refs_being_resolved, 'Error resolving ref {ref}: no context'
target = self._refs_being_resolved[-1]
else:
if url in self._refs:
target = self._refs[url]
else:
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
referenced = requests.get(url).json()
self._refs[url] = referenced
target = referenced

for sel in parts[1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]

return target

def visit(self, schema, name):
schema_type = schema.get('type')
schema_format = schema.get('format')
Expand Down
41 changes: 41 additions & 0 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,47 @@ static void test_json_schema() {
// R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
}
);

test_schema(
"refs",
// Schema
R"""({
"type": "array",
"minItems": 15,
"maxItems": 15,
"items": { "$ref": "#/$defs/TALK" },
"$defs": {
"characters": { "enum": ["Biff", "Alice"] },
"emotes": { "enum": ["EXCLAMATION", "CONFUSION", "CHEERFUL", "LOVE", "ANGRY"] },
"TALK": {
"type": "object",
"required": [ "character", "emote", "dialog" ],
"properties": {
"character": { "$ref": "#/$defs/characters" },
"emote": { "$ref": "#/$defs/emotes" },
"dialog": {
"type": "string",
"minLength": 1,
"maxLength": 200
}
}
}
}
})""",
// Passing strings
{
R"""({
"character": "Alice",
"emote": "EXCLAMATION",
"dialog": "Hello, world!"
})""",
},
// Failing strings
{
}
);
}

int main() {
Expand Down
Loading

0 comments on commit 2f4a1c0

Please sign in to comment.