Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create test_graph_schema_extraction.py #62

Merged
merged 9 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nodestream/schema/printers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .graphql_schema_printer import GraphQLSchemaPrinter
from .graph_schema_extraction import LargeLanguageModelSchemaPrinter
from .plain_text_schema_printer import PlainTestSchemaPrinter
from .schema_printer import SCHEMA_PRINTER_SUBCLASS_REGISTRY, SchemaPrinter

__all__ = (
"SchemaPrinter",
"GraphQLSchemaPrinter",
"LargeLanguageModelSchemaPrinter",
"PlainTestSchemaPrinter",
"SCHEMA_PRINTER_SUBCLASS_REGISTRY",
)
33 changes: 33 additions & 0 deletions nodestream/schema/printers/graph_schema_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .schema_printer import SchemaPrinter
from collections import defaultdict

class LargeLanguageModelSchemaPrinter(SchemaPrinter, alias="genaillm"):
def return_nodes_props(self, schema):
return {
str(node_shape.object_type): node_shape.property_names()
for node_shape in schema.known_node_types()
}

def return_rels_props(self, schema):
return {
str(rel_shape.object_type): rel_shape.property_names()
for rel_shape in schema.known_relationship_types()
}

def return_rels(self, schema):
rels = defaultdict(list)
for elem in schema.relationships:
from_node = elem.from_object_type
to_node = elem.to_object_type
rel_name = elem.relationship_type
rels[str(from_node)] = defaultdict(list)
rels[str(from_node)][str(rel_name)].append(str(to_node))
return rels

def print_schema_to_string(self, schema: GraphSchema) -> str:
representation = str(self.return_nodes_props(schema))
representation += ". "
representation += str(self.return_rels_props(schema))
representation += ". "
representation += str(self.return_rels(schema))
return representation
4 changes: 4 additions & 0 deletions nodestream/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def include(self, other: "GraphObjectShape"):

self.properties.update(other.properties)

def property_names(self):
all_props = self.properties.properties
return [all_props[prop].name for prop in all_props.keys()]

def resolve_types(self, shapes: Iterable["GraphObjectShape"]):
if object_type := self.object_type.resolve_type(shapes):
self.object_type = object_type
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/schema/printers/test_graph_schema_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from collections import defaultdict
from hamcrest import assert_that, equal_to

from nodestream.schema.printers.graph_schema_extraction import LargeLanguageModelSchemaPrinter

EXPECTED_NODE_PROPS = {'Person': ['name', 'age'], 'Organization': ['name', 'industry']}
EXPECTED_RELS = defaultdict(list, {'Person': defaultdict(list, {'BEST_FRIEND_OF': ['Person']}), 'Organization': defaultdict(list, {'HAS_EMPLOYEE': ['Person']})})
EXPECTED_RELS_PROPS = {'BEST_FRIEND_OF': ['since'], 'HAS_EMPLOYEE': ['since']}
EXPECTED_PRINTED_SCHEMA = "{'Person': ['name', 'age'], 'Organization': ['name', 'industry']}. {'BEST_FRIEND_OF': ['since'], 'HAS_EMPLOYEE': ['since']}. defaultdict(<class 'list'>, {'Person': defaultdict(<class 'list'>, {'BEST_FRIEND_OF': ['Person']}), 'Organization': defaultdict(<class 'list'>, {'HAS_EMPLOYEE': ['Person']})})"

def test_outputs_schema_correctly(basic_schema):
printer = LargeLanguageModelSchemaPrinter()
output = printer.print_schema_to_string(basic_schema)
assert_that(output, equal_to(str(EXPECTED_PRINTED_SCHEMA)))

def test_ensure_nodes_props(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_nodes_props(basic_schema)
assert_that(result, equal_to(EXPECTED_NODE_PROPS))

def test_ensure_rels(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_rels(basic_schema)
assert_that(result, equal_to(EXPECTED_RELS))

def test_ensure_rels_props(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_rels_props(basic_schema)
assert_that(result, equal_to(EXPECTED_RELS_PROPS))