Skip to content

Commit

Permalink
Merge pull request #62 from AGhafaryy/AGhafaryy-patch-1
Browse files Browse the repository at this point in the history
Create test_graph_schema_extraction.py
  • Loading branch information
zprobst authored Aug 9, 2023
2 parents 7028652 + a998312 commit 502492b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nodestream/schema/printers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .graphql_schema_printer import GraphQLSchemaPrinter
from .graph_schema_extraction import LargeLanguageModelSchemaPrinter
from .plain_text_schema_printer import PlainTestSchemaPrinter
from .schema_printer import SCHEMA_PRINTER_SUBCLASS_REGISTRY, SchemaPrinter

__all__ = (
"SchemaPrinter",
"GraphQLSchemaPrinter",
"LargeLanguageModelSchemaPrinter",
"PlainTestSchemaPrinter",
"SCHEMA_PRINTER_SUBCLASS_REGISTRY",
)
33 changes: 33 additions & 0 deletions nodestream/schema/printers/graph_schema_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .schema_printer import SchemaPrinter
from collections import defaultdict

class LargeLanguageModelSchemaPrinter(SchemaPrinter, alias="genaillm"):
def return_nodes_props(self, schema):
return {
str(node_shape.object_type): node_shape.property_names()
for node_shape in schema.known_node_types()
}

def return_rels_props(self, schema):
return {
str(rel_shape.object_type): rel_shape.property_names()
for rel_shape in schema.known_relationship_types()
}

def return_rels(self, schema):
rels = defaultdict(list)
for elem in schema.relationships:
from_node = elem.from_object_type
to_node = elem.to_object_type
rel_name = elem.relationship_type
rels[str(from_node)] = defaultdict(list)
rels[str(from_node)][str(rel_name)].append(str(to_node))
return rels

def print_schema_to_string(self, schema: GraphSchema) -> str:
representation = str(self.return_nodes_props(schema))
representation += ". "
representation += str(self.return_rels_props(schema))
representation += ". "
representation += str(self.return_rels(schema))
return representation
4 changes: 4 additions & 0 deletions nodestream/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def include(self, other: "GraphObjectShape"):

self.properties.update(other.properties)

def property_names(self):
all_props = self.properties.properties
return [all_props[prop].name for prop in all_props.keys()]

def resolve_types(self, shapes: Iterable["GraphObjectShape"]):
if object_type := self.object_type.resolve_type(shapes):
self.object_type = object_type
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/schema/printers/test_graph_schema_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from collections import defaultdict
from hamcrest import assert_that, equal_to

from nodestream.schema.printers.graph_schema_extraction import LargeLanguageModelSchemaPrinter

EXPECTED_NODE_PROPS = {'Person': ['name', 'age'], 'Organization': ['name', 'industry']}
EXPECTED_RELS = defaultdict(list, {'Person': defaultdict(list, {'BEST_FRIEND_OF': ['Person']}), 'Organization': defaultdict(list, {'HAS_EMPLOYEE': ['Person']})})
EXPECTED_RELS_PROPS = {'BEST_FRIEND_OF': ['since'], 'HAS_EMPLOYEE': ['since']}
EXPECTED_PRINTED_SCHEMA = "{'Person': ['name', 'age'], 'Organization': ['name', 'industry']}. {'BEST_FRIEND_OF': ['since'], 'HAS_EMPLOYEE': ['since']}. defaultdict(<class 'list'>, {'Person': defaultdict(<class 'list'>, {'BEST_FRIEND_OF': ['Person']}), 'Organization': defaultdict(<class 'list'>, {'HAS_EMPLOYEE': ['Person']})})"

def test_outputs_schema_correctly(basic_schema):
printer = LargeLanguageModelSchemaPrinter()
output = printer.print_schema_to_string(basic_schema)
assert_that(output, equal_to(str(EXPECTED_PRINTED_SCHEMA)))

def test_ensure_nodes_props(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_nodes_props(basic_schema)
assert_that(result, equal_to(EXPECTED_NODE_PROPS))

def test_ensure_rels(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_rels(basic_schema)
assert_that(result, equal_to(EXPECTED_RELS))

def test_ensure_rels_props(basic_schema):
subject = LargeLanguageModelSchemaPrinter()
result = subject.return_rels_props(basic_schema)
assert_that(result, equal_to(EXPECTED_RELS_PROPS))

0 comments on commit 502492b

Please sign in to comment.