Skip to content

Commit

Permalink
Merge pull request #77 from RDFLib/iterative_rule
Browse files Browse the repository at this point in the history
Add the ability for SHACL rules to operate iteratively.
Closes #76
  • Loading branch information
ashleysommer authored May 26, 2021
2 parents 4f491dc + 7a86d6b commit 4166ee1
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 60 deletions.
12 changes: 12 additions & 0 deletions pyshacl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def __call__(self, parser, namespace, values, option_string=None):
default=False,
help='Enable features from the SHACL-JS Specification.',
)
parser.add_argument(
'--iterate-rules',
dest='iterate_rules',
action='store_true',
default=False,
help="Run Shape's SHACL Rules iteratively until the data_graph reaches a steady state.",
)
parser.add_argument('--abort', dest='abort', action='store_true', default=False, help='Abort on first error.')
parser.add_argument(
'-d', '--debug', dest='debug', action='store_true', default=False, help='Output additional runtime messages.'
Expand Down Expand Up @@ -151,6 +158,11 @@ def main():
validator_kwargs['advanced'] = True
if args.js:
validator_kwargs['js'] = True
if args.iterate_rules:
if not args.advanced:
sys.stderr.write("Iterate-Rules option only works when you enable Advanced Mode.\n")
else:
validator_kwargs['iterate_rules'] = True
if args.abort:
validator_kwargs['abort_on_error'] = True
if args.shacl_file_format:
Expand Down
59 changes: 40 additions & 19 deletions pyshacl/extras/js/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import typing

from pyshacl.consts import SH
from pyshacl.errors import ReportableRuntimeError
from pyshacl.rules.shacl_rule import SHACLRule

from .js_executable import JSExecutable


if typing.TYPE_CHECKING:
from pyshacl.pytypes import GraphLike
from pyshacl.shape import Shape
from pyshacl.shapes_graph import ShapesGraph

Expand All @@ -18,26 +20,45 @@
class JSRule(SHACLRule):
__slots__ = ('js_exe',)

def __init__(self, shape: 'Shape', rule_node):
super(JSRule, self).__init__(shape, rule_node)
def __init__(self, shape: 'Shape', rule_node, **kwargs):
super(JSRule, self).__init__(shape, rule_node, **kwargs)
shapes_graph = shape.sg # type: ShapesGraph
self.js_exe = JSExecutable(shapes_graph, rule_node)

def apply(self, data_graph):
def apply(self, data_graph: 'GraphLike') -> int:
focus_nodes = self.shape.focus_nodes(data_graph) # uses target nodes to find focus nodes
applicable_nodes = self.filter_conditions(focus_nodes, data_graph)
sets_to_add = []
for a in applicable_nodes:
args_map = {"this": a}
results = self.js_exe.execute(data_graph, args_map, mode="construct")
triples = results['_result']
if triples is not None and isinstance(triples, (list, tuple)):
set_to_add = set()
for t in triples:
s, p, o = t[:3]
set_to_add.add((s, p, o))
sets_to_add.append(set_to_add)
for s in sets_to_add:
for t in s:
data_graph.add(t)
return
all_added = 0
iterate_limit = 100
while True:
if iterate_limit < 1:
raise ReportableRuntimeError("Local rule iteration exceeded iteration limit of 100.")
iterate_limit -= 1
added = 0
applicable_nodes = self.filter_conditions(focus_nodes, data_graph)
sets_to_add = []
for a in applicable_nodes:
args_map = {"this": a}
results = self.js_exe.execute(data_graph, args_map, mode="construct")
triples = results['_result']
this_added = False
if triples is not None and isinstance(triples, (list, tuple)):
set_to_add = set()
for t in triples:
s, p, o = tr = t[:3]
if not this_added and tr not in data_graph:
this_added = True
set_to_add.add(tr)
sets_to_add.append(set_to_add)
if this_added:
added += 1
if added > 0:
all_added += added
for s in sets_to_add:
for t in s:
data_graph.add(t)
if self.iterate:
continue # Jump up to iterate
else:
break # Don't iterate
break
return all_added
32 changes: 24 additions & 8 deletions pyshacl/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union

from pyshacl.consts import RDF_type, SH_rule, SH_SPARQLRule, SH_TripleRule
from pyshacl.errors import RuleLoadError
from pyshacl.errors import ReportableRuntimeError, RuleLoadError
from pyshacl.pytypes import GraphLike
from pyshacl.rules.sparql import SPARQLRule
from pyshacl.rules.triple import TripleRule
Expand All @@ -16,7 +16,7 @@
from .shacl_rule import SHACLRule


def gather_rules(shacl_graph: 'ShapesGraph') -> Dict['Shape', List['SHACLRule']]:
def gather_rules(shacl_graph: 'ShapesGraph', iterate_rules=False) -> Dict['Shape', List['SHACLRule']]:
"""
:param shacl_graph:
Expand Down Expand Up @@ -63,7 +63,7 @@ def gather_rules(shacl_graph: 'ShapesGraph') -> Dict['Shape', List['SHACLRule']]
"https://www.w3.org/TR/shacl-af/#rules-syntax",
)
if obj in triple_rule_nodes:
rule: SHACLRule = TripleRule(shape, obj)
rule: SHACLRule = TripleRule(shape, obj, iterate=iterate_rules)
elif obj in sparql_rule_nodes:
rule = SPARQLRule(shape, obj)
elif use_JSRule and callable(use_JSRule) and obj in js_rule_nodes:
Expand All @@ -77,13 +77,29 @@ def gather_rules(shacl_graph: 'ShapesGraph') -> Dict['Shape', List['SHACLRule']]
return ret_rules


def apply_rules(shapes_rules: Dict, data_graph: GraphLike):
def apply_rules(shapes_rules: Dict, data_graph: GraphLike, iterate=False) -> int:
# short the shapes dict by shapes sh:order before execution
sorted_shapes_rules: List[Tuple[Any, Any]] = sorted(shapes_rules.items(), key=lambda x: x[0].order)
total_modified = 0
for shape, rules in sorted_shapes_rules:
# sort the rules by the sh:order before execution
rules = sorted(rules, key=lambda x: x.order)
for r in rules:
if r.deactivated:
continue
r.apply(data_graph)
iterate_limit = 100
while True:
if iterate_limit < 1:
raise ReportableRuntimeError("SHACL Shape Rule iteration exceeded iteration limit of 100.")
iterate_limit -= 1
this_modified = 0
for r in rules:
if r.deactivated:
continue
n_modified = r.apply(data_graph)
this_modified += n_modified
if this_modified > 0:
total_modified += this_modified
if iterate:
continue
else:
break
break
return total_modified
6 changes: 4 additions & 2 deletions pyshacl/rules/shacl_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def validate_condition(self, data_graph, focus_node):


class SHACLRule(object):
__slots__ = ("shape", "node", "_deactivated")
__slots__ = ("shape", "node", "iterate", "_deactivated")

def __init__(self, shape, rule_node):
def __init__(self, shape, rule_node, iterate=False):
"""
:param shape:
Expand All @@ -38,6 +38,8 @@ def __init__(self, shape, rule_node):
super(SHACLRule, self).__init__()
self.shape = shape
self.node = rule_node
self.iterate = False

deactivated_nodes = list(self.shape.sg.objects(self.node, SH_deactivated))
self._deactivated = len(deactivated_nodes) > 0 and bool(deactivated_nodes[0])

Expand Down
60 changes: 42 additions & 18 deletions pyshacl/rules/sparql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


if TYPE_CHECKING:
from pyshacl.pytypes import GraphLike
from pyshacl.shape import Shape

XSD_string = XSD.term('string')
Expand All @@ -23,15 +24,15 @@
class SPARQLRule(SHACLRule):
__slots__ = ("_constructs", "_qh")

def __init__(self, shape: 'Shape', rule_node: 'rdflib.term.Identifier'):
def __init__(self, shape: 'Shape', rule_node: 'rdflib.term.Identifier', **kwargs):
"""
:param shape:
:type shape: Shape
:param rule_node:
:type rule_node: rdflib.term.Identifier
"""
super(SPARQLRule, self).__init__(shape, rule_node)
super(SPARQLRule, self).__init__(shape, rule_node, **kwargs)
construct_nodes = set(self.shape.sg.objects(self.node, SH_construct))
if len(construct_nodes) < 1:
raise RuleLoadError("No sh:construct on SPARQLRule", "https://www.w3.org/TR/shacl-af/#SPARQLRule")
Expand All @@ -49,21 +50,44 @@ def __init__(self, shape: 'Shape', rule_node: 'rdflib.term.Identifier'):
query_helper.collect_prefixes()
self._qh = query_helper

def apply(self, data_graph):
def apply(self, data_graph: 'GraphLike') -> int:
focus_nodes = self.shape.focus_nodes(data_graph) # uses target nodes to find focus nodes
applicable_nodes = self.filter_conditions(focus_nodes, data_graph)
construct_graphs = set()
all_added = 0
SPARQLQueryHelper = get_query_helper_cls()
for a in applicable_nodes:
for c in self._constructs:
init_bindings = {}
found_this = SPARQLQueryHelper.bind_this_regex.search(c)
if found_this:
init_bindings['this'] = a
c = self._qh.apply_prefixes(c)
results = data_graph.query(c, initBindings=init_bindings)
if results.type != "CONSTRUCT":
raise ReportableRuntimeError("Query executed by a SHACL SPARQLRule must be CONSTRUCT query.")
construct_graphs.add(results.graph)
for g in construct_graphs:
data_graph = clone_graph(g, target_graph=data_graph)
iterate_limit = 100
while True:
if iterate_limit < 1:
raise ReportableRuntimeError("Local rule iteration exceeded iteration limit of 100.")
iterate_limit -= 1
added = 0
applicable_nodes = self.filter_conditions(focus_nodes, data_graph)
construct_graphs = set()
for a in applicable_nodes:
for c in self._constructs:
init_bindings = {}
found_this = SPARQLQueryHelper.bind_this_regex.search(c)
if found_this:
init_bindings['this'] = a
c = self._qh.apply_prefixes(c)
results = data_graph.query(c, initBindings=init_bindings)
if results.type != "CONSTRUCT":
raise ReportableRuntimeError("Query executed by a SHACL SPARQLRule must be CONSTRUCT query.")
this_added = False
for i in results.graph:
if not this_added and i not in data_graph:
this_added = True
# We only need to know at least one triple was added, then break!
break
if this_added:
added += 1
construct_graphs.add(results.graph)
if added > 0:
for g in construct_graphs:
data_graph = clone_graph(g, target_graph=data_graph)
all_added += added
if self.iterate:
continue # Jump up to iterate
else:
break # Don't iterate
break # We've reached a local steady state
return all_added
43 changes: 33 additions & 10 deletions pyshacl/rules/triple/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
class TripleRule(SHACLRule):
__slots__ = ("s", "p", "o")

def __init__(self, shape: 'Shape', rule_node: 'rdflib.term.Identifier'):
def __init__(self, shape: 'Shape', rule_node: 'rdflib.term.Identifier', **kwargs):
"""
:param shape:
:type shape: Shape
:param rule_node:
:type rule_node: rdflib.term.Identifier
"""
super(TripleRule, self).__init__(shape, rule_node)
super(TripleRule, self).__init__(shape, rule_node, **kwargs)
my_subject_nodes = set(self.shape.sg.objects(self.node, SH_subject))
if len(my_subject_nodes) < 1:
raise RuntimeError("No sh:subject")
Expand Down Expand Up @@ -183,13 +183,36 @@ def get_nodes_from_node_expression(
else:
raise NotImplementedError("Unsupported expression s, p, or o, in SHACL TripleRule")

def apply(self, data_graph):
def apply(self, data_graph: 'GraphLike') -> int:
focus_nodes = self.shape.focus_nodes(data_graph) # uses target nodes to find focus nodes
applicable_nodes = self.filter_conditions(focus_nodes, data_graph)
for a in applicable_nodes:
s_set = self.get_nodes_from_node_expression(self.s, a, data_graph)
p_set = self.get_nodes_from_node_expression(self.p, a, data_graph)
o_set = self.get_nodes_from_node_expression(self.o, a, data_graph)
new_triples = itertools.product(s_set, p_set, o_set)
for i in iter(new_triples):
data_graph.add(i)
all_added = 0
iterate_limit = 100
while True:
if iterate_limit < 1:
raise ReportableRuntimeError("sh:rule iteration exceeded iteration limit of 100.")
iterate_limit -= 1
added = 0
to_add = []
for a in applicable_nodes:
s_set = self.get_nodes_from_node_expression(self.s, a, data_graph)
p_set = self.get_nodes_from_node_expression(self.p, a, data_graph)
o_set = self.get_nodes_from_node_expression(self.o, a, data_graph)
new_triples = itertools.product(s_set, p_set, o_set)
this_added = False
for i in iter(new_triples):
if not this_added and i not in data_graph:
this_added = True
to_add.append(i)
if this_added:
added += 1
if added > 0:
for i in to_add:
data_graph.add(i)
all_added += added
if self.iterate:
continue # Jump up to iterate
else:
break # Don't iterate
break
return all_added
12 changes: 9 additions & 3 deletions pyshacl/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _load_default_options(cls, options_dict: dict):
options_dict.setdefault('inference', 'none')
options_dict.setdefault('inplace', False)
options_dict.setdefault('use_js', False)
options_dict.setdefault('iterate_rules', False)
options_dict.setdefault('abort_on_error', False)
if 'logger' not in options_dict:
options_dict['logger'] = logging.getLogger(__name__)
Expand Down Expand Up @@ -221,10 +222,13 @@ def run(self):
self._target_graph = the_target_graph

shapes = self.shacl_graph.shapes # This property getter triggers shapes harvest.

iterate_rules = self.options.get("iterate_rules", False)
if self.options['advanced']:
target_types = gather_target_types(self.shacl_graph)
advanced = {'functions': gather_functions(self.shacl_graph), 'rules': gather_rules(self.shacl_graph)}
advanced = {
'functions': gather_functions(self.shacl_graph),
'rules': gather_rules(self.shacl_graph, iterate_rules=iterate_rules),
}
for s in shapes:
s.set_advanced(True)
apply_target_types(target_types)
Expand All @@ -245,7 +249,7 @@ def run(self):
for g in named_graphs:
if advanced:
apply_functions(advanced['functions'], g)
apply_rules(advanced['rules'], g)
apply_rules(advanced['rules'], g, iterate=iterate_rules)
for s in shapes:
_is_conform, _reports = s.validate(g)
non_conformant = non_conformant or (not _is_conform)
Expand Down Expand Up @@ -372,6 +376,7 @@ def validate(
)
rdflib_bool_unpatch()
use_js = kwargs.pop('js', None)
iterate_rules = kwargs.pop('iterate_rules', False)
validator = None
try:
validator = Validator(
Expand All @@ -383,6 +388,7 @@ def validate(
'inplace': inplace,
'abort_on_error': abort_on_error,
'advanced': advanced,
'iterate_rules': iterate_rules,
'use_js': use_js,
'logger': log,
},
Expand Down
Loading

0 comments on commit 4166ee1

Please sign in to comment.